All Downloads are FREE. Search and download functionalities are using the official Maven repository.
Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.FeatureImportanceBaseline Maven / Gradle / Ivy
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
public class FeatureImportanceBaseline implements ToXContentObject, Writeable {
private static final String NAME = "feature_importance_baseline";
public static final ParseField BASELINE = new ParseField("baseline");
public static final ParseField CLASSES = new ParseField("classes");
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ConstructingObjectParser LENIENT_PARSER = createParser(true);
public static final ConstructingObjectParser STRICT_PARSER = createParser(false);
@SuppressWarnings("unchecked")
private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new FeatureImportanceBaseline((Double)a[0], (List)a[1]));
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), BASELINE);
parser.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(),
ignoreUnknownFields ? ClassBaseline.LENIENT_PARSER : ClassBaseline.STRICT_PARSER,
CLASSES);
return parser;
}
public static FeatureImportanceBaseline fromXContent(XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}
public final Double baseline;
public final List classBaselines;
public FeatureImportanceBaseline(StreamInput in) throws IOException {
this.baseline = in.readOptionalDouble();
this.classBaselines = in.readList(ClassBaseline::new);
}
public FeatureImportanceBaseline(Double baseline, List classBaselines) {
this.baseline = baseline;
this.classBaselines = classBaselines == null ? Collections.emptyList() : classBaselines;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalDouble(baseline);
out.writeList(classBaselines);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.map(asMap());
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FeatureImportanceBaseline that = (FeatureImportanceBaseline) o;
return Objects.equals(that.baseline, baseline)
&& Objects.equals(classBaselines, that.classBaselines);
}
public Map asMap() {
Map map = new LinkedHashMap<>();
if (baseline != null) {
map.put(BASELINE.getPreferredName(), baseline);
}
if (classBaselines.isEmpty() == false) {
map.put(CLASSES.getPreferredName(), classBaselines.stream().map(ClassBaseline::asMap).collect(Collectors.toList()));
}
return map;
}
@Override
public int hashCode() {
return Objects.hash(baseline, classBaselines);
}
public static class ClassBaseline implements ToXContentObject, Writeable {
private static final String NAME = "feature_importance_class_baseline";
public static final ParseField CLASS_NAME = new ParseField("class_name");
public static final ConstructingObjectParser LENIENT_PARSER = createParser(true);
public static final ConstructingObjectParser STRICT_PARSER = createParser(false);
private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME,
ignoreUnknownFields,
a -> new ClassBaseline(a[0], (double)a[1]));
parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
return p.text();
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
return p.numberValue();
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
return p.booleanValue();
}
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
parser.declareDouble(ConstructingObjectParser.constructorArg(), BASELINE);
return parser;
}
public static ClassBaseline fromXContent(XContentParser parser, boolean lenient) throws IOException {
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
}
public final Object className;
public final double baseline;
public ClassBaseline(StreamInput in) throws IOException {
this.className = in.readGenericValue();
this.baseline = in.readDouble();
}
ClassBaseline(Object className, double baseline) {
this.className = className;
this.baseline = baseline;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeGenericValue(className);
out.writeDouble(baseline);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return builder.map(asMap());
}
private Map asMap() {
Map map = new LinkedHashMap<>();
map.put(CLASS_NAME.getPreferredName(), className);
map.put(BASELINE.getPreferredName(), baseline);
return map;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ClassBaseline that = (ClassBaseline) o;
return Objects.equals(that.className, className) && Objects.equals(baseline, that.baseline);
}
@Override
public int hashCode() {
return Objects.hash(className, baseline);
}
}
}