diff --git a/src/java/src/main/java/triton/client/InferResult.java b/src/java/src/main/java/triton/client/InferResult.java index c5166f3f0..258e9533d 100644 --- a/src/java/src/main/java/triton/client/InferResult.java +++ b/src/java/src/main/java/triton/client/InferResult.java @@ -40,6 +40,8 @@ import java.util.List; import java.util.Map; import java.util.function.Function; +import java.util.zip.GZIPInputStream; +import java.util.zip.InflaterInputStream; import org.apache.commons.io.IOUtils; import org.apache.http.Header; import org.apache.http.HttpEntity; @@ -77,6 +79,17 @@ public InferResult(HttpResponse resp) throws IOException, InferenceException Preconditions.checkState( entity != null, "Get null entity from HTTP response."); InputStream stream = entity.getContent(); + + // Check for response compression and decompress if needed + Header contentEncodingHeader = resp.getFirstHeader("Content-Encoding"); + if (contentEncodingHeader != null) { + String encoding = contentEncodingHeader.getValue().toLowerCase(); + if ("gzip".equals(encoding)) { + stream = new GZIPInputStream(stream); + } else if ("deflate".equals(encoding)) { + stream = new InflaterInputStream(stream); + } + } int httpCode = resp.getStatusLine().getStatusCode(); if (httpCode != HttpStatus.SC_OK) { diff --git a/src/java/src/main/java/triton/client/InferenceServerClient.java b/src/java/src/main/java/triton/client/InferenceServerClient.java index cabd62105..693b95cd6 100644 --- a/src/java/src/main/java/triton/client/InferenceServerClient.java +++ b/src/java/src/main/java/triton/client/InferenceServerClient.java @@ -45,6 +45,8 @@ import java.util.Map; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.zip.GZIPOutputStream; +import java.util.zip.DeflaterOutputStream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.http.HttpResponse; @@ -340,6 +342,27 @@ private HttpPost createHttpPost( "Inference-Header-Content-Length", String.valueOf(jsonBytes.length)); } + // Apply request compression if specified + byte[] finalRequestBody = bodyBytes.toByteArray(); + if (arg.requestCompressionAlgorithm != null) { + if ("gzip".equalsIgnoreCase(arg.requestCompressionAlgorithm)) { + finalRequestBody = compressGzip(finalRequestBody); + arg.headers.put("Content-Encoding", "gzip"); + } else if ("deflate".equalsIgnoreCase(arg.requestCompressionAlgorithm)) { + finalRequestBody = compressDeflate(finalRequestBody); + arg.headers.put("Content-Encoding", "deflate"); + } + } + + // Set response compression header if specified + if (arg.responseCompressionAlgorithm != null) { + if ("gzip".equalsIgnoreCase(arg.responseCompressionAlgorithm)) { + arg.headers.put("Accept-Encoding", "gzip"); + } else if ("deflate".equalsIgnoreCase(arg.responseCompressionAlgorithm)) { + arg.headers.put("Accept-Encoding", "deflate"); + } + } + // Create target URI. URIBuilder ub = new URIBuilder(this.getUrl()); String safeModelName = @@ -356,7 +379,7 @@ private HttpPost createHttpPost( // Crete HttpPost, uri, body and headers. HttpPost post = new HttpPost(ub.build()); arg.headers.forEach(post::setHeader); - post.setEntity(new NByteArrayEntity(bodyBytes.toByteArray())); + post.setEntity(new NByteArrayEntity(finalRequestBody)); return post; } @@ -365,6 +388,24 @@ private String getUrl() throws Exception return "http://" + this.endpoint.getEndpoint(); } + private byte[] compressGzip(byte[] data) throws IOException + { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (GZIPOutputStream gzipOut = new GZIPOutputStream(baos)) { + gzipOut.write(data); + } + return baos.toByteArray(); + } + + private byte[] compressDeflate(byte[] data) throws IOException + { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (DeflaterOutputStream deflateOut = new DeflaterOutputStream(baos)) { + deflateOut.write(data); + } + return baos.toByteArray(); + } + public InferResult infer( String modelName, List inputs, List outputs) throws InferenceException @@ -387,6 +428,8 @@ public static class InferArguments { int timeout = -1; Map headers = new HashMap<>(); Map queryParams = new HashMap<>(); + String requestCompressionAlgorithm = null; + String responseCompressionAlgorithm = null; public InferArguments( String modelName, List inputs, @@ -464,5 +507,17 @@ public InferArguments addQueryParam(String key, String value) this.queryParams.put(key, value); return this; } + + public InferArguments setRequestCompressionAlgorithm(String algorithm) + { + this.requestCompressionAlgorithm = algorithm; + return this; + } + + public InferArguments setResponseCompressionAlgorithm(String algorithm) + { + this.responseCompressionAlgorithm = algorithm; + return this; + } } }