All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aliyun.openservices.eas.predict.http.PredictClient Maven / Gradle / Ivy

There is a newer version: 2.0.20
Show newest version
package com.aliyun.openservices.eas.predict.http;

import com.aliyun.openservices.eas.discovery.core.DiscoveryClient;
import com.aliyun.openservices.eas.predict.auth.HmacSha1Signature;
import com.aliyun.openservices.eas.predict.proto.EasyRecPredictProtos;
import com.aliyun.openservices.eas.predict.proto.TorchRecPredictProtos;
import com.aliyun.openservices.eas.predict.request.*;
import com.aliyun.openservices.eas.predict.response.*;
import com.aliyun.openservices.eas.predict.utils.*;
import org.apache.commons.io.IOUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.http.ConnectionClosedException;
import org.apache.http.Header;
import org.apache.http.HttpResponse;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.concurrent.FutureCallback;
import org.apache.http.conn.ConnectTimeoutException;
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
import org.apache.http.impl.nio.client.HttpAsyncClients;
import org.apache.http.impl.nio.conn.PoolingNHttpClientConnectionManager;
import org.apache.http.impl.nio.reactor.DefaultConnectingIOReactor;
import org.apache.http.impl.nio.reactor.IOReactorConfig;
import org.apache.http.nio.entity.NByteArrayEntity;
import org.apache.http.nio.reactor.ConnectingIOReactor;
import org.xerial.snappy.Snappy;

import java.io.IOException;
import java.net.ConnectException;
import java.net.SocketTimeoutException;
import java.net.URL;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.locks.ReentrantReadWriteLock;

/**
 * Created by xiping.zk on 2018/07/25.
 */

class BlacklistData{
    private long timestamp = 0L;
    private int count = 0;

    public BlacklistData(long timestamp, int count) {
        this.timestamp = timestamp;
        this.count = count;
    }

    public void setTimestamp(long timestamp) {
        this.timestamp = timestamp;
    }

    public long getTimestamp() {
        return timestamp;
    }

    public void setCount(int count) {
        this.count = count;
    }

    public int getCount() {
        return count;
    }
}

// define blacklist task
class BlacklistTask implements Runnable {
    private Map blacklist = null;
    private ReentrantReadWriteLock rwlock = null;
    private int blacklistTimeout = 0;
    private static Log log = LogFactory.getLog(BlacklistTask.class);

    public BlacklistTask(Map blacklist,
                         ReentrantReadWriteLock rwlock, int blacklistTimeout) {
        this.blacklist = blacklist;
        this.rwlock = rwlock;
        this.blacklistTimeout = blacklistTimeout;
    }

    @Override
    public void run() {
        while (true) {
            try {
                rwlock.writeLock().lock();
                Iterator> it = blacklist.entrySet().iterator();
                long currentTimestamp = System.currentTimeMillis();
                while (it.hasNext()) {
                    Map.Entry entry = it.next();
                    if (entry.getValue().getTimestamp() <= currentTimestamp) {
                        log.info("Remove [" + entry.getKey() + "] from blacklist");
                        it.remove();
                    }
                }
                rwlock.writeLock().unlock();
                Thread.sleep(blacklistTimeout * 1000);
            } catch (Exception e) {
                System.err.println(e.getMessage());
            }
        }
    }
}

public class PredictClient {
    private static Log log = LogFactory.getLog(PredictClient.class);
    final private int endpointRetryCount = 10;
    private HashMap mapHeader = null;
    private CloseableHttpAsyncClient httpclient = null;
    private String token = null;
    private String modelName = null;
    private String requestPath = "";
    private String endpoint = null;
    private String url = null;
    private boolean isCompressed = false;
    private int retryCount = 3;
    private EnumSet retryConditions = EnumSet.noneOf(RetryCondition.class);
    private String contentType = "application/octet-stream";
    private int errorCode = 400;
    private String errorMessage;
    private String vipSrvEndPoint = null;
    private String directEndPoint = null;
    private int requestTimeout = 0;
    private boolean enableBlacklist = false;
    private int blacklistSize = 10;
    private int blacklistTimeout = 30;
    private int blacklistTimeoutCount = 10;
    private Map blacklist = null;
    private ReentrantReadWriteLock rwlock = new ReentrantReadWriteLock();
    private Compressor compressor = null;
    private Map extraHeaders = new HashMap<>();

    public PredictClient() {
    }

    public PredictClient(HttpConfig httpConfig) {
        try {
            ConnectingIOReactor ioReactor = new DefaultConnectingIOReactor();
            PoolingNHttpClientConnectionManager cm = new PoolingNHttpClientConnectionManager(
                    ioReactor);
            cm.setMaxTotal(httpConfig.getMaxConnectionCount());
            cm.setDefaultMaxPerRoute(httpConfig.getMaxConnectionPerRoute());
            requestTimeout = httpConfig.getRequestTimeout();
            IOReactorConfig config = IOReactorConfig.custom()
                    .setTcpNoDelay(true)
                    .setSoTimeout(httpConfig.getReadTimeout())
                    .setSoReuseAddress(true)
                    .setConnectTimeout(httpConfig.getConnectTimeout())
                    .setIoThreadCount(httpConfig.getIoThreadNum())
                    .setSoKeepAlive(httpConfig.isKeepAlive()).build();
            final RequestConfig requestConfig =
                RequestConfig.custom()
                    .setRedirectsEnabled(httpConfig.getRedirectsEnabled())
                    .setConnectTimeout(httpConfig.getConnectTimeout())
                    .setSocketTimeout(httpConfig.getReadTimeout()).build();
            httpclient = HttpAsyncClients.custom().setConnectionManager(cm)
                    .setDefaultIOReactorConfig(config)
                    .setDefaultRequestConfig(requestConfig).build();
            httpclient.start();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private PredictClient setHttp(CloseableHttpAsyncClient httpclient) {
        this.httpclient = httpclient;
        return this;
    }

    public PredictClient setToken(String token) {
        if (token == null || token.length() > 0) {
            this.token = token;
        }
        return this;
    }

    public PredictClient setRequestTimeout(int requestTimeout) {
        this.requestTimeout = requestTimeout;
        return this;
    }

    public PredictClient setModelName(String modelName) {
        this.modelName = modelName;
        return this;
    }

    public PredictClient setEndpoint(String endpoint) {
        if (endpoint != null && !endpoint.startsWith("http://") && !endpoint.startsWith("https://")){
            this.endpoint = "http://" + endpoint;
        } else {
            this.endpoint = endpoint;
        }
        return this;
    }

    public PredictClient setUrl(String url) {
        this.url = url;
        return this;
    }

    public PredictClient setVIPServer(String vipSrvEndPoint) {
        if (vipSrvEndPoint == null || vipSrvEndPoint.length() > 0) {
            this.vipSrvEndPoint = vipSrvEndPoint;
        }
        return this;
    }

    public PredictClient setDirectEndpoint(String directEndpoint) {
        if (directEndPoint == null || directEndPoint.length() > 0) {
            this.directEndPoint = directEndpoint;
            System.setProperty("com.aliyun.eas.discovery", directEndpoint);
        }
        return this;
    }

    public PredictClient setIsCompressed(boolean isCompressed) {
        this.isCompressed = isCompressed;
        return this;
    }
    public PredictClient setCompressor(Compressor compressor) {
        this.compressor = compressor;
        return this;
    }

    public PredictClient setRetryCount(int retryCount) {
        this.retryCount = retryCount;
        return this;
    }

    public PredictClient setRetryConditions(EnumSet retryConditions) {
        this.retryConditions = retryConditions;
        return this;
    }

    public PredictClient setTracing(HashMap mapHeader) {
        this.mapHeader = mapHeader;
        return this;
    }

    public PredictClient setContentType(String contentType) {
        this.contentType = contentType;
        return this;
    }

    public String getRequestPath() {
        return requestPath;
    }

    public PredictClient setRequestPath(String requestPath) {
        if (requestPath == null) {
            return this;
        }
        if (requestPath.length() > 0 && requestPath.charAt(0) != '/') {
            requestPath = "/" + requestPath;
        }
        this.requestPath = requestPath;
        return this;
    }

    public PredictClient addExtraHeaders(Map extraHeaders) {
        this.extraHeaders.putAll(extraHeaders);
        return this;
    }

    public PredictClient startBlacklistMechanism(int blacklistSize,
                                                 int blacklistTimeout,
                                                 int blacklistTimeoutCount) {
        this.enableBlacklist = true;
        this.blacklistSize = blacklistSize;
        this.blacklistTimeout = blacklistTimeout;
        this.blacklistTimeoutCount = blacklistTimeoutCount;
        this.blacklist = new HashMap();
        BlacklistTask task = new BlacklistTask(this.blacklist,
                this.rwlock, this.blacklistTimeout);
        Thread t = new Thread(task);
        t.start();
        return this;
    }

    public int getErrorCode() {
        return errorCode;
    }

    public String getErrorMessage() {
        return errorMessage;
    }

    public PredictClient createChildClient(String token, String endPoint,
                                           String modelName) {
        PredictClient client = new PredictClient();
        client.setHttp(this.httpclient).setToken(token).setEndpoint(endPoint)
                .setModelName(modelName);
        return client;
    }

    // to be compatible with old version typo
    public PredictClient createChlidClient(String token, String endPoint,
                                           String modelName) {
        return createChildClient(token, endpoint, modelName);
    }

    public PredictClient createChildClient() {
        PredictClient client = new PredictClient();
        client.setHttp(this.httpclient)
            .setToken(this.token)
            .setModelName(this.modelName)
            .setRetryCount(this.retryCount)
            .setRetryConditions(this.retryConditions)
            .setRequestTimeout(this.requestTimeout)
            .setIsCompressed(this.isCompressed)
            .setContentType(this.contentType)
            .setRequestPath(this.requestPath)
            .setUrl(this.url)
            .addExtraHeaders(this.extraHeaders);
        if (this.vipSrvEndPoint != null) {
            client.setVIPServer(this.vipSrvEndPoint);
        } else if (this.directEndPoint != null) {
            client.setDirectEndpoint(this.directEndPoint);
        } else {
            client.setEndpoint(this.endpoint);
        }
        if (this.compressor != null) {
            client.setCompressor(this.compressor);
        }
        return client;
    }

    // to be compatible with old version typo
    public PredictClient createChlidClient() {
        return createChildClient();
    }

    private String getUrl(String lastUrl) throws Exception {
        if (this.url != null) {
            return this.url + this.requestPath;
        }
        String endpoint = this.endpoint;
        String url = "";
        if (enableBlacklist) {
            int retryCount = endpointRetryCount;
            if (blacklistSize * 2 > endpointRetryCount) {
                retryCount = blacklistSize * 2;
            }
            for (int i = 0; i < retryCount; i++) {
                if (directEndPoint != null) {
                    endpoint = DiscoveryClient.srvHost(this.modelName).toInetAddr();
                    url = "http://" + endpoint + "/api/predict/" + modelName + requestPath;
                    // System.out.println("URL: " + url + " LastURL: " + lastUrl);
                    if (DiscoveryClient.getHosts(this.modelName).size() < 2) {
                        return url;
                    }
                    try {
                        rwlock.readLock().lock();
                        if (!url.equals(lastUrl)) {
                            if (!blacklist.containsKey(url)) {
                                return url;
                            } else if (blacklist.get(url).getCount() < blacklistTimeoutCount) {
                                return url;
                            }
                        }
                    } finally {
                        rwlock.readLock().unlock();
                    }
                } else {
                    url = endpoint + "/api/predict/" + modelName + requestPath;
                    break;
                }
            }
        } else {
            for (int i = 0; i < endpointRetryCount; i++) {
                if (directEndPoint != null) {
                    endpoint = DiscoveryClient.srvHost(this.modelName).toInetAddr();
                    url = "http://" + endpoint + "/api/predict/" + modelName + requestPath;
                    // System.out.println("URL: " + url + " LastURL: " + lastUrl);
                    if (DiscoveryClient.getHosts(this.modelName).size() < 2) {
                        return url;
                    }
                    if (!url.equals(lastUrl)) {
                        return url;
                    }
                } else {
                    url = endpoint + "/api/predict/" + modelName + requestPath;
                    break;
                }
            }
        }

        return url;
    }

    private HttpPost generateSignature(byte[] requestContent, String lastUrl) throws Exception {
        HttpPost request = new HttpPost(getUrl(lastUrl));
        request.setEntity(new NByteArrayEntity(requestContent));
        if (isCompressed) {
            try {
                requestContent = Snappy.compress(requestContent);
            } catch (IOException e) {
                log.error("Compress Error", e);
            }
        }
        HmacSha1Signature signature = new HmacSha1Signature();
        String md5Content = signature.getMD5(requestContent);
        for (Map.Entry entry : this.extraHeaders.entrySet()) {
            request.addHeader(entry.getKey(), entry.getValue());
        }
        request.addHeader(HttpHeaders.CONTENT_MD5, md5Content);
        Date now = new Date();
        SimpleDateFormat dateFormat = new SimpleDateFormat(
                "EEE, dd MMM yyyy HH:mm:ss", Locale.ENGLISH);
        dateFormat.setTimeZone(TimeZone.getTimeZone("GMT"));
        String currentTime = dateFormat.format(now) + " GMT";
        request.addHeader(HttpHeaders.DATE, currentTime);
        request.addHeader(HttpHeaders.CONTENT_TYPE, contentType);

        if (mapHeader != null) {
            request.addHeader("Client-Timestamp",
                    String.valueOf(System.currentTimeMillis()));
        }

        if (this.token != null) {
            String auth = "POST" + "\n" + md5Content + "\n"
                + this.contentType + "\n" + currentTime + "\n";
            if (this.url == null) {
                auth = auth + "/api/predict/" + this.modelName + this.requestPath;
            } else {
                URL u = new URL(this.url);
                auth = auth + u.getPath() + this.requestPath;
            }
            request.addHeader(HttpHeaders.AUTHORIZATION,
                "EAS " + signature.computeSignature(token, auth));
        }
        return request;
    }

    private byte[] handleResponse(HttpResponse response) throws IOException, HttpException {
        byte[] content;
        int statusCode = response.getStatusLine().getStatusCode();
        if (statusCode == 200) {
            content = IOUtils.toByteArray(response.getEntity().getContent());
            if (isCompressed) {
                content = Snappy.uncompress(content);
            }
        } else {
            String errorMsg = IOUtils.toString(response.getEntity().getContent(), "UTF-8");
            throw new HttpException(statusCode, errorMsg);
        }
        return content;
    }

    private byte[] getContent(HttpPost request) throws IOException,
            InterruptedException, ExecutionException, TimeoutException {
        byte[] content = null;
        HttpResponse response = null;

        Future future = httpclient.execute(request, null);
        try {
            if (requestTimeout > 0) {
                response = future.get(requestTimeout, TimeUnit.MILLISECONDS);
            } else {
                response = future.get();
            }
        } catch (ExecutionException e) {
            throw e;
        }

        if (mapHeader != null) {
            Header[] header = response.getAllHeaders();
            for (int i = 0; i < header.length; i++) {
                mapHeader.put(header[i].getName(), header[i].getValue());
            }
        }
        if (future.isDone()) {
            try {
                content = handleResponse(response);
            } catch (HttpException e) {
                throw e;
            } catch (Exception e) {
                log.error("handle response error:", e);
                throw e;
            }
        } else if (future.isCancelled()) {
            log.error("request cancelled!", new Exception("Request cancelled"));
        } else {
            throw new HttpException(-1, "request failed!");
        }
        return content;
    }


    private boolean shouldRetry(Exception e) {
        // Always need retry if there are no specific retryConditions
        if (retryConditions.isEmpty()) {
            return true;
        }

        if (e instanceof HttpException) {
            int statusCode = ((HttpException) e).getCode();
            if (retryConditions.contains(RetryCondition.RESPONSE_4XX) && statusCode / 100 == 4) {
                return true;
            }
            if (retryConditions.contains(RetryCondition.RESPONSE_5XX) && statusCode / 100 == 5) {
                return true;
            }
        }

        Throwable cause = e.getCause();
        if ((cause instanceof ConnectException || cause instanceof ConnectionClosedException || e instanceof ConnectException || e instanceof ConnectionClosedException) && retryConditions.contains(RetryCondition.CONNECTION_FAILED)) {
            return true;
        }
        if ((cause instanceof ConnectTimeoutException || e instanceof ConnectTimeoutException) && retryConditions.contains(RetryCondition.CONNECTION_TIMEOUT)) {
            return true;
        }
        if ((cause instanceof SocketTimeoutException || cause instanceof TimeoutException || e instanceof SocketTimeoutException || e instanceof TimeoutException) && retryConditions.contains(RetryCondition.READ_TIMEOUT)) {
            return true;
        }

        return false;
    }

    public byte[] predict(byte[] requestContent) throws Exception {
        if (compressor != null) {
            if (compressor == Compressor.Gzip) {
                requestContent = GzipUtils.compress(requestContent);
            } else if (compressor == Compressor.Zlib) {
                requestContent = ZlibUtils.compress(requestContent);
            }  else if (compressor == Compressor.Snappy) {
                requestContent = SnappyUtils.compress(requestContent);
            }  else if (compressor == Compressor.LZ4) {
                requestContent = LZ4Utils.compress(requestContent);
            }  else if (compressor == Compressor.Zstd) {
                requestContent = ZstdUtils.compress(requestContent);
            } else {
                log.warn("Compressor are not supported!");
            }
        }
        byte[] content = null;
        String lastUrl = "";
        for (int currentRetry = 0; currentRetry <= retryCount; currentRetry++) {
            try {
                HttpPost request = generateSignature(requestContent, lastUrl);
                lastUrl = request.getURI().toString();
                content = getContent(request);
                break;
            } catch (HttpException e) {
                int statusCode = e.getCode();
                String errorMessage = String.format("URL: %s, Status Code:: %d, Message: %s", lastUrl, statusCode, e.getMessage());
                if (shouldRetry(e) && currentRetry < retryCount) {
                    log.warn(String.format("Predict failed on %dth retry, %s", currentRetry + 1, errorMessage));
                } else {
                    log.error(errorMessage);
                    throw new HttpException(statusCode, errorMessage);
                }
            } catch (Exception e) {
                String errorMessage = String.format("URL: %s, Message: %s", lastUrl, (e.getMessage() == null) ? e : e.getMessage());
                if (shouldRetry(e) && currentRetry < retryCount) {
                    log.warn(String.format("Predict failed on %dth retry, %s", currentRetry + 1, errorMessage));
                } else {
                    log.error(errorMessage);
                    throw e;
                }
            }
        }

        return content;
    }

    public BladeResponse predict(BladeRequest runRequest) throws Exception {
        BladeResponse runResponse = new BladeResponse();
        byte[] result = predict(runRequest.getRequest().toByteArray());
        if (result != null) {
            runResponse.setContentValues(result);
        }
        return runResponse;
    }

    public TFResponse predict(TFRequest runRequest) throws Exception {
        TFResponse runResponse = new TFResponse();
        byte[] result = predict(runRequest.getRequest().toByteArray());
        if (result != null) {
            runResponse.setContentValues(result);
        }
        return runResponse;
    }

    public CaffeResponse predict(CaffeRequest runRequest) throws Exception {
        CaffeResponse runResponse = new CaffeResponse();
        byte[] result = predict(runRequest.getRequest().toByteArray());
        if (result != null) {
            runResponse.setContentValues(result);
        }
        return runResponse;
    }

    public JsonResponse predict(JsonRequest requestContent)
            throws Exception {
        byte[] result = predict(requestContent.getJSON().getBytes());
        JsonResponse jsonResponse = new JsonResponse();
        if (result != null) {
            jsonResponse.setContentValues(result);
        }
        return jsonResponse;
    }

    public TorchResponse predict(TorchRequest runRequest) throws Exception {
        TorchResponse runResponse = new TorchResponse();
        byte[] result = predict(runRequest.getRequest().toByteArray());
        if (result != null) {
            runResponse.setContentValues(result);
        }
        return runResponse;
    }



    public EasyRecPredictProtos.PBResponse predict(EasyRecRequest runRequest) throws Exception {
        EasyRecPredictProtos.PBResponse runResponse = null;
        byte[] result = this.predict(runRequest.getRequest().toByteArray());
        if (result != null) {
            runResponse = EasyRecPredictProtos.PBResponse.parseFrom(result);
        }

        return runResponse;
    }

    public TorchRecPredictProtos.PBResponse predict(TorchRecRequest runRequest) throws Exception {
        TorchRecPredictProtos.PBResponse runResponse = null;
        byte[] result = this.predict(runRequest.getRequest().toByteArray());
        if (result != null) {
            runResponse = TorchRecPredictProtos.PBResponse.parseFrom(result);
        }

        return runResponse;
    }

    public String predict(String requestContent) throws Exception {
        byte[] result = predict(requestContent.getBytes());
        if (result != null) {
            return new String(result);
        }
        return null;
    }

    private void handleBlacklist(String key) {
        if (blacklist.containsKey(key)) {
            int timeoutCount = blacklist.get(key).getCount();
            if (timeoutCount < blacklistTimeoutCount) {
                blacklist.get(key).setCount(timeoutCount + 1);
                log.info("Set [" + key + "] timeoutCount:"
                        + blacklist.get(key).getCount());
            } else {
                long expirationTimestamp =
                        System.currentTimeMillis() + blacklistTimeout * 1000;
                blacklist.get(key).setTimestamp(expirationTimestamp);
                log.info("Set [" + key + "] timestamp: " +
                        blacklist.get(key).getTimestamp()
                        + " timeoutCount: " + blacklist.get(key).getCount());
            }
        } else {
            if (blacklist.size() < blacklistSize) {
                long expirationTimestamp =
                        System.currentTimeMillis() + blacklistTimeout * 1000;
                blacklist.put(key, new BlacklistData(expirationTimestamp, 1));
                log.info("Put [" + key + "] into blacklist");
            }
        }
    }

    private void handleRetryOrFailure(CompletableFuture futureResponse, Exception exception, int currentRetry, String url, byte[] requestData) {
        String errorMessage;
        if (exception instanceof HttpException) {
            int statusCode = ((HttpException) exception).getCode();
            errorMessage = String.format("URL: %s, Status Code:: %d, Message: %s", url, statusCode, exception.getMessage());
        } else {
            errorMessage = String.format("URL: %s, Message: %s", url, (exception.getMessage() == null) ? exception : exception.getMessage());
        }

        if (currentRetry < retryCount && shouldRetry(exception)) {
            log.warn(String.format("PredictAsync failed on %dth retry, %s", currentRetry + 1, errorMessage));
            predictAsyncInternal(requestData, currentRetry + 1, url).whenComplete((result, ex) -> {
                if (ex != null) {
                    Throwable cause = ex.getCause();
                    futureResponse.completeExceptionally((cause != null) ? cause : ex);
                } else {
                    futureResponse.complete(result);
                }
            });
        } else {
            log.error(errorMessage);
            futureResponse.completeExceptionally(exception);
        }
    }

    private CompletableFuture predictAsyncInternal(byte[] requestData, int currentRetry, String lastUrl) {
        CompletableFuture futureResponse = new CompletableFuture<>();
        try {
            // Generate the HTTP POST request with signatures
            HttpPost request = generateSignature(requestData, lastUrl);
            httpclient.execute(request, new FutureCallback() {
                @Override
                public void completed(HttpResponse response) {
                    try {
                        byte[] responseContent = handleResponse(response);
                        futureResponse.complete(responseContent);
                    } catch (HttpException e) {
                        handleRetryOrFailure(futureResponse, e, currentRetry, request.getURI().toString(), requestData);
                    } catch (Exception e) {
                        handleRetryOrFailure(futureResponse, e, currentRetry, request.getURI().toString(), requestData);
                    }
                }

                @Override
                public void failed(Exception ex) {
                    handleRetryOrFailure(futureResponse, ex, currentRetry, request.getURI().toString(), requestData);
                }

                @Override
                public void cancelled() {
                    futureResponse.cancel(true);
                }
            });
        } catch (Exception ex) {
            futureResponse.completeExceptionally(ex);
        }
        return futureResponse;
    }

    public CompletableFuture predictAsync(byte[] requestContent) {
        // Start the asynchronous prediction with initial retry parameters
        return predictAsyncInternal(requestContent, 0, "");
    }

    public CompletableFuture predictAsync(BladeRequest runRequest) {
        CompletableFuture futureResponse = new CompletableFuture<>();

        predictAsync(runRequest.getRequest().toByteArray())
            .thenApply(result -> {
                BladeResponse runResponse = new BladeResponse();
                runResponse.setContentValues(result);
                return runResponse;
            })
            .whenComplete((res, ex) -> {
                if (ex != null) {
                    futureResponse.completeExceptionally(ex);
                } else {
                    futureResponse.complete(res);
                }
            });

        return futureResponse;
    }

    public CompletableFuture predictAsync(TFRequest runRequest) {
        CompletableFuture futureResponse = new CompletableFuture<>();

        predictAsync(runRequest.getRequest().toByteArray())
            .thenApply(result -> {
                TFResponse runResponse = new TFResponse();
                runResponse.setContentValues(result);
                return runResponse;
            })
            .whenComplete((res, ex) -> {
                if (ex != null) {
                    futureResponse.completeExceptionally(ex);
                } else {
                    futureResponse.complete(res);
                }
            });

        return futureResponse;
    }

    public CompletableFuture predictAsync(CaffeRequest runRequest) {
        CompletableFuture futureResponse = new CompletableFuture<>();

        predictAsync(runRequest.getRequest().toByteArray())
            .thenApply(result -> {
                CaffeResponse runResponse = new CaffeResponse();
                runResponse.setContentValues(result);
                return runResponse;
            })
            .whenComplete((res, ex) -> {
                if (ex != null) {
                    futureResponse.completeExceptionally(ex);
                } else {
                    futureResponse.complete(res);
                }
            });

        return futureResponse;
    }

    public CompletableFuture predictAsync(JsonRequest requestContent) {
        CompletableFuture futureResponse = new CompletableFuture<>();

        byte[] requestData;
        try {
            requestData = requestContent.getJSON().getBytes();
        } catch (IOException ex) {
            futureResponse.completeExceptionally(ex);
            return futureResponse;
        }

        predictAsync(requestData)
            .thenApply(resultBytes -> {
                JsonResponse jsonResponse = new JsonResponse();
                try {
                    jsonResponse.setContentValues(resultBytes);
                    return jsonResponse;
                } catch (Exception ex) {
                    throw new CompletionException(ex);
                }
            })
            .whenComplete((jsonResponse, throwable) -> {
                if (throwable != null) {
                    futureResponse.completeExceptionally(throwable.getCause());
                } else {
                    futureResponse.complete(jsonResponse);
                }
            });

        return futureResponse;
    }


    public CompletableFuture predictAsync(TorchRequest runRequest) {
        CompletableFuture futureResponse = new CompletableFuture<>();

        predictAsync(runRequest.getRequest().toByteArray())
            .thenApply(result -> {
                TorchResponse runResponse = new TorchResponse();
                runResponse.setContentValues(result);
                return runResponse;
            })
            .whenComplete((res, ex) -> {
                if (ex != null) {
                    futureResponse.completeExceptionally(ex);
                } else {
                    futureResponse.complete(res);
                }
            });

        return futureResponse;
    }

    public CompletableFuture predictAsync(String requestContent) {
        CompletableFuture futureResponse = new CompletableFuture<>();

        predictAsync(requestContent.getBytes())
            .thenApply(result -> {
                return new String(result);
            })
            .whenComplete((res, ex) -> {
                if (ex != null) {
                    futureResponse.completeExceptionally(ex);
                } else {
                    futureResponse.complete(res);
                }
            });

        return futureResponse;
    }

    public void shutdown() {
        try {
            httpclient.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy