All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.redislabs.redisai.Dag Maven / Gradle / Ivy

The newest version!
package com.redislabs.redisai;

import java.util.ArrayList;
import java.util.List;
import redis.clients.jedis.util.SafeEncoder;

public class Dag implements DagRunCommands {
  private final List> commands = new ArrayList<>();
  private final List tensorgetflag = new ArrayList<>();

  /** Direct acyclic graph of operations to run within RedisAI */
  public Dag() {}

  protected List processDagReply(List reply) {
    List outputList = new ArrayList<>(reply.size());
    for (int i = 0; i < reply.size(); i++) {
      Object obj = reply.get(i);
      // TODO: Should encode 'OK', 'NA', etc. response
      if (obj instanceof Exception) {
        Exception ex = (Exception) obj;
        outputList.add(new RedisAIException(ex.getMessage(), ex));
      } else if (this.tensorgetflag.get(i)) {
        outputList.add(Tensor.createTensorFromRespReply((List) obj));
      } else {
        outputList.add(obj);
      }
    }
    return outputList;
  }

  @Override
  public Dag setTensor(String key, Tensor tensor) {
    List args = tensor.tensorSetFlatArgs(key, true);
    this.commands.add(args);
    this.tensorgetflag.add(false);
    return this;
  }

  @Override
  public Dag getTensor(String key) {
    List args = Tensor.tensorGetFlatArgs(key, true);
    this.commands.add(args);
    this.tensorgetflag.add(true);
    return this;
  }

  @Override
  public Dag runModel(String key, String[] inputs, String[] outputs) {
    List args = Model.modelRunFlatArgs(key, inputs, outputs, true);
    this.commands.add(args);
    this.tensorgetflag.add(false);
    return this;
  }

  @Override
  public Dag executeModel(String key, String[] inputs, String[] outputs) {
    List args = Model.modelExecuteCommandArgs(key, inputs, outputs, -1L, true);
    this.commands.add(args);
    this.tensorgetflag.add(false);
    return this;
  }

  @Override
  public Dag runScript(String key, String function, String[] inputs, String[] outputs) {
    List args = Script.scriptRunFlatArgs(key, function, inputs, outputs, true);
    this.commands.add(args);
    this.tensorgetflag.add(false);
    return this;
  }

  @Override
  public Dag executeScript(
      String key,
      String function,
      List keys,
      List inputs,
      List args,
      List outputs) {
    List binary =
        Script.scriptExecuteFlatArgs(key, function, keys, inputs, args, outputs, -1L, true);
    this.commands.add(binary);
    this.tensorgetflag.add(false);
    return this;
  }

  List dagRunFlatArgs(String[] loadKeys, String[] persistKeys) {
    List args = new ArrayList<>();
    if (loadKeys != null && loadKeys.length > 0) {
      args.add(Keyword.LOAD.getRaw());
      args.add(SafeEncoder.encode(String.valueOf(loadKeys.length)));
      for (String key : loadKeys) {
        args.add(SafeEncoder.encode(key));
      }
    }
    if (persistKeys != null && persistKeys.length > 0) {
      args.add(Keyword.PERSIST.getRaw());
      args.add(SafeEncoder.encode(String.valueOf(persistKeys.length)));
      for (String key : persistKeys) {
        args.add(SafeEncoder.encode(key));
      }
    }
    for (List command : this.commands) {
      args.add(Keyword.PIPE.getRaw());
      args.addAll(command);
    }
    return args;
  }

  List dagExecuteFlatArgs(
      String[] loadTensors, String[] persistTensors, String routingHint) {
    List args = new ArrayList<>();
    if (loadTensors != null && loadTensors.length > 0) {
      args.add(Keyword.LOAD.getRaw());
      args.add(SafeEncoder.encode(String.valueOf(loadTensors.length)));
      for (String key : loadTensors) {
        args.add(SafeEncoder.encode(key));
      }
    }
    if (persistTensors != null && persistTensors.length > 0) {
      args.add(Keyword.PERSIST.getRaw());
      args.add(SafeEncoder.encode(String.valueOf(persistTensors.length)));
      for (String key : persistTensors) {
        args.add(SafeEncoder.encode(key));
      }
    }

    if (routingHint != null) {
      args.add(Keyword.ROUTING.getRaw());
      args.add(SafeEncoder.encode(routingHint));
    }

    for (List command : this.commands) {
      args.add(Keyword.PIPE.getRaw());
      args.addAll(command);
    }
    return args;
  }
}