All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.redislabs.redisai.RedisAI Maven / Gradle / Ivy

The newest version!
package com.redislabs.redisai;

import com.redislabs.redisai.exceptions.JRedisAIRunTimeException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import redis.clients.jedis.BinaryClient;
import redis.clients.jedis.Client;
import redis.clients.jedis.HostAndPort;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisClientConfig;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
import redis.clients.jedis.exceptions.JedisDataException;
import redis.clients.jedis.util.Pool;
import redis.clients.jedis.util.SafeEncoder;

public class RedisAI implements AutoCloseable {

  private final Pool pool;

  /** Create a new RedisAI client with default connection to local host */
  public RedisAI() {
    this("localhost", 6379);
  }

  /**
   * Create a new RedisAI client
   *
   * @param host the redis host
   * @param port the redis pot
   */
  public RedisAI(String host, int port) {
    this(host, port, 500, 100);
  }

  /**
   * Create a new RedisAI client
   *
   * @param host the redis host
   * @param port the redis pot
   */
  public RedisAI(String host, int port, int timeout, int poolSize) {
    this(host, port, timeout, poolSize, null);
  }

  /**
   * Create a new RedisAI client
   *
   * @param host the redis host
   * @param port the redis pot
   * @param password the password for authentication in a password protected Redis server
   */
  public RedisAI(String host, int port, int timeout, int poolSize, String password) {
    this(new JedisPool(initPoolConfig(poolSize), host, port, timeout, password));
  }

  /**
   * Create a new RedisAI client
   *
   * @param hostAndPort
   * @param clientConfig
   */
  public RedisAI(HostAndPort hostAndPort, JedisClientConfig clientConfig) {
    this(new JedisPool(new GenericObjectPoolConfig<>(), hostAndPort, clientConfig));
  }

  /**
   * Create a new RedisAI client
   *
   * @param pool jedis connection pool
   */
  public RedisAI(Pool pool) {
    this.pool = pool;
  }

  @Override
  public void close() {
    this.pool.close();
  }

  /**
   * Constructs JedisPoolConfig object.
   *
   * @param poolSize size of the JedisPool
   * @return {@link JedisPoolConfig} object with a few default settings
   */
  private static JedisPoolConfig initPoolConfig(int poolSize) {
    JedisPoolConfig conf = new JedisPoolConfig();
    conf.setMaxTotal(poolSize);
    conf.setTestOnBorrow(false);
    conf.setTestOnReturn(false);
    conf.setTestOnCreate(false);
    conf.setTestWhileIdle(false);
    conf.setMinEvictableIdleTimeMillis(60000);
    conf.setTimeBetweenEvictionRunsMillis(30000);
    conf.setNumTestsPerEvictionRun(-1);
    conf.setFairness(true);

    return conf;
  }

  private Jedis getConnection() {
    return pool.getResource();
  }

  private BinaryClient sendCommand(Jedis conn, Command command, byte[]... args) {
    BinaryClient client = conn.getClient();
    client.sendCommand(command, args);
    return client;
  }

  private Client sendCommand(Jedis conn, Command command, String... args) {
    Client client = conn.getClient();
    client.sendCommand(command, args);
    return client;
  }

  /**
   * Direct mapping to AI.TENSORSET
   *
   * @param key name of key to store the Tensor
   * @param values multi-dimension numeric data
   * @param shape one or more dimensions, or the number of elements per axis, for the tensor
   * @return true if Tensor was properly set in RedisAI server
   */
  public boolean setTensor(String key, Object values, int[] shape) {
    DataType dataType = DataType.baseObjType(values);
    long[] shapeL = new long[shape.length];
    for (int i = 0; i < shape.length; i++) {
      shapeL[i] = shape[i];
    }
    Tensor tensor = new Tensor(dataType, shapeL, values);
    return setTensor(key, tensor);
  }

  /**
   * Direct mapping to AI.TENSORSET
   *
   * @param key name of key to store the Tensor
   * @param tensor Tensor object
   * @return true if Tensor was properly set in RedisAI server
   */
  public boolean setTensor(String key, Tensor tensor) {
    try (Jedis conn = getConnection()) {
      List args = tensor.tensorSetFlatArgs(key, false);
      return sendCommand(conn, Command.TENSOR_SET, args.toArray(new byte[args.size()][]))
          .getStatusCodeReply()
          .equals("OK");

    } catch (JedisDataException ex) {
      throw new RedisAIException(ex);
    }
  }

  /**
   * Direct mapping to AI.TENSORGET
   *
   * @param key name of key to get the Tensor from
   * @return Tensor
   * @throws JRedisAIRunTimeException
   */
  public Tensor getTensor(String key) {
    try (Jedis conn = getConnection()) {
      List args = Tensor.tensorGetFlatArgs(key, false);
      List reply =
          sendCommand(conn, Command.TENSOR_GET, args.toArray(new byte[args.size()][]))
              .getObjectMultiBulkReply();
      if (reply.isEmpty()) {
        return null;
      }
      return Tensor.createTensorFromRespReply(reply);
    }
  }

  /**
   * Direct mapping to AI.MODELSET
   *
   * @param key name of key to store the Model
   * @param backend - the backend for the model. can be one of TF, TFLITE, TORCH or ONNX
   * @param device - the device that will execute the model. can be of CPU or GPU
   * @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow
   *     models)
   * @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow
   *     models)
   * @param modelPath - the file path for the Protobuf-serialized model
   * @return true if Model was properly set in RedisAI server
   */
  public boolean setModel(
      String key,
      Backend backend,
      Device device,
      String[] inputs,
      String[] outputs,
      String modelPath) {

    try {
      byte[] blob = Files.readAllBytes(Paths.get(modelPath));
      Model model = new Model(backend, device, inputs, outputs, blob);
      return setModel(key, model);
    } catch (IOException ex) {
      throw new RedisAIException(ex);
    }
  }

  /**
   * Direct mapping to AI.MODELSET
   *
   * @param key name of key to store the Model
   * @param model Model object
   * @return true if Model was properly set in RedisAI server
   */
  public boolean setModel(String key, Model model) {

    try (Jedis conn = getConnection()) {
      List args = model.getModelSetCommandBytes(key);
      return sendCommand(conn, Command.MODEL_SET, args.toArray(new byte[args.size()][]))
          .getStatusCodeReply()
          .equals("OK");
    } catch (JedisDataException ex) {
      throw new RedisAIException(ex);
    }
  }

  /**
   * Direct mapping to AI.MODELSTORE command.
   *
   * 

{@code AI.MODELSTORE [TAG tag] [BATCHSIZE n [MINBATCHSIZE m]] * [INPUTS ...] [OUTPUTS ...] BLOB } * * @param key name of key to store the Model * @param model Model object * @return true if Model was properly stored in RedisAI server */ public boolean storeModel(String key, Model model) { try (Jedis conn = getConnection()) { List args = model.getModelStoreCommandArgs(key); return sendCommand(conn, Command.MODEL_STORE, args.toArray(new byte[args.size()][])) .getStatusCodeReply() .equals("OK"); } catch (JedisDataException ex) { throw new RedisAIException(ex.getMessage(), ex); } } /** * Direct mapping to AI.MODELGET * * @param key name of key to get the Model from RedisAI server * @return Model * @throws JRedisAIRunTimeException */ public Model getModel(String key) { try (Jedis conn = getConnection()) { List reply = sendCommand( conn, Command.MODEL_GET, SafeEncoder.encode(key), Keyword.META.getRaw(), Keyword.BLOB.getRaw()) .getObjectMultiBulkReply(); if (reply.isEmpty()) { return null; } return Model.createModelFromRespReply(reply); } } /** * Direct mapping to AI.MODELDEL * * @param key name of key to delete the Model * @return true if Model was properly delete in RedisAI server */ public boolean delModel(String key) { try (Jedis conn = getConnection()) { return sendCommand(conn, Command.MODEL_DEL, SafeEncoder.encode(key)) .getStatusCodeReply() .equals("OK"); } catch (JedisDataException ex) { throw new RedisAIException(ex); } } /** * Direct mapping to AI.SCRIPTSET * * @param key name of key to store the Script in RedisAI server * @param device - the device that will execute the model. can be of CPU or GPU * @param scriptFile - the file path for the script source code * @return true if Script was properly set in RedisAI server */ public boolean setScriptFile(String key, Device device, String scriptFile) { try { Script script = new Script(device, Paths.get(scriptFile)); return setScript(key, script); } catch (IOException ex) { throw new RedisAIException(ex); } } /** * Direct mapping to AI.SCRIPTSET * * @param key name of key to store the Script in RedisAI server * @param device - the device that will execute the model. can be of CPU or GPU * @param source - the script source code * @return true if Script was properly set in RedisAI server */ public boolean setScript(String key, Device device, String source) { Script script = new Script(device, source); return setScript(key, script); } /** * Direct mapping to AI.SCRIPTSET * * @param key name of key to store the Script in RedisAI server * @param script the Script Object * @return true if Script was properly set in RedisAI server */ public boolean setScript(String key, Script script) { try (Jedis conn = getConnection()) { List args = script.getScriptSetCommandBytes(key); return sendCommand(conn, Command.SCRIPT_SET, args.toArray(new byte[args.size()][])) .getStatusCodeReply() .equals("OK"); } catch (JedisDataException ex) { throw new RedisAIException(ex); } } /** * Direct mapping to AI.MODELSTORE command. * *

{@code AI.SCRIPTSTORE [TAG tag] ENTRY_POINTS * [...] SOURCE "