All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.aliyun.openservices.eas.predict.response.TorchResponse Maven / Gradle / Ivy

package com.aliyun.openservices.eas.predict.response;

import com.aliyun.openservices.eas.predict.proto.TorchPredictProtos.ArrayProto;
import com.aliyun.openservices.eas.predict.proto.TorchPredictProtos.PredictResponse;
import shade.protobuf
        .ByteString;
import shade.protobuf
        .InvalidProtocolBufferException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import java.util.ArrayList;
import java.util.List;


public class TorchResponse {
    private static Log log = LogFactory.getLog(TorchResponse.class);
    private PredictResponse response = null;

    public void setContentValues(byte[] content) {
        try {
            response = PredictResponse.parseFrom(content);
        } catch (NullPointerException e) {
            e.printStackTrace();
        } catch (InvalidProtocolBufferException e) {
            e.printStackTrace();
        }
    }

    public List getTensorShape(int index) {
        if (response != null) {
            if (response.getOutputsCount() <= index) {
                log.error("Output_filter should not have more tensors than model's outputs: " + response.getOutputsCount());
                throw new RuntimeException("Output_filter should not have more tensors than model's outputs: " + response.getOutputsCount());
            }
            ArrayProto responseProto = response.getOutputs(index);
            return responseProto.getArrayShape().getDimList();
        } else {
            log.error("request failed: can't get response");
            return new ArrayList();
        }
    }

    public List getFloatVals(int index) {
        if (response != null) {
            if (response.getOutputsCount() <= index) {
                log.error("Output_filter should not have more tensors than model's outputs: " + response.getOutputsCount());
                throw new RuntimeException("Output_filter should not have more tensors than model's outputs: " + response.getOutputsCount());
            }
            ArrayProto responseProto = response.getOutputs(index);
            return responseProto.getFloatValList();
        } else {
            log.error("request failed: can't get response");
            return new ArrayList();
        }
    }

    public List getDoubleVals(int index) {
        if (response != null) {
            if (response.getOutputsCount() <= index) {
                log.error("Output_filter should not have more tensors than model's outputs: " + response.getOutputsCount());
                throw new RuntimeException("Output_filter should not have more tensors than model's outputs: " + response.getOutputsCount());
            }
            ArrayProto responseProto = response.getOutputs(index);
            return responseProto.getDoubleValList();
        } else {
            log.error("request failed: can't get response");
            return new ArrayList();
        }
    }

    public List getIntVals(int index) {
        if (response != null) {
            if (response.getOutputsCount() <= index) {
                log.error("Output_filter should not have more tensors than model's outputs: " + response.getOutputsCount());
                throw new RuntimeException("Output_filter should not have more tensors than model's outputs: " + response.getOutputsCount());
            }
            ArrayProto responseProto = response.getOutputs(index);
            return responseProto.getIntValList();
        } else {
            log.error("request failed: can't get response");
            return new ArrayList();
        }
    }


    public List getInt64Vals(int index) {
        if (response != null) {
            if (response.getOutputsCount() <= index) {
                log.error("Output_filter should not have more tensors than model's outputs: " + response.getOutputsCount());
                throw new RuntimeException("Output_filter should not have more tensors than model's outputs: " + response.getOutputsCount());
            }
            ArrayProto responseProto = response.getOutputs(index);
            return responseProto.getInt64ValList();
        } else {
            log.error("request failed: can't get response");
            return new ArrayList();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy