All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.redislabs.redisai.Script Maven / Gradle / Ivy

The newest version!
package com.redislabs.redisai;

import com.redislabs.redisai.exceptions.JRedisAIRunTimeException;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import redis.clients.jedis.BuilderFactory;
import redis.clients.jedis.Protocol;
import redis.clients.jedis.util.SafeEncoder;

public class Script {

  /** the device that will execute the model. can be of CPU or GPU */
  private Device device; // TODO: final

  /** a string containing TorchScript source code */
  private String source; // TODO: final

  /**
   * tag is an optional string for tagging the model such as a version number or any arbitrary
   * identifier
   */
  private String tag;

  private List entryPoints;

  /** @param device the device that will execute the model. can be of CPU or GPU */
  @Deprecated
  public Script(Device device) {
    this(device, "");
  }

  /**
   * @param device the device that will execute the model. can be of CPU or GPU
   * @param source a string containing TorchScript source code
   */
  public Script(Device device, String source) {
    this.device = device;
    this.source = source;
  }

  /**
   * Constructor given the device string and the Path containing the script
   *
   * @param device the device that will execute the model. can be of CPU or GPU
   * @param filePath file path to load the script from
   * @throws java.io.IOException
   */
  public Script(Device device, Path filePath) throws IOException {
    this(device, fileContent(filePath));
  }

  private static String fileContent(Path filePath) throws IOException {
    return Files.readAllLines(filePath, StandardCharsets.UTF_8).stream()
            .collect(Collectors.joining("\n"))
        + "\n";
  }

  public static Script createScriptFromRespReply(List reply) {
    Device device = null;
    String tag = null;
    String source = null;
    List entryPoints = null;
    for (int i = 0; i < reply.size(); i += 2) {
      String mapKey = SafeEncoder.encode((byte[]) reply.get(i));
      Object mapVal = reply.get(i + 1);
      switch (mapKey) {
        case "source":
          source = BuilderFactory.STRING.build(mapVal);
          break;
        case "device":
          device = Device.valueOf(BuilderFactory.STRING.build(mapVal));
          break;
        case "tag":
          tag = BuilderFactory.STRING.build(mapVal);
          break;
        case "Entry Points":
          entryPoints = BuilderFactory.STRING_LIST.build(mapVal);
          break;
        default:
          break;
      }
    }
    if (device != null && source != null) {
      return new Script(device, source).setTag(tag).setEntryPoints(entryPoints);
    }
    throw new JRedisAIRunTimeException(
        "AI.SCRIPTGET reply did not contained all elements to build the script");
  }

  public Device getDevice() {
    return device;
  }

  @Deprecated
  public void setDevice(Device device) {
    this.device = device;
  }

  public String getSource() {
    return source;
  }

  @Deprecated
  public void setSource(String source) {
    this.source = source;
  }

  public String getTag() {
    return tag;
  }

  public Script setTag(String tag) {
    this.tag = tag;
    return this;
  }

  public List getEntryPoints() {
    return entryPoints;
  }

  public Script setEntryPoints(List entryPoints) {
    this.entryPoints = entryPoints;
    return this;
  }

  public Script setEntryPoints(String... entryPoints) {
    return setEntryPoints(Arrays.asList(entryPoints));
  }

  /**
   * Encodes the current script into an AI.SCRIPTSET command to be store in RedisAI Server
   *
   * @param key name of key to store the Script
   * @return
   */
  protected List getScriptSetCommandBytes(String key) {
    List args = new ArrayList<>();
    args.add(SafeEncoder.encode(key));
    args.add(device.getRaw());
    if (tag != null) {
      args.add(Keyword.TAG.getRaw());
      args.add(SafeEncoder.encode(tag));
    }
    args.add(Keyword.SOURCE.getRaw());
    args.add(SafeEncoder.encode(source));
    return args;
  }

  /**
   * Prepare AI.SCRIPTSTORE command arguments
   *
   * @param key name of key to store the Script
   * @return
   */
  protected List getScriptStoreCommandBytes(String key) {
    List args = new ArrayList<>();
    args.add(key);
    args.add(device.name());
    if (tag != null) {
      args.add(Keyword.TAG.name());
      args.add(tag);
    }
    if (entryPoints != null && !entryPoints.isEmpty()) {
      args.add(Keyword.ENTRY_POINTS.name());
      args.add(Integer.toString(entryPoints.size()));
      args.addAll(entryPoints);
    }
    args.add(Keyword.SOURCE.name());
    args.add(source);
    return args;
  }

  /**
   * sets the Script source give a filePath
   *
   * @param filePath
   * @throws IOException
   * @deprecated Use {@link #Script(com.redislabs.redisai.Device, java.nio.file.Path)}.
   */
  @Deprecated
  public void readSourceFromFile(String filePath) throws IOException {
    this.source = fileContent(Paths.get(filePath));
  }

  protected static List scriptRunFlatArgs(
      String key, String function, String[] inputs, String[] outputs, boolean includeCommandName) {
    List args = new ArrayList<>();
    if (includeCommandName) {
      args.add(Command.SCRIPT_RUN.getRaw());
    }
    args.add(SafeEncoder.encode(key));
    args.add(SafeEncoder.encode(function));
    args.add(Keyword.INPUTS.getRaw());
    for (String input : inputs) {
      args.add(SafeEncoder.encode(input));
    }

    args.add(Keyword.OUTPUTS.getRaw());
    for (String output : outputs) {
      args.add(SafeEncoder.encode(output));
    }
    return args;
  }

  protected static List scriptExecuteFlatArgs(
      String key,
      String function,
      List keys,
      List inputs,
      List args,
      List outputs,
      long timeout,
      boolean includeCommandName) {
    List binary = new ArrayList<>();
    if (includeCommandName) {
      binary.add(Command.SCRIPT_EXECUTE.getRaw());
    }

    binary.add(SafeEncoder.encode(key));
    binary.add(SafeEncoder.encode(function));
    variadicArgumentsCheckAndAddWithCount(binary, Keyword.KEYS, keys);
    variadicArgumentsCheckAndAddWithCount(binary, Keyword.INPUTS, inputs);
    variadicArgumentsCheckAndAddWithCount(binary, Keyword.ARGS, args);
    variadicArgumentsCheckAndAddWithCount(binary, Keyword.OUTPUTS, outputs);
    if (timeout >= 0) {
      binary.add(Keyword.TIMEOUT.getRaw());
      binary.add(Protocol.toByteArray(timeout));
    }

    return binary;
  }

  private static void variadicArgumentsCheckAndAddWithCount(
      List arguments, Keyword keyword, List values) {
    if (values == null || values.isEmpty()) return;
    arguments.add(keyword.getRaw());
    arguments.add(Protocol.toByteArray(values.size()));
    values.forEach(v -> arguments.add(SafeEncoder.encode(v)));
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy