All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition Maven / Gradle / Ivy
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.inference;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Accountables;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
public class TrainedModelDefinition implements ToXContentObject, Writeable, Accountable {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TrainedModelDefinition.class);
public static final String NAME = "trained_model_definition";
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser LENIENT_PARSER = createParser(true);
public static final ObjectParser STRICT_PARSER = createParser(false);
private static ObjectParser createParser(boolean ignoreUnknownFields) {
ObjectParser parser = new ObjectParser<>(NAME,
ignoreUnknownFields,
TrainedModelDefinition.Builder::builderForParser);
parser.declareNamedObject(TrainedModelDefinition.Builder::setTrainedModel,
(p, c, n) -> ignoreUnknownFields ?
p.namedObject(LenientlyParsedTrainedModel.class, n, null) :
p.namedObject(StrictlyParsedTrainedModel.class, n, null),
TRAINED_MODEL);
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
(p, c, n) -> ignoreUnknownFields ?
p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT) :
p.namedObject(StrictlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT),
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
PREPROCESSORS);
return parser;
}
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}
private final TrainedModel trainedModel;
private final List preProcessors;
private TrainedModelDefinition(TrainedModel trainedModel, List preProcessors) {
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
}
public TrainedModelDefinition(StreamInput in) throws IOException {
this.trainedModel = in.readNamedWriteable(TrainedModel.class);
this.preProcessors = in.readNamedWriteableList(PreProcessor.class);
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeNamedWriteable(trainedModel);
out.writeNamedWriteableList(preProcessors);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
NamedXContentObjectHelper.writeNamedObjects(builder,
params,
false,
TRAINED_MODEL.getPreferredName(),
Collections.singletonList(trainedModel));
NamedXContentObjectHelper.writeNamedObjects(builder,
params,
true,
PREPROCESSORS.getPreferredName(),
preProcessors);
builder.endObject();
return builder;
}
public TrainedModel getTrainedModel() {
return trainedModel;
}
public List getPreProcessors() {
return preProcessors;
}
@Override
public String toString() {
return Strings.toString(this);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition that = (TrainedModelDefinition) o;
return Objects.equals(trainedModel, that.trainedModel) &&
Objects.equals(preProcessors, that.preProcessors);
}
@Override
public int hashCode() {
return Objects.hash(trainedModel, preProcessors);
}
@Override
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(trainedModel);
size += RamUsageEstimator.sizeOfCollection(preProcessors);
return size;
}
@Override
public Collection getChildResources() {
List accountables = new ArrayList<>(preProcessors.size() + 2);
accountables.add(Accountables.namedAccountable("trained_model", trainedModel));
for(PreProcessor preProcessor : preProcessors) {
accountables.add(Accountables.namedAccountable("pre_processor_" + preProcessor.getName(), preProcessor));
}
return accountables;
}
public static class Builder {
private List preProcessors;
private TrainedModel trainedModel;
private boolean processorsInOrder;
private static Builder builderForParser() {
return new Builder(false);
}
private Builder(boolean processorsInOrder) {
this.processorsInOrder = processorsInOrder;
}
public Builder() {
this(true);
}
public Builder(TrainedModelDefinition definition) {
this(true);
this.preProcessors = new ArrayList<>(definition.getPreProcessors());
this.trainedModel = definition.trainedModel;
}
public Builder setPreProcessors(List preProcessors) {
this.preProcessors = preProcessors;
return this;
}
public Builder setTrainedModel(TrainedModel trainedModel) {
this.trainedModel = trainedModel;
return this;
}
private void setProcessorsInOrder(boolean value) {
this.processorsInOrder = value;
}
public TrainedModelDefinition build() {
if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) {
throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects");
}
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
}
}
}