All Downloads are FREE. Search and download functionalities are using the official Maven repository.

clarifai2.api.request.model.PredictRequest Maven / Gradle / Ivy

The newest version!
package clarifai2.api.request.model;

import clarifai2.internal.grpc.api.ConceptOuterClass;
import clarifai2.internal.grpc.api.InputOuterClass;
import clarifai2.internal.grpc.api.ModelOuterClass;
import clarifai2.internal.grpc.api.OutputOuterClass;
import clarifai2.api.BaseClarifaiClient;
import clarifai2.api.request.ClarifaiRequest;
import clarifai2.dto.input.ClarifaiInput;
import clarifai2.dto.model.ModelVersion;
import clarifai2.dto.model.output.ClarifaiOutput;
import clarifai2.dto.prediction.Concept;
import clarifai2.dto.prediction.Prediction;
import com.google.common.util.concurrent.ListenableFuture;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

public final class PredictRequest
    extends ClarifaiRequest.Builder>> {

  @NotNull private final String modelID;
  @NotNull private final List inputData = new ArrayList<>();

  @Nullable private String modelVersionID = null;
  @Nullable private String language = null;

  @Nullable private Double minValue = null;
  @Nullable private Integer maxConcepts = null;

  @Nullable private Integer sampleMs = null;

  @NotNull private final List concepts = new ArrayList<>();

  public PredictRequest(@NotNull final BaseClarifaiClient client, @NotNull String modelID) {
    super(client);
    this.modelID = modelID;
  }

  @NotNull public PredictRequest withInputs(@NotNull ClarifaiInput... inputData) {
    return withInputs(Arrays.asList(inputData));
  }

  @NotNull public PredictRequest withInputs(@NotNull Collection inputData) {
    this.inputData.addAll(inputData);
    return this;
  }

  @NotNull public PredictRequest withVersion(@NotNull ModelVersion version) {
    this.modelVersionID = version.id();
    return this;
  }

  @NotNull public PredictRequest withVersion(@NotNull String versionID) {
    this.modelVersionID = versionID;
    return this;
  }

  @NotNull public PredictRequest withLanguage(@NotNull String language) {
    this.language = language;
    return this;
  }

  @NotNull public final PredictRequest withMinValue(@Nullable Double minValue) {
    this.minValue = minValue;
    return this;
  }

  @NotNull public final PredictRequest withMaxConcepts(@Nullable Integer maxConcepts) {
    this.maxConcepts = maxConcepts;
    return this;
  }

  /**
   * If added, only these concepts will be considered in the prediction.
   * @param concepts the concepts
   * @return PredictRequest instance
   */
  @NotNull public PredictRequest selectConcepts(@NotNull Concept... concepts) {
    return selectConcepts(Arrays.asList(concepts));
  }

  /**
   * See {@link PredictRequest#selectConcepts(Concept...)}.
   * @param concepts the concepts
   * @return PredictRequest instance
   */
  @NotNull public PredictRequest selectConcepts(@NotNull Collection concepts) {
    this.concepts.addAll(concepts);
    return this;
  }

  @NotNull public final PredictRequest withSampleMs(@Nullable Integer sampleMs) {
    this.sampleMs = sampleMs;
    return this;
  }

  @NotNull @Override protected String method() {
    return "POST";
  }

  @NotNull @Override protected String subUrl() {
    if (modelVersionID == null) {
      return "v2/models/" + modelID + "/outputs";
    }
    return "v2/models/" + modelID + "/versions/" + modelVersionID + "/outputs";
  }

  @NotNull @Override protected DeserializedRequest>> request() {
    return new DeserializedRequest>>() {
      @NotNull @Override public ListenableFuture httpRequestGrpc() {
        List inputs = new ArrayList<>();
        for (ClarifaiInput input : inputData) {
          inputs.add(input.serialize());
        }

        boolean anyOutputConfig = false;
        ModelOuterClass.OutputConfig.Builder outputConfigBuilder = ModelOuterClass.OutputConfig.newBuilder();
        if (language != null) {
          outputConfigBuilder.setLanguage(language);
          anyOutputConfig = true;
        }
        if (minValue != null) {
          outputConfigBuilder.setMinValue(minValue.floatValue());
          anyOutputConfig = true;
        }
        if (maxConcepts != null) {
          outputConfigBuilder.setMaxConcepts(maxConcepts);
          anyOutputConfig = true;
        }
        if (sampleMs != null) {
          outputConfigBuilder.setSampleMs(sampleMs);
          anyOutputConfig = true;
        }
        if (!concepts.isEmpty()) {
          List selectConceptsGrpc = new ArrayList<>();
          for (Concept concept : concepts) {
            selectConceptsGrpc.add(concept.serialize());
          }
          outputConfigBuilder.addAllSelectConcepts(selectConceptsGrpc);
          anyOutputConfig = true;
        }
        InputOuterClass.PostModelOutputsRequest.Builder requestBuilder =
            InputOuterClass.PostModelOutputsRequest.newBuilder()
                .addAllInputs(inputs);

        if (anyOutputConfig) {
          requestBuilder.setModel(
              ModelOuterClass.Model.newBuilder()
                  .setOutputInfo(ModelOuterClass.OutputInfo.newBuilder().setOutputConfig(outputConfigBuilder))
          );
        }

        return stub().postModelOutputs(requestBuilder.build());
      }

      @NotNull @Override public List> unmarshalerGrpc(Object returnedObject) {
        OutputOuterClass.MultiOutputResponse response = (OutputOuterClass.MultiOutputResponse) returnedObject;

        List> outputs = new ArrayList<>();
        for (OutputOuterClass.Output output : response.getOutputsList()) {
          outputs.add(ClarifaiOutput.deserialize(output, client));
        }

        return outputs;
      }
    };
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy