
com.redislabs.redisai.RedisAI Maven / Gradle / Ivy
package com.redislabs.redisai;
import com.redislabs.redisai.exceptions.JRedisAIRunTimeException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import redis.clients.jedis.BinaryClient;
import redis.clients.jedis.Client;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisClientConfig;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
import redis.clients.jedis.exceptions.JedisDataException;
import redis.clients.jedis.util.Pool;
import redis.clients.jedis.util.SafeEncoder;
public class RedisAI implements AutoCloseable {
private final Pool pool;
/** Create a new RedisAI client with default connection to local host */
public RedisAI() {
this("localhost", 6379);
}
/**
* Create a new RedisAI client
*
* @param host the redis host
* @param port the redis pot
*/
public RedisAI(String host, int port) {
this(host, port, 500, 100);
}
/**
* Create a new RedisAI client
*
* @param host the redis host
* @param port the redis pot
*/
public RedisAI(String host, int port, int timeout, int poolSize) {
this(host, port, timeout, poolSize, null);
}
/**
* Create a new RedisAI client
*
* @param host the redis host
* @param port the redis pot
* @param password the password for authentication in a password protected Redis server
*/
public RedisAI(String host, int port, int timeout, int poolSize, String password) {
this(new JedisPool(initPoolConfig(poolSize), host, port, timeout, password));
}
/**
* Create a new RedisAI client
*
* @param hostAndPort
* @param clientConfig
*/
public RedisAI(HostAndPort hostAndPort, JedisClientConfig clientConfig) {
this(new JedisPool(new GenericObjectPoolConfig<>(), hostAndPort, clientConfig));
}
/**
* Create a new RedisAI client
*
* @param pool jedis connection pool
*/
public RedisAI(Pool pool) {
this.pool = pool;
}
@Override
public void close() {
this.pool.close();
}
/**
* Constructs JedisPoolConfig object.
*
* @param poolSize size of the JedisPool
* @return {@link JedisPoolConfig} object with a few default settings
*/
private static JedisPoolConfig initPoolConfig(int poolSize) {
JedisPoolConfig conf = new JedisPoolConfig();
conf.setMaxTotal(poolSize);
conf.setTestOnBorrow(false);
conf.setTestOnReturn(false);
conf.setTestOnCreate(false);
conf.setTestWhileIdle(false);
conf.setMinEvictableIdleTimeMillis(60000);
conf.setTimeBetweenEvictionRunsMillis(30000);
conf.setNumTestsPerEvictionRun(-1);
conf.setFairness(true);
return conf;
}
private Jedis getConnection() {
return pool.getResource();
}
private BinaryClient sendCommand(Jedis conn, Command command, byte[]... args) {
BinaryClient client = conn.getClient();
client.sendCommand(command, args);
return client;
}
private Client sendCommand(Jedis conn, Command command, String... args) {
Client client = conn.getClient();
client.sendCommand(command, args);
return client;
}
/**
* Direct mapping to AI.TENSORSET
*
* @param key name of key to store the Tensor
* @param values multi-dimension numeric data
* @param shape one or more dimensions, or the number of elements per axis, for the tensor
* @return true if Tensor was properly set in RedisAI server
*/
public boolean setTensor(String key, Object values, int[] shape) {
DataType dataType = DataType.baseObjType(values);
long[] shapeL = new long[shape.length];
for (int i = 0; i < shape.length; i++) {
shapeL[i] = shape[i];
}
Tensor tensor = new Tensor(dataType, shapeL, values);
return setTensor(key, tensor);
}
/**
* Direct mapping to AI.TENSORSET
*
* @param key name of key to store the Tensor
* @param tensor Tensor object
* @return true if Tensor was properly set in RedisAI server
*/
public boolean setTensor(String key, Tensor tensor) {
try (Jedis conn = getConnection()) {
List args = tensor.tensorSetFlatArgs(key, false);
return sendCommand(conn, Command.TENSOR_SET, args.toArray(new byte[args.size()][]))
.getStatusCodeReply()
.equals("OK");
} catch (JedisDataException ex) {
throw new RedisAIException(ex);
}
}
/**
* Direct mapping to AI.TENSORGET
*
* @param key name of key to get the Tensor from
* @return Tensor
* @throws JRedisAIRunTimeException
*/
public Tensor getTensor(String key) {
try (Jedis conn = getConnection()) {
List args = Tensor.tensorGetFlatArgs(key, false);
List> reply =
sendCommand(conn, Command.TENSOR_GET, args.toArray(new byte[args.size()][]))
.getObjectMultiBulkReply();
if (reply.isEmpty()) {
return null;
}
return Tensor.createTensorFromRespReply(reply);
}
}
/**
* Direct mapping to AI.MODELSET
*
* @param key name of key to store the Model
* @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
* @param device - the device that will execute the model. can be of CPU or GPU
* @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow
* models)
* @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow
* models)
* @param modelPath - the file path for the Protobuf-serialized model
* @return true if Model was properly set in RedisAI server
*/
public boolean setModel(
String key,
Backend backend,
Device device,
String[] inputs,
String[] outputs,
String modelPath) {
try {
byte[] blob = Files.readAllBytes(Paths.get(modelPath));
Model model = new Model(backend, device, inputs, outputs, blob);
return setModel(key, model);
} catch (IOException ex) {
throw new RedisAIException(ex);
}
}
/**
* Direct mapping to AI.MODELSET
*
* @param key name of key to store the Model
* @param model Model object
* @return true if Model was properly set in RedisAI server
*/
public boolean setModel(String key, Model model) {
try (Jedis conn = getConnection()) {
List args = model.getModelSetCommandBytes(key);
return sendCommand(conn, Command.MODEL_SET, args.toArray(new byte[args.size()][]))
.getStatusCodeReply()
.equals("OK");
} catch (JedisDataException ex) {
throw new RedisAIException(ex);
}
}
/**
* Direct mapping to AI.MODELSTORE command.
*
* {@code AI.MODELSTORE [TAG tag] [BATCHSIZE n [MINBATCHSIZE m]]
* [INPUTS ...] [OUTPUTS ...] BLOB }
*
* @param key name of key to store the Model
* @param model Model object
* @return true if Model was properly stored in RedisAI server
*/
public boolean storeModel(String key, Model model) {
try (Jedis conn = getConnection()) {
List args = model.getModelStoreCommandArgs(key);
return sendCommand(conn, Command.MODEL_STORE, args.toArray(new byte[args.size()][]))
.getStatusCodeReply()
.equals("OK");
} catch (JedisDataException ex) {
throw new RedisAIException(ex.getMessage(), ex);
}
}
/**
* Direct mapping to AI.MODELGET
*
* @param key name of key to get the Model from RedisAI server
* @return Model
* @throws JRedisAIRunTimeException
*/
public Model getModel(String key) {
try (Jedis conn = getConnection()) {
List> reply =
sendCommand(
conn,
Command.MODEL_GET,
SafeEncoder.encode(key),
Keyword.META.getRaw(),
Keyword.BLOB.getRaw())
.getObjectMultiBulkReply();
if (reply.isEmpty()) {
return null;
}
return Model.createModelFromRespReply(reply);
}
}
/**
* Direct mapping to AI.MODELDEL
*
* @param key name of key to delete the Model
* @return true if Model was properly delete in RedisAI server
*/
public boolean delModel(String key) {
try (Jedis conn = getConnection()) {
return sendCommand(conn, Command.MODEL_DEL, SafeEncoder.encode(key))
.getStatusCodeReply()
.equals("OK");
} catch (JedisDataException ex) {
throw new RedisAIException(ex);
}
}
/**
* Direct mapping to AI.SCRIPTSET
*
* @param key name of key to store the Script in RedisAI server
* @param device - the device that will execute the model. can be of CPU or GPU
* @param scriptFile - the file path for the script source code
* @return true if Script was properly set in RedisAI server
*/
public boolean setScriptFile(String key, Device device, String scriptFile) {
try {
Script script = new Script(device, Paths.get(scriptFile));
return setScript(key, script);
} catch (IOException ex) {
throw new RedisAIException(ex);
}
}
/**
* Direct mapping to AI.SCRIPTSET
*
* @param key name of key to store the Script in RedisAI server
* @param device - the device that will execute the model. can be of CPU or GPU
* @param source - the script source code
* @return true if Script was properly set in RedisAI server
*/
public boolean setScript(String key, Device device, String source) {
Script script = new Script(device, source);
return setScript(key, script);
}
/**
* Direct mapping to AI.SCRIPTSET
*
* @param key name of key to store the Script in RedisAI server
* @param script the Script Object
* @return true if Script was properly set in RedisAI server
*/
public boolean setScript(String key, Script script) {
try (Jedis conn = getConnection()) {
List args = script.getScriptSetCommandBytes(key);
return sendCommand(conn, Command.SCRIPT_SET, args.toArray(new byte[args.size()][]))
.getStatusCodeReply()
.equals("OK");
} catch (JedisDataException ex) {
throw new RedisAIException(ex);
}
}
/**
* Direct mapping to AI.MODELSTORE command.
*
* {@code AI.SCRIPTSTORE [TAG tag] ENTRY_POINTS
* [...] SOURCE "