All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.seldon.wrapper.utils.DL4JUtils Maven / Gradle / Ivy

package io.seldon.wrapper.utils;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import com.google.protobuf.ListValue;
import com.google.protobuf.Value;

import io.seldon.protos.PredictionProtos.DefaultData;
import io.seldon.protos.PredictionProtos.Tensor;
import io.seldon.protos.PredictionProtos.DefaultData.DataOneofCase;

/**
 * Utilities for working with Deep Learning4J Models
 * @author clive
 *
 */
public class DL4JUtils {
	
	/**
	 * Convert seldon protobuf DefaultData to nd4j Array
	 * @param data Seldon protobuf message
	 * @return nd4j Array
	 */
	public static INDArray getINDArray(DefaultData data) {

		if (data.getDataOneofCase() == DataOneofCase.TENSOR) {

			List valuesList = data.getTensor().getValuesList();
			List shapeList = data.getTensor().getShapeList();

			double[] values = new double[valuesList.size()];
			int[] shape = new int[shapeList.size()];
			for (int i = 0; i < values.length; i++) {
				values[i] = valuesList.get(i);
			}
			for (int i = 0; i < shape.length; i++) {
				shape[i] = shapeList.get(i);
			}

			INDArray newArr = Nd4j.create(values, shape, 'c');

			return newArr;
		} else if (data.getDataOneofCase() == DataOneofCase.NDARRAY) {
			ListValue list = data.getNdarray();
			int bLength = list.getValuesCount();
			int vLength = list.getValues(0).getListValue().getValuesCount();

			double[] values = new double[bLength * vLength];
			int[] shape = { bLength, vLength };

			for (int i = 0; i < bLength; ++i) {
				for (int j = 0; j < vLength; j++) {
					values[i * bLength + j] = list.getValues(i).getListValue().getValues(j).getNumberValue();
				}
			}

			INDArray newArr = Nd4j.create(values, shape, 'c');

			return newArr;
		}
		return null;
	}

	/**
	 * Convert a nd4j array into a seldon protobuf DefaultData following same type as oldData
	 * @param oldData original data
	 * @param newData nd4j array
	 * @return seldon DefaultData protobuf message
	 */
	public static DefaultData updateData(DefaultData oldData, INDArray newData) {
		DefaultData.Builder dataBuilder = DefaultData.newBuilder();

		dataBuilder.addAllNames(oldData.getNamesList());

		// int index=0;
		// for (Iterator i = oldData.getFeaturesList().iterator();
		// i.hasNext();){
		// dataBuilder.setFeatures(index, i.next());
		// index++;
		// }

		if (oldData == null || oldData.getDataOneofCase() == DataOneofCase.TENSOR) {
			Tensor.Builder tBuilder = Tensor.newBuilder();
			List shapeList = Arrays.stream(newData.shape()).boxed().collect(Collectors.toList());
			tBuilder.addAllShape(shapeList);

			for (int i = 0; i < shapeList.get(0); ++i) {
				for (int j = 0; j < shapeList.get(1); ++j) {
					tBuilder.addValues(newData.getDouble(i, j));
				}
			}
			dataBuilder.setTensor(tBuilder);
			return dataBuilder.build();
		} else if (oldData.getDataOneofCase() == DataOneofCase.NDARRAY) {
			ListValue.Builder b1 = ListValue.newBuilder();
			for (int i = 0; i < newData.shape()[0]; ++i) {
				ListValue.Builder b2 = ListValue.newBuilder();
				for (int j = 0; j < newData.shape()[1]; j++) {
					b2.addValues(Value.newBuilder().setNumberValue(newData.getDouble(i, j)));
				}
				b1.addValues(Value.newBuilder().setListValue(b2.build()));
			}
			dataBuilder.setNdarray(b1.build());
			return dataBuilder.build();
		}
		return null;

	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy