All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.spotify.zoltar.DefaultPredictorBuilder Maven / Gradle / Ivy

There is a newer version: 0.6.0
Show newest version
/*-
 * -\-\-
 * zoltar-core
 * --
 * Copyright (C) 2016 - 2018 Spotify AB
 * --
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 * -/-/-
 */

package com.spotify.zoltar;

import com.google.auto.value.AutoValue;
import com.spotify.zoltar.FeatureExtractFns.ExtractFn;
import com.spotify.zoltar.PredictFns.AsyncPredictFn;
import com.spotify.zoltar.PredictFns.PredictFn;

/**
 * Entry point for prediction. Default implementation of a PredictorBuilder that holds the necessary
 * info to build a {@link Predictor}.
 *
 * @param   underlying type of the {@link Model}.
 * @param   type of the input to the {@link FeatureExtractor}.
 * @param  type of the output from {@link FeatureExtractor}.
 * @param   type of the prediction result.
 */
@AutoValue
abstract class DefaultPredictorBuilder, InputT, VectorT, ValueT>
    implements PredictorBuilder {

  /**
   * Returns a context given a {@link Model}, {@link FeatureExtractor} and a {@link PredictFn}.
   *
   * @param modelLoader model loader that loads the model to perform prediction on.
   * @param extractFn   a feature extract function to use to transform input into extracted
   *                    features.
   * @param predictFn   a prediction function to perform prediction with {@link PredictFn}.
   * @param     underlying type of the {@link Model}.
   * @param     type of the input to the {@link FeatureExtractor}.
   * @param    type of the output from {@link FeatureExtractor}.
   * @param     type of the prediction result.
   */
  @SuppressWarnings("checkstyle:LineLength")
  public static , InputT, VectorT, ValueT> DefaultPredictorBuilder create(
      final ModelLoader modelLoader,
      final ExtractFn extractFn,
      final PredictFn predictFn) {
    return create(modelLoader, FeatureExtractor.create(extractFn), AsyncPredictFn.lift(predictFn));
  }

  /**
   * Returns a context given a {@link Model}, {@link FeatureExtractor} and a {@link
   * AsyncPredictFn}.
   *
   * @param modelLoader model loader that loads the model to perform prediction on.
   * @param extractFn   a feature extract function to use to transform input into extracted
   *                    features.
   * @param predictFn   a prediction function to perform prediction with {@link AsyncPredictFn}.
   * @param     underlying type of the {@link Model}.
   * @param     type of the input to the {@link FeatureExtractor}.
   * @param    type of the output from {@link FeatureExtractor}.
   * @param     type of the prediction result.
   */
  @SuppressWarnings("checkstyle:LineLength")
  public static , InputT, VectorT, ValueT> DefaultPredictorBuilder create(
      final ModelLoader modelLoader,
      final ExtractFn extractFn,
      final AsyncPredictFn predictFn) {
    return create(modelLoader, FeatureExtractor.create(extractFn), predictFn);
  }

  /**
   * Returns a context given a {@link Model}, {@link FeatureExtractor} and a {@link PredictFn}.
   *
   * @param modelLoader      model loader that loads the model to perform prediction on.
   * @param featureExtractor a feature extractor to use to transform input into extracted features.
   * @param predictFn        a prediction function to perform prediction with {@link PredictFn}.
   * @param          underlying type of the {@link Model}.
   * @param          type of the input to the {@link FeatureExtractor}.
   * @param         type of the output from {@link FeatureExtractor}.
   * @param          type of the prediction result.
   */
  @SuppressWarnings("checkstyle:LineLength")
  public static , InputT, VectorT, ValueT> DefaultPredictorBuilder create(
      final ModelLoader modelLoader,
      final FeatureExtractor featureExtractor,
      final PredictFn predictFn) {
    return create(modelLoader, featureExtractor, AsyncPredictFn.lift(predictFn));
  }

  /**
   * Returns a context given a {@link Model}, {@link FeatureExtractor} and a {@link PredictFn}.
   *
   * @param modelLoader      model loader that loads the model to perform prediction on.
   * @param featureExtractor a feature extractor to use to transform input into extracted features.
   * @param predictFn        a prediction function to perform prediction with {@link
   *                         AsyncPredictFn}.
   * @param          underlying type of the {@link Model}.
   * @param          type of the input to the {@link FeatureExtractor}.
   * @param         type of the output from {@link FeatureExtractor}.
   * @param          type of the prediction result.
   */
  @SuppressWarnings("checkstyle:LineLength")
  public static , InputT, VectorT, ValueT> DefaultPredictorBuilder create(
      final ModelLoader modelLoader,
      final FeatureExtractor featureExtractor,
      final AsyncPredictFn predictFn) {
    return new AutoValue_DefaultPredictorBuilder<>(
        modelLoader,
        featureExtractor,
        predictFn,
        DefaultPredictor.create(modelLoader, featureExtractor, predictFn));
  }

  public abstract ModelLoader modelLoader();

  public abstract FeatureExtractor featureExtractor();

  public abstract AsyncPredictFn predictFn();

  public abstract Predictor predictor();

  @Override
  public DefaultPredictorBuilder with(
      final ModelLoader modelLoader,
      final FeatureExtractor featureExtractor,
      final AsyncPredictFn predictFn) {
    return create(modelLoader, featureExtractor, predictFn);
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy