com.spotify.zoltar.DefaultPredictorBuilder Maven / Gradle / Ivy
/*-
* -\-\-
* zoltar-core
* --
* Copyright (C) 2016 - 2018 Spotify AB
* --
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* -/-/-
*/
package com.spotify.zoltar;
import com.google.auto.value.AutoValue;
import com.spotify.zoltar.FeatureExtractFns.ExtractFn;
import com.spotify.zoltar.PredictFns.AsyncPredictFn;
import com.spotify.zoltar.PredictFns.PredictFn;
/**
* Entry point for prediction. Default implementation of a PredictorBuilder that holds the necessary
* info to build a {@link Predictor}.
*
* @param underlying type of the {@link Model}.
* @param type of the input to the {@link FeatureExtractor}.
* @param type of the output from {@link FeatureExtractor}.
* @param type of the prediction result.
*/
@AutoValue
abstract class DefaultPredictorBuilder, InputT, VectorT, ValueT>
implements PredictorBuilder {
/**
* Returns a context given a {@link Model}, {@link FeatureExtractor} and a {@link PredictFn}.
*
* @param modelLoader model loader that loads the model to perform prediction on.
* @param extractFn a feature extract function to use to transform input into extracted
* features.
* @param predictFn a prediction function to perform prediction with {@link PredictFn}.
* @param underlying type of the {@link Model}.
* @param type of the input to the {@link FeatureExtractor}.
* @param type of the output from {@link FeatureExtractor}.
* @param type of the prediction result.
*/
@SuppressWarnings("checkstyle:LineLength")
public static , InputT, VectorT, ValueT> DefaultPredictorBuilder create(
final ModelLoader modelLoader,
final ExtractFn extractFn,
final PredictFn predictFn) {
return create(modelLoader, FeatureExtractor.create(extractFn), AsyncPredictFn.lift(predictFn));
}
/**
* Returns a context given a {@link Model}, {@link FeatureExtractor} and a {@link
* AsyncPredictFn}.
*
* @param modelLoader model loader that loads the model to perform prediction on.
* @param extractFn a feature extract function to use to transform input into extracted
* features.
* @param predictFn a prediction function to perform prediction with {@link AsyncPredictFn}.
* @param underlying type of the {@link Model}.
* @param type of the input to the {@link FeatureExtractor}.
* @param type of the output from {@link FeatureExtractor}.
* @param type of the prediction result.
*/
@SuppressWarnings("checkstyle:LineLength")
public static , InputT, VectorT, ValueT> DefaultPredictorBuilder create(
final ModelLoader modelLoader,
final ExtractFn extractFn,
final AsyncPredictFn predictFn) {
return create(modelLoader, FeatureExtractor.create(extractFn), predictFn);
}
/**
* Returns a context given a {@link Model}, {@link FeatureExtractor} and a {@link PredictFn}.
*
* @param modelLoader model loader that loads the model to perform prediction on.
* @param featureExtractor a feature extractor to use to transform input into extracted features.
* @param predictFn a prediction function to perform prediction with {@link PredictFn}.
* @param underlying type of the {@link Model}.
* @param type of the input to the {@link FeatureExtractor}.
* @param type of the output from {@link FeatureExtractor}.
* @param type of the prediction result.
*/
@SuppressWarnings("checkstyle:LineLength")
public static , InputT, VectorT, ValueT> DefaultPredictorBuilder create(
final ModelLoader modelLoader,
final FeatureExtractor featureExtractor,
final PredictFn predictFn) {
return create(modelLoader, featureExtractor, AsyncPredictFn.lift(predictFn));
}
/**
* Returns a context given a {@link Model}, {@link FeatureExtractor} and a {@link PredictFn}.
*
* @param modelLoader model loader that loads the model to perform prediction on.
* @param featureExtractor a feature extractor to use to transform input into extracted features.
* @param predictFn a prediction function to perform prediction with {@link
* AsyncPredictFn}.
* @param underlying type of the {@link Model}.
* @param type of the input to the {@link FeatureExtractor}.
* @param type of the output from {@link FeatureExtractor}.
* @param type of the prediction result.
*/
@SuppressWarnings("checkstyle:LineLength")
public static , InputT, VectorT, ValueT> DefaultPredictorBuilder create(
final ModelLoader modelLoader,
final FeatureExtractor featureExtractor,
final AsyncPredictFn predictFn) {
return new AutoValue_DefaultPredictorBuilder<>(
modelLoader,
featureExtractor,
predictFn,
DefaultPredictor.create(modelLoader, featureExtractor, predictFn));
}
public abstract ModelLoader modelLoader();
public abstract FeatureExtractor featureExtractor();
public abstract AsyncPredictFn predictFn();
public abstract Predictor predictor();
@Override
public DefaultPredictorBuilder with(
final ModelLoader modelLoader,
final FeatureExtractor featureExtractor,
final AsyncPredictFn predictFn) {
return create(modelLoader, featureExtractor, predictFn);
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy