All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.redislabs.redisai.Tensor Maven / Gradle / Ivy

The newest version!
package com.redislabs.redisai;

import com.redislabs.redisai.exceptions.JRedisAIRunTimeException;
import java.util.ArrayList;
import java.util.List;
import redis.clients.jedis.Protocol;
import redis.clients.jedis.util.SafeEncoder;

public class Tensor {
  private DataType dataType;
  private long[] shape;
  private Object values;

  /**
   * @param dataType
   * @param shape
   * @param values
   */
  public Tensor(DataType dataType, long[] shape, Object values) {
    this.shape = shape;
    this.values = values;
    this.dataType = dataType;
  }

  /**
   * Given a RESP reply from RedisAI this method will create a new Tensor
   *
   * @param reply reply from RedisAI Server
   * @return Tensor Object
   */
  protected static Tensor createTensorFromRespReply(List reply) {
    DataType dtype = null;
    long[] shape = null;
    Object values = null;
    Tensor tensor = null;
    for (int i = 0; i < reply.size(); i += 2) {
      String arrayKey = SafeEncoder.encode((byte[]) reply.get(i));
      switch (arrayKey) {
        case "dtype":
          String dtypeString = SafeEncoder.encode((byte[]) reply.get(i + 1));
          dtype = DataType.getDataTypefromString(dtypeString);
          if (dtype == null) {
            throw new JRedisAIRunTimeException("Unrecognized datatype: " + dtypeString);
          }
          break;
        case "shape":
          List shapeResp = (List) reply.get(i + 1);
          shape = new long[shapeResp.size()];
          for (int j = 0; j < shapeResp.size(); j++) {
            shape[j] = shapeResp.get(j);
          }
          break;
        case "values":
          if (dtype == null) {
            throw new JRedisAIRunTimeException(
                "Trying to decode values array without previous datatype info");
          }
          List valuesEncoded = (List) reply.get(i + 1);
          values = dtype.toObject(valuesEncoded);
          break;
        default:
          break;
      }
    }
    if (dtype != null && shape != null && values != null) {
      tensor = new Tensor(dtype, shape, values);
    } else {
      throw new JRedisAIRunTimeException(
          "AI.TENSORGET reply did not contained all elements to build the tensor");
    }
    return tensor;
  }

  public Object getValues() {
    return values;
  }

  public void setValues(Object values) {
    this.values = values;
  }

  public long[] getShape() {
    return shape;
  }

  public DataType getDataType() {
    return dataType;
  }

  public void setDataType(DataType dataType) {
    this.dataType = dataType;
  }

  /**
   * Encodes the current model properties into an AI.MODELSET command to be store in RedisAI Server
   *
   * @param key name of key to store the Model
   * @return
   */
  protected List tensorSetFlatArgs(String key, boolean includeCommandName) {
    List args = new ArrayList<>();
    if (includeCommandName) {
      args.add(Command.TENSOR_SET.getRaw());
    }
    args.add(SafeEncoder.encode(key));
    args.add(dataType.getRaw());
    for (long shapeDimension : shape) {
      args.add(Protocol.toByteArray(shapeDimension));
    }
    args.add(Keyword.VALUES.getRaw());
    args.addAll(dataType.toByteArray(values, shape));
    return args;
  }

  protected static List tensorGetFlatArgs(String key, boolean includeCommandName) {
    List args = new ArrayList<>();
    if (includeCommandName) {
      args.add(Command.TENSOR_GET.getRaw());
    }
    args.add(SafeEncoder.encode(key));
    args.add(Keyword.META.getRaw());
    args.add(Keyword.VALUES.getRaw());
    return args;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy