com.aliyun.openservices.eas.predict.response.CaffeResponse Maven / Gradle / Ivy
package com.aliyun.openservices.eas.predict.response;
import com.aliyun.openservices.eas.predict.proto.CaffePredictProtos.PredictResponse;
import com.aliyun.openservices.eas.predict.proto.CaffePredictProtos.ArrayProto;
import shade.protobuf
.InvalidProtocolBufferException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import java.util.ArrayList;
import java.util.List;
/**
* Created by yaozheng.wyz on 2017/11/27.
*/
public class CaffeResponse {
private static Log log = LogFactory.getLog(CaffeResponse.class);
private PredictResponse response = null;
public void setContentValues(byte[] content) {
try {
response = PredictResponse.parseFrom(content);
} catch (InvalidProtocolBufferException e) {
e.printStackTrace();
} catch (NullPointerException e) {
e.printStackTrace();
}
}
public List getBlobShape(String outputname) {
if (response != null) {
int i = 0;
for (; i < response.getOutputNameCount(); i++) {
if (outputname.equals(response.getOutputName(i))) {
ArrayProto responseProto = response.getOutputData(i);
return responseProto.getShape().getDimList();
}
}
if (i == response.getOutputNameCount()) {
log.error("Not Found output name: " + outputname);
throw new RuntimeException("Not Found output name: " + outputname);
}
} else {
log.error("request failed: can't get response");
return new ArrayList();
}
return new ArrayList();
}
public List getVals(String outputname) {
if (response != null) {
int i = 0;
for (; i < response.getOutputNameCount(); i++) {
if (outputname.equals(response.getOutputName(i))) {
ArrayProto responseProto = response.getOutputData(i);
return responseProto.getDataList();
}
}
if (i == response.getOutputNameCount()) {
log.error("Not Found output name: " + outputname);
throw new RuntimeException("Not Found output name: " + outputname);
}
} else {
log.error("request failed: can't get response");
return new ArrayList();
}
return new ArrayList();
}
}