All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aliyun.openservices.eas.predict.request.TFRequest Maven / Gradle / Ivy

package com.aliyun.openservices.eas.predict.request;

import com.aliyun.openservices.eas.predict.proto.PredictProtos.ArrayShape;
import com.aliyun.openservices.eas.predict.proto.PredictProtos.ArrayDataType;
import com.aliyun.openservices.eas.predict.proto.PredictProtos.ArrayProto;
import com.aliyun.openservices.eas.predict.proto.PredictProtos.PredictRequest;
import shade.protobuf.ByteString;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/**
 * Created by yaozheng.wyz on 2017/10/24.
 */
public class TFRequest {
    private PredictRequest.Builder request = PredictRequest.newBuilder();
    private static Log log = LogFactory.getLog(TFRequest.class);
    private String prefix = "";
    public void setSignatureName(String value) {
        request.setSignatureName(value);
    }
    public void setPrefix(String value) { this.prefix = value; }

    public void addFetch(String value) {
        if (prefix.isEmpty()) {
            request.addOutputFilter(value);
        } else if (prefix.endsWith("/")) {
            request.addOutputFilter(prefix + value);
        } else {
            request.addOutputFilter(prefix + "/" + value);
        }
    }

    public void addFeed(String inputName, TFDataType dataType, long[] shape, float[] content) {
        ArrayProto.Builder requestProto = ArrayProto.newBuilder();
        if (dataType == TFDataType.DT_FLOAT) {
            requestProto.setDtype(ArrayDataType.DT_FLOAT);
        } else if (dataType == TFDataType.DT_COMPLEX64) {
            requestProto.setDtype(ArrayDataType.DT_COMPLEX64);
        } else if (dataType == TFDataType.DT_BFLOAT16) {
            requestProto.setDtype(ArrayDataType.DT_BFLOAT16);
        } else if (dataType == TFDataType.DT_HALF) {
            requestProto.setDtype(ArrayDataType.DT_HALF);
        } else {
            log.error("call addFeed Error: TFDataType and content mismatch!");
            throw new RuntimeException("call addFeed Error: TFDataType and content mismatch!");
        }
        ArrayShape.Builder arrayShape = ArrayShape.newBuilder();
        for (int i = 0; i < shape.length; i++)
            arrayShape.addDim(shape[i]);
        requestProto.mergeArrayShape(arrayShape.build());
        for (int i = 0; i < content.length; i++)
            requestProto.addFloatVal(content[i]);
        if (prefix.isEmpty()) {
            request.putInputs(inputName, requestProto.build());
        } else if (prefix.endsWith("/")) {
            request.putInputs(prefix + inputName, requestProto.build());
        } else {
            request.putInputs(prefix + "/" + inputName, requestProto.build());
        }
    }

    public void addFeed(String inputName, TFDataType dataType, long[] shape, double[] content) {
        ArrayProto.Builder requestProto = ArrayProto.newBuilder();
        if (dataType == TFDataType.DT_DOUBLE) {
            requestProto.setDtype(ArrayDataType.DT_DOUBLE);
        } else if (dataType == TFDataType.DT_COMPLEX128) {
            requestProto.setDtype(ArrayDataType.DT_COMPLEX128);
        } else {
            log.error("call addFeed Error: TFDataType and content mismatch!");
            throw new RuntimeException("call addFeed Error: TFDataType and content mismatch!");
        }
        ArrayShape.Builder arrayShape = ArrayShape.newBuilder();
        for (int i = 0; i < shape.length; i++)
            arrayShape.addDim(shape[i]);
        requestProto.mergeArrayShape(arrayShape.build());
        for (int i = 0; i < content.length; i++)
            requestProto.addDoubleVal(content[i]);
        if (prefix.isEmpty()) {
            request.putInputs(inputName, requestProto.build());
        } else if (prefix.endsWith("/")) {
            request.putInputs(prefix + inputName, requestProto.build());
        } else {
            request.putInputs(prefix + "/" + inputName, requestProto.build());
        }
    }

    public void addFeed(String inputName, TFDataType dataType, long[] shape, int[] content) {
        ArrayProto.Builder requestProto = ArrayProto.newBuilder();
        if (dataType == TFDataType.DT_INT32) {
            requestProto.setDtype(ArrayDataType.DT_INT32);
        } else if (dataType == TFDataType.DT_UINT8) {
            requestProto.setDtype(ArrayDataType.DT_UINT8);
        } else if (dataType == TFDataType.DT_INT16) {
            requestProto.setDtype(ArrayDataType.DT_INT16);
        } else if (dataType == TFDataType.DT_INT8) {
            requestProto.setDtype(ArrayDataType.DT_INT8);
        } else if (dataType == TFDataType.DT_QINT8) {
            requestProto.setDtype(ArrayDataType.DT_QINT8);
        } else if (dataType == TFDataType.DT_QUINT8) {
            requestProto.setDtype(ArrayDataType.DT_QUINT8);
        } else if (dataType == TFDataType.DT_QINT32) {
            requestProto.setDtype(ArrayDataType.DT_QINT32);
        } else if (dataType == TFDataType.DT_QINT16) {
            requestProto.setDtype(ArrayDataType.DT_QINT16);
        } else if (dataType == TFDataType.DT_QUINT16) {
            requestProto.setDtype(ArrayDataType.DT_QUINT16);
        } else if (dataType == TFDataType.DT_UINT16) {
            requestProto.setDtype(ArrayDataType.DT_UINT16);
        } else {
            log.error("call addFeed Error: TFDataType and content mismatch");
            throw new RuntimeException("call addFeed Error: TFDataType and content mismatch");
        }
        ArrayShape.Builder arrayShape = ArrayShape.newBuilder();
        for (int i = 0; i < shape.length; i++) {
            arrayShape.addDim(shape[i]);
        }
        requestProto.mergeArrayShape(arrayShape.build());
        for (int i = 0; i < content.length; i++)
            requestProto.addIntVal(content[i]);
        if (prefix.isEmpty()) {
            request.putInputs(inputName, requestProto.build());
        } else if (prefix.endsWith("/")) {
            request.putInputs(prefix + inputName, requestProto.build());
        } else {
            request.putInputs(prefix + "/" + inputName, requestProto.build());
        }
    }

    public void addFeed(String inputName, TFDataType dataType, long[] shape, String[] content) {
        ArrayProto.Builder requestProto = ArrayProto.newBuilder();
        if (dataType == TFDataType.DT_STRING) {
            requestProto.setDtype(ArrayDataType.DT_STRING);
        } else {
            log.error("call addFeed Error: TFDataType and content mismatch");
            throw new RuntimeException("call addFeed Error: TFDataType and content mismatch");
        }
        ArrayShape.Builder arrayShape = ArrayShape.newBuilder();
        for (int i = 0; i < shape.length; i++)
            arrayShape.addDim(shape[i]);
        requestProto.mergeArrayShape(arrayShape.build());
        for (int i = 0; i < content.length; i++)
            requestProto.addStringVal(ByteString.copyFromUtf8(content[i]));
        if (prefix.isEmpty()) {
            request.putInputs(inputName, requestProto.build());
        } else if (prefix.endsWith("/")) {
            request.putInputs(prefix + inputName, requestProto.build());
        } else {
            request.putInputs(prefix + "/" + inputName, requestProto.build());
        }
    }

    public void addFeed(String inputName, TFDataType dataType, long[] shape, long[] content) {
        ArrayProto.Builder requestProto = ArrayProto.newBuilder();
        if (dataType == TFDataType.DT_INT64) {
            requestProto.setDtype(ArrayDataType.DT_INT64);
        } else {
            log.error("call addFeed Error: TFDataType and content mismatch");
            throw new RuntimeException("call addFeed Error: TFDataType and content mismatch");
        }
        ArrayShape.Builder arrayShape = ArrayShape.newBuilder();
        for (int i = 0; i < shape.length; i++)
            arrayShape.addDim(shape[i]);
        requestProto.mergeArrayShape(arrayShape.build());
        for (int i = 0; i < content.length; i++)
            requestProto.addInt64Val(content[i]);
        if (prefix.isEmpty()) {
            request.putInputs(inputName, requestProto.build());
        } else if (prefix.endsWith("/")) {
            request.putInputs(prefix + inputName, requestProto.build());
        } else {
            request.putInputs(prefix + "/" + inputName, requestProto.build());
        }
    }

    public void addFeed(String inputName, TFDataType dataType, long[] shape, boolean[] content) {
        ArrayProto.Builder requestProto = ArrayProto.newBuilder();
        if (dataType == TFDataType.DT_BOOL) {
            requestProto.setDtype(ArrayDataType.DT_BOOL);
        } else {
            log.error("call addFeed Error: TFDataType and content mismatch");
            throw new RuntimeException("call addFeed Error: TFDataType and content mismatch");
        }
        ArrayShape.Builder arrayShape = ArrayShape.newBuilder();
        for (int i = 0; i < shape.length; i++)
            arrayShape.addDim(shape[i]);
        requestProto.mergeArrayShape(arrayShape.build());
        for (int i = 0; i < content.length; i++)
            requestProto.addBoolVal(content[i]);
        if (prefix.isEmpty()) {
            request.putInputs(inputName, requestProto.build());
        } else if (prefix.endsWith("/")) {
            request.putInputs(prefix + inputName, requestProto.build());
        } else {
            request.putInputs(prefix + "/" + inputName, requestProto.build());
        }
    }

    public PredictRequest getRequest() {
        return request.build();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy