org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of elasticsearch-rest-high-level-client Show documentation
Show all versions of elasticsearch-rest-high-level-client Show documentation
Elasticsearch subproject :client:rest-high-level
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.client.ml.inference;
import org.elasticsearch.client.ml.inference.preprocessing.CustomWordEmbedding;
import org.elasticsearch.client.ml.inference.preprocessing.FrequencyEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.Multi;
import org.elasticsearch.client.ml.inference.preprocessing.NGram;
import org.elasticsearch.client.ml.inference.preprocessing.OneHotEncoding;
import org.elasticsearch.client.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.client.ml.inference.preprocessing.TargetMeanEncoding;
import org.elasticsearch.client.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.Exponent;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.LogisticRegression;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.OutputAggregator;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode;
import org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedSum;
import org.elasticsearch.client.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
import org.elasticsearch.client.ml.inference.trainedmodel.tree.Tree;
import org.elasticsearch.plugins.spi.NamedXContentProvider;
import org.elasticsearch.xcontent.NamedXContentRegistry;
import org.elasticsearch.xcontent.ParseField;
import java.util.ArrayList;
import java.util.List;
public class MlInferenceNamedXContentProvider implements NamedXContentProvider {
@Override
public List getNamedXContentParsers() {
List namedXContent = new ArrayList<>();
// PreProcessing
namedXContent.add(
new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(OneHotEncoding.NAME), OneHotEncoding::fromXContent)
);
namedXContent.add(
new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(TargetMeanEncoding.NAME), TargetMeanEncoding::fromXContent)
);
namedXContent.add(
new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(FrequencyEncoding.NAME), FrequencyEncoding::fromXContent)
);
namedXContent.add(
new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(CustomWordEmbedding.NAME), CustomWordEmbedding::fromXContent)
);
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(NGram.NAME), NGram::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(PreProcessor.class, new ParseField(Multi.NAME), Multi::fromXContent));
// Model
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Tree.NAME), Tree::fromXContent));
namedXContent.add(new NamedXContentRegistry.Entry(TrainedModel.class, new ParseField(Ensemble.NAME), Ensemble::fromXContent));
namedXContent.add(
new NamedXContentRegistry.Entry(
TrainedModel.class,
new ParseField(LangIdentNeuralNetwork.NAME),
LangIdentNeuralNetwork::fromXContent
)
);
// Inference Config
namedXContent.add(
new NamedXContentRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::fromXContent)
);
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::fromXContent));
// Aggregating output
namedXContent.add(
new NamedXContentRegistry.Entry(OutputAggregator.class, new ParseField(WeightedMode.NAME), WeightedMode::fromXContent)
);
namedXContent.add(
new NamedXContentRegistry.Entry(OutputAggregator.class, new ParseField(WeightedSum.NAME), WeightedSum::fromXContent)
);
namedXContent.add(
new NamedXContentRegistry.Entry(
OutputAggregator.class,
new ParseField(LogisticRegression.NAME),
LogisticRegression::fromXContent
)
);
namedXContent.add(new NamedXContentRegistry.Entry(OutputAggregator.class, new ParseField(Exponent.NAME), Exponent::fromXContent));
return namedXContent;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy