Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/java/src/main/java/triton/client/InferResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.zip.GZIPInputStream;
import java.util.zip.InflaterInputStream;
import org.apache.commons.io.IOUtils;
import org.apache.http.Header;
import org.apache.http.HttpEntity;
Expand Down Expand Up @@ -77,6 +79,17 @@ public InferResult(HttpResponse resp) throws IOException, InferenceException
Preconditions.checkState(
entity != null, "Get null entity from HTTP response.");
InputStream stream = entity.getContent();

// Check for response compression and decompress if needed
Header contentEncodingHeader = resp.getFirstHeader("Content-Encoding");
if (contentEncodingHeader != null) {
String encoding = contentEncodingHeader.getValue().toLowerCase();
if ("gzip".equals(encoding)) {
stream = new GZIPInputStream(stream);
} else if ("deflate".equals(encoding)) {
stream = new InflaterInputStream(stream);
}
}

int httpCode = resp.getStatusLine().getStatusCode();
if (httpCode != HttpStatus.SC_OK) {
Expand Down
57 changes: 56 additions & 1 deletion src/java/src/main/java/triton/client/InferenceServerClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
import java.util.Map;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.zip.GZIPOutputStream;
import java.util.zip.DeflaterOutputStream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.http.HttpResponse;
Expand Down Expand Up @@ -340,6 +342,27 @@ private HttpPost createHttpPost(
"Inference-Header-Content-Length", String.valueOf(jsonBytes.length));
}

// Apply request compression if specified
byte[] finalRequestBody = bodyBytes.toByteArray();
if (arg.requestCompressionAlgorithm != null) {
if ("gzip".equalsIgnoreCase(arg.requestCompressionAlgorithm)) {
finalRequestBody = compressGzip(finalRequestBody);
arg.headers.put("Content-Encoding", "gzip");
} else if ("deflate".equalsIgnoreCase(arg.requestCompressionAlgorithm)) {
finalRequestBody = compressDeflate(finalRequestBody);
arg.headers.put("Content-Encoding", "deflate");
}
}

// Set response compression header if specified
if (arg.responseCompressionAlgorithm != null) {
if ("gzip".equalsIgnoreCase(arg.responseCompressionAlgorithm)) {
arg.headers.put("Accept-Encoding", "gzip");
} else if ("deflate".equalsIgnoreCase(arg.responseCompressionAlgorithm)) {
arg.headers.put("Accept-Encoding", "deflate");
}
}

// Create target URI.
URIBuilder ub = new URIBuilder(this.getUrl());
String safeModelName =
Expand All @@ -356,7 +379,7 @@ private HttpPost createHttpPost(
// Crete HttpPost, uri, body and headers.
HttpPost post = new HttpPost(ub.build());
arg.headers.forEach(post::setHeader);
post.setEntity(new NByteArrayEntity(bodyBytes.toByteArray()));
post.setEntity(new NByteArrayEntity(finalRequestBody));
return post;
}

Expand All @@ -365,6 +388,24 @@ private String getUrl() throws Exception
return "http://" + this.endpoint.getEndpoint();
}

private byte[] compressGzip(byte[] data) throws IOException
{
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try (GZIPOutputStream gzipOut = new GZIPOutputStream(baos)) {
gzipOut.write(data);
}
return baos.toByteArray();
}

private byte[] compressDeflate(byte[] data) throws IOException
{
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try (DeflaterOutputStream deflateOut = new DeflaterOutputStream(baos)) {
deflateOut.write(data);
}
return baos.toByteArray();
}

public InferResult infer(
String modelName, List<InferInput> inputs,
List<InferRequestedOutput> outputs) throws InferenceException
Expand All @@ -387,6 +428,8 @@ public static class InferArguments {
int timeout = -1;
Map<String, String> headers = new HashMap<>();
Map<String, String> queryParams = new HashMap<>();
String requestCompressionAlgorithm = null;
String responseCompressionAlgorithm = null;

public InferArguments(
String modelName, List<InferInput> inputs,
Expand Down Expand Up @@ -464,5 +507,17 @@ public InferArguments addQueryParam(String key, String value)
this.queryParams.put(key, value);
return this;
}

public InferArguments setRequestCompressionAlgorithm(String algorithm)
{
this.requestCompressionAlgorithm = algorithm;
return this;
}

public InferArguments setResponseCompressionAlgorithm(String algorithm)
{
this.responseCompressionAlgorithm = algorithm;
return this;
}
}
}