
com.redislabs.redisai.Model Maven / Gradle / Ivy
The newest version!
package com.redislabs.redisai;
import com.redislabs.redisai.exceptions.JRedisAIRunTimeException;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import redis.clients.jedis.Protocol;
import redis.clients.jedis.util.SafeEncoder;
/** Direct mapping to RedisAI Model */
public class Model {
public static final String BLOB_CHUNK_SIZE_PROPERTY = "redisai.blob.chunkSize";
private static final int BLOB_CHUNK_SIZE =
Integer.parseInt(System.getProperty(BLOB_CHUNK_SIZE_PROPERTY, "536870912"));
private Backend backend; // TODO: final
private Device device; // TODO: final
private String[] inputs;
private String[] outputs;
private byte[] blob; // TODO: final
private String tag;
private long batchSize;
private long minBatchSize;
private long minBatchTimeout;
/**
* @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
* @param device - the device that will execute the model. can be of CPU or GPU
* @param modelUri - filepath of the Protobuf-serialized model
* @throws java.io.IOException
* @see #Model(com.redislabs.redisai.Backend, com.redislabs.redisai.Device, byte[])
* @see Files#readAllBytes(java.nio.file.Path)
* @see Paths#get(java.net.URI)
*/
public Model(Backend backend, Device device, URI modelUri) throws IOException {
this(backend, device, Files.readAllBytes(Paths.get(modelUri)));
}
/**
* @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
* @param device - the device that will execute the model. can be of CPU or GPU
* @param blob - the Protobuf-serialized model
*/
public Model(Backend backend, Device device, byte[] blob) {
this.backend = backend;
this.device = device;
this.blob = blob;
}
/**
* @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
* @param device - the device that will execute the model. can be of CPU or GPU
* @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow
* models)
* @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow
* models)
* @param blob - the Protobuf-serialized model
*/
public Model(Backend backend, Device device, String[] inputs, String[] outputs, byte[] blob) {
this(backend, device, inputs, outputs, blob, 0, 0);
}
/**
* @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
* @param device - the device that will execute the model. can be of CPU or GPU
* @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow
* models)
* @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow
* models)
* @param blob - the Protobuf-serialized model
* @param batchSize - when provided with an batchsize that is greater than 0, the engine will
* batch incoming requests from multiple clients that use the model with input tensors of the
* same shape.
* @param minBatchSize - when provided with an minbatchsize that is greater than 0, the engine
* will postpone calls to AI.MODELRUN until the batch's size had reached minbatchsize
*/
public Model(
Backend backend,
Device device,
String[] inputs,
String[] outputs,
byte[] blob,
long batchSize,
long minBatchSize) {
this.backend = backend;
this.device = device;
this.inputs = inputs;
this.outputs = outputs;
this.blob = blob;
this.tag = null;
this.batchSize = batchSize;
this.minBatchSize = minBatchSize;
}
public static Model createModelFromRespReply(List> reply) {
Backend backend = null;
Device device = null;
String tag = null;
byte[] blob = null;
long batchsize = 0;
long minbatchsize = 0;
long minbatchtimeout = 0;
String[] inputs = new String[0];
String[] outputs = new String[0];
for (int i = 0; i < reply.size(); i += 2) {
String arrayKey = SafeEncoder.encode((byte[]) reply.get(i));
switch (arrayKey) {
case "backend":
String backendString = SafeEncoder.encode((byte[]) reply.get(i + 1));
backend = Backend.valueOf(backendString);
if (backend == null) {
throw new JRedisAIRunTimeException("Unrecognized backend: " + backendString);
}
break;
case "device":
String deviceString = SafeEncoder.encode((byte[]) reply.get(i + 1));
device = Device.valueOf(deviceString);
if (device == null) {
throw new JRedisAIRunTimeException("Unrecognized device: " + deviceString);
}
break;
case "tag":
tag = SafeEncoder.encode((byte[]) reply.get(i + 1));
break;
case "blob":
blob = (byte[]) reply.get(i + 1);
break;
case "batchsize":
batchsize = (Long) reply.get(i + 1);
break;
case "minbatchsize":
minbatchsize = (Long) reply.get(i + 1);
break;
case "minbatchtimeout":
minbatchtimeout = (Long) reply.get(i + 1);
break;
case "inputs":
List inputsEncoded = (List) reply.get(i + 1);
if (!inputsEncoded.isEmpty()) {
inputs = new String[inputsEncoded.size()];
for (int j = 0; j < inputsEncoded.size(); j++) {
inputs[j] = SafeEncoder.encode(inputsEncoded.get(j));
}
}
break;
case "outputs":
List outputsEncoded = (List) reply.get(i + 1);
if (!outputsEncoded.isEmpty()) {
outputs = new String[outputsEncoded.size()];
for (int j = 0; j < outputsEncoded.size(); j++) {
outputs[j] = SafeEncoder.encode(outputsEncoded.get(j));
}
}
break;
default:
break;
}
}
if (backend == null || device == null || blob == null) {
throw new JRedisAIRunTimeException(
"AI.MODELGET reply did not contained all elements to build the model");
}
return new Model(backend, device, blob)
.setInputs(inputs)
.setOutputs(outputs)
.setBatchSize(batchsize)
.setMinBatchSize(minbatchsize)
.setMinBatchTimeout(minbatchtimeout)
.setTag(tag);
}
public String getTag() {
return tag;
}
public Model setTag(String tag) {
this.tag = tag;
return this;
}
public byte[] getBlob() {
return blob;
}
/**
* @param blob
* @deprecated This variable will be final. Use any constructor.
*/
@Deprecated
public void setBlob(byte[] blob) {
this.blob = blob;
}
public String[] getOutputs() {
return outputs;
}
public Model setOutputs(String[] outputs) {
this.outputs = outputs;
return this;
}
public String[] getInputs() {
return inputs;
}
public Model setInputs(String[] inputs) {
this.inputs = inputs;
return this;
}
public Device getDevice() {
return device;
}
/**
* @param device
* @deprecated This variable will be final. Use any constructor.
*/
@Deprecated
public void setDevice(Device device) {
this.device = device;
}
public Backend getBackend() {
return backend;
}
/**
* @param backend
* @deprecated This variable will be final. Use any constructor.
*/
@Deprecated
public void setBackend(Backend backend) {
this.backend = backend;
}
public long getBatchSize() {
return batchSize;
}
public Model setBatchSize(long batchsize) {
this.batchSize = batchsize;
return this;
}
public long getMinBatchSize() {
return minBatchSize;
}
public Model setMinBatchSize(long minbatchsize) {
this.minBatchSize = minbatchsize;
return this;
}
public long getMinBatchTimeout() {
return minBatchTimeout;
}
public Model setMinBatchTimeout(long minBatchTimeout) {
this.minBatchTimeout = minBatchTimeout;
return this;
}
/**
* Encodes the current model properties into an AI.MODELSET command to be store in RedisAI Server
*
* @param key name of key to store the Model
* @return
*/
protected List getModelSetCommandBytes(String key) {
List args = new ArrayList<>();
args.add(SafeEncoder.encode(key));
args.add(backend.getRaw());
args.add(device.getRaw());
if (tag != null) {
args.add(Keyword.TAG.getRaw());
args.add(SafeEncoder.encode(tag));
}
if (batchSize > 0) {
args.add(Keyword.BATCHSIZE.getRaw());
args.add(Protocol.toByteArray(batchSize));
if (minBatchSize > 0) {
args.add(Keyword.MINBATCHSIZE.getRaw());
args.add(Protocol.toByteArray(minBatchSize));
}
}
args.add(Keyword.INPUTS.getRaw());
for (String input : inputs) {
args.add(SafeEncoder.encode(input));
}
args.add(Keyword.OUTPUTS.getRaw());
for (String output : outputs) {
args.add(SafeEncoder.encode(output));
}
args.add(Keyword.BLOB.getRaw());
args.add(blob);
return args;
}
/**
* Encodes the current model properties into an AI.MODELSTORE command to store in RedisAI Server.
*
* @param key
* @return
*/
protected List getModelStoreCommandArgs(String key) {
List args = new ArrayList<>();
args.add(SafeEncoder.encode(key));
args.add(backend.getRaw());
args.add(device.getRaw());
if (tag != null) {
args.add(Keyword.TAG.getRaw());
args.add(SafeEncoder.encode(tag));
}
if (batchSize > 0) {
args.add(Keyword.BATCHSIZE.getRaw());
args.add(Protocol.toByteArray(batchSize));
args.add(Keyword.MINBATCHSIZE.getRaw());
args.add(Protocol.toByteArray(minBatchSize));
args.add(Keyword.MINBATCHTIMEOUT.getRaw());
args.add(Protocol.toByteArray(minBatchTimeout));
}
if (inputs != null && inputs.length > 0) {
args.add(Keyword.INPUTS.getRaw());
args.add(Protocol.toByteArray(inputs.length));
for (String input : inputs) {
args.add(SafeEncoder.encode(input));
}
}
if (outputs != null && outputs.length > 0) {
args.add(Keyword.OUTPUTS.getRaw());
args.add(Protocol.toByteArray(outputs.length));
for (String output : outputs) {
args.add(SafeEncoder.encode(output));
}
}
args.add(Keyword.BLOB.getRaw());
collectChunks(args, blob);
return args;
}
private static void collectChunks(List collector, byte[] array) {
final int chunkSize = BLOB_CHUNK_SIZE;
if (chunkSize <= 0 || array.length <= chunkSize) {
collector.add(array);
return;
}
int from = 0;
while (from < array.length) {
int copySize = Math.min(array.length - from, chunkSize);
collector.add(Arrays.copyOfRange(array, from, from + copySize));
from += copySize;
}
}
protected static List modelRunFlatArgs(
String key, String[] inputs, String[] outputs, boolean includeCommandName) {
List args = new ArrayList<>();
if (includeCommandName) {
args.add(Command.MODEL_RUN.getRaw());
}
args.add(SafeEncoder.encode(key));
args.add(Keyword.INPUTS.getRaw());
for (String input : inputs) {
args.add(SafeEncoder.encode(input));
}
args.add(Keyword.OUTPUTS.getRaw());
for (String output : outputs) {
args.add(SafeEncoder.encode(output));
}
return args;
}
protected static List modelExecuteCommandArgs(
String key, String[] inputs, String[] outputs, long timeout, boolean includeCommandName) {
List args = new ArrayList<>();
if (includeCommandName) {
args.add(Command.MODEL_EXECUTE.getRaw());
}
args.add(SafeEncoder.encode(key));
args.add(Keyword.INPUTS.getRaw());
args.add(Protocol.toByteArray(inputs.length));
for (String input : inputs) {
args.add(SafeEncoder.encode(input));
}
args.add(Keyword.OUTPUTS.getRaw());
args.add(Protocol.toByteArray(outputs.length));
for (String output : outputs) {
args.add(SafeEncoder.encode(output));
}
if (timeout >= 0) {
args.add(Keyword.TIMEOUT.getRaw());
args.add(Protocol.toByteArray(timeout));
}
return args;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy