From 138a87453c8517ffe3a083295457e725fbe663fa Mon Sep 17 00:00:00 2001 From: koki watanabe Date: Tue, 5 Aug 2025 17:16:19 +0900 Subject: [PATCH 1/2] enh: compression option for java --- .../main/java/triton/client/InferResult.java | 13 +++++ .../triton/client/InferenceServerClient.java | 57 ++++++++++++++++++- 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/src/java/src/main/java/triton/client/InferResult.java b/src/java/src/main/java/triton/client/InferResult.java index c5166f3f0..ff880063e 100644 --- a/src/java/src/main/java/triton/client/InferResult.java +++ b/src/java/src/main/java/triton/client/InferResult.java @@ -33,6 +33,8 @@ import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Array; +import java.util.zip.GZIPInputStream; +import java.util.zip.InflaterInputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; @@ -77,6 +79,17 @@ public InferResult(HttpResponse resp) throws IOException, InferenceException Preconditions.checkState( entity != null, "Get null entity from HTTP response."); InputStream stream = entity.getContent(); + + // Check for response compression and decompress if needed + Header contentEncodingHeader = resp.getFirstHeader("Content-Encoding"); + if (contentEncodingHeader != null) { + String encoding = contentEncodingHeader.getValue().toLowerCase(); + if ("gzip".equals(encoding)) { + stream = new GZIPInputStream(stream); + } else if ("deflate".equals(encoding)) { + stream = new InflaterInputStream(stream); + } + } int httpCode = resp.getStatusLine().getStatusCode(); if (httpCode != HttpStatus.SC_OK) { diff --git a/src/java/src/main/java/triton/client/InferenceServerClient.java b/src/java/src/main/java/triton/client/InferenceServerClient.java index cabd62105..3f86e61f3 100644 --- a/src/java/src/main/java/triton/client/InferenceServerClient.java +++ b/src/java/src/main/java/triton/client/InferenceServerClient.java @@ -37,6 +37,8 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.net.URLEncoder; +import java.util.zip.GZIPOutputStream; +import java.util.zip.DeflaterOutputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -340,6 +342,27 @@ private HttpPost createHttpPost( "Inference-Header-Content-Length", String.valueOf(jsonBytes.length)); } + // Apply request compression if specified + byte[] finalRequestBody = bodyBytes.toByteArray(); + if (arg.requestCompressionAlgorithm != null) { + if ("gzip".equalsIgnoreCase(arg.requestCompressionAlgorithm)) { + finalRequestBody = compressGzip(finalRequestBody); + arg.headers.put("Content-Encoding", "gzip"); + } else if ("deflate".equalsIgnoreCase(arg.requestCompressionAlgorithm)) { + finalRequestBody = compressDeflate(finalRequestBody); + arg.headers.put("Content-Encoding", "deflate"); + } + } + + // Set response compression header if specified + if (arg.responseCompressionAlgorithm != null) { + if ("gzip".equalsIgnoreCase(arg.responseCompressionAlgorithm)) { + arg.headers.put("Accept-Encoding", "gzip"); + } else if ("deflate".equalsIgnoreCase(arg.responseCompressionAlgorithm)) { + arg.headers.put("Accept-Encoding", "deflate"); + } + } + // Create target URI. URIBuilder ub = new URIBuilder(this.getUrl()); String safeModelName = @@ -356,7 +379,7 @@ private HttpPost createHttpPost( // Crete HttpPost, uri, body and headers. HttpPost post = new HttpPost(ub.build()); arg.headers.forEach(post::setHeader); - post.setEntity(new NByteArrayEntity(bodyBytes.toByteArray())); + post.setEntity(new NByteArrayEntity(finalRequestBody)); return post; } @@ -365,6 +388,24 @@ private String getUrl() throws Exception return "http://" + this.endpoint.getEndpoint(); } + private byte[] compressGzip(byte[] data) throws IOException + { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (GZIPOutputStream gzipOut = new GZIPOutputStream(baos)) { + gzipOut.write(data); + } + return baos.toByteArray(); + } + + private byte[] compressDeflate(byte[] data) throws IOException + { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (DeflaterOutputStream deflateOut = new DeflaterOutputStream(baos)) { + deflateOut.write(data); + } + return baos.toByteArray(); + } + public InferResult infer( String modelName, List inputs, List outputs) throws InferenceException @@ -387,6 +428,8 @@ public static class InferArguments { int timeout = -1; Map headers = new HashMap<>(); Map queryParams = new HashMap<>(); + String requestCompressionAlgorithm = null; + String responseCompressionAlgorithm = null; public InferArguments( String modelName, List inputs, @@ -464,5 +507,17 @@ public InferArguments addQueryParam(String key, String value) this.queryParams.put(key, value); return this; } + + public InferArguments setRequestCompressionAlgorithm(String algorithm) + { + this.requestCompressionAlgorithm = algorithm; + return this; + } + + public InferArguments setResponseCompressionAlgorithm(String algorithm) + { + this.responseCompressionAlgorithm = algorithm; + return this; + } } } From 436ae030c59c79e5188e61bd8330e25a8b8ccba6 Mon Sep 17 00:00:00 2001 From: koki watanabe Date: Tue, 5 Aug 2025 18:42:45 +0900 Subject: [PATCH 2/2] fix: import order --- src/java/src/main/java/triton/client/InferResult.java | 4 ++-- .../src/main/java/triton/client/InferenceServerClient.java | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/java/src/main/java/triton/client/InferResult.java b/src/java/src/main/java/triton/client/InferResult.java index ff880063e..258e9533d 100644 --- a/src/java/src/main/java/triton/client/InferResult.java +++ b/src/java/src/main/java/triton/client/InferResult.java @@ -33,8 +33,6 @@ import java.io.IOException; import java.io.InputStream; import java.lang.reflect.Array; -import java.util.zip.GZIPInputStream; -import java.util.zip.InflaterInputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; @@ -42,6 +40,8 @@ import java.util.List; import java.util.Map; import java.util.function.Function; +import java.util.zip.GZIPInputStream; +import java.util.zip.InflaterInputStream; import org.apache.commons.io.IOUtils; import org.apache.http.Header; import org.apache.http.HttpEntity; diff --git a/src/java/src/main/java/triton/client/InferenceServerClient.java b/src/java/src/main/java/triton/client/InferenceServerClient.java index 3f86e61f3..693b95cd6 100644 --- a/src/java/src/main/java/triton/client/InferenceServerClient.java +++ b/src/java/src/main/java/triton/client/InferenceServerClient.java @@ -37,8 +37,6 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.net.URLEncoder; -import java.util.zip.GZIPOutputStream; -import java.util.zip.DeflaterOutputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -47,6 +45,8 @@ import java.util.Map; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import java.util.zip.GZIPOutputStream; +import java.util.zip.DeflaterOutputStream; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.http.HttpResponse;