Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.elasticsearch.xpack.core.ml.dataframe.stats.regression.Hyperparameters Maven / Gradle / Ivy
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.core.ml.dataframe.stats.regression;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import java.io.IOException;
import java.util.Objects;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
public class Hyperparameters implements ToXContentObject, Writeable {
public static final ParseField ALPHA = new ParseField("alpha");
public static final ParseField DOWNSAMPLE_FACTOR = new ParseField("downsample_factor");
public static final ParseField ETA = new ParseField("eta");
public static final ParseField ETA_GROWTH_RATE_PER_TREE = new ParseField("eta_growth_rate_per_tree");
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
public static final ParseField GAMMA = new ParseField("gamma");
public static final ParseField LAMBDA = new ParseField("lambda");
public static final ParseField MAX_ATTEMPTS_TO_ADD_TREE = new ParseField("max_attempts_to_add_tree");
public static final ParseField MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER = new ParseField(
"max_optimization_rounds_per_hyperparameter");
public static final ParseField MAX_TREES = new ParseField("max_trees");
public static final ParseField NUM_FOLDS = new ParseField("num_folds");
public static final ParseField NUM_SPLITS_PER_FEATURE = new ParseField("num_splits_per_feature");
public static final ParseField SOFT_TREE_DEPTH_LIMIT = new ParseField("soft_tree_depth_limit");
public static final ParseField SOFT_TREE_DEPTH_TOLERANCE = new ParseField("soft_tree_depth_tolerance");
public static Hyperparameters fromXContent(XContentParser parser, boolean ignoreUnknownFields) {
return createParser(ignoreUnknownFields).apply(parser, null);
}
private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) {
ConstructingObjectParser parser = new ConstructingObjectParser<>("regression_hyperparameters",
ignoreUnknownFields,
a -> new Hyperparameters(
(double) a[0],
(double) a[1],
(double) a[2],
(double) a[3],
(double) a[4],
(double) a[5],
(double) a[6],
(int) a[7],
(int) a[8],
(int) a[9],
(int) a[10],
(int) a[11],
(double) a[12],
(double) a[13]
));
parser.declareDouble(constructorArg(), ALPHA);
parser.declareDouble(constructorArg(), DOWNSAMPLE_FACTOR);
parser.declareDouble(constructorArg(), ETA);
parser.declareDouble(constructorArg(), ETA_GROWTH_RATE_PER_TREE);
parser.declareDouble(constructorArg(), FEATURE_BAG_FRACTION);
parser.declareDouble(constructorArg(), GAMMA);
parser.declareDouble(constructorArg(), LAMBDA);
parser.declareInt(constructorArg(), MAX_ATTEMPTS_TO_ADD_TREE);
parser.declareInt(constructorArg(), MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER);
parser.declareInt(constructorArg(), MAX_TREES);
parser.declareInt(constructorArg(), NUM_FOLDS);
parser.declareInt(constructorArg(), NUM_SPLITS_PER_FEATURE);
parser.declareDouble(constructorArg(), SOFT_TREE_DEPTH_LIMIT);
parser.declareDouble(constructorArg(), SOFT_TREE_DEPTH_TOLERANCE);
return parser;
}
private final double alpha;
private final double downsampleFactor;
private final double eta;
private final double etaGrowthRatePerTree;
private final double featureBagFraction;
private final double gamma;
private final double lambda;
private final int maxAttemptsToAddTree;
private final int maxOptimizationRoundsPerHyperparameter;
private final int maxTrees;
private final int numFolds;
private final int numSplitsPerFeature;
private final double softTreeDepthLimit;
private final double softTreeDepthTolerance;
public Hyperparameters(double alpha,
double downsampleFactor,
double eta,
double etaGrowthRatePerTree,
double featureBagFraction,
double gamma,
double lambda,
int maxAttemptsToAddTree,
int maxOptimizationRoundsPerHyperparameter,
int maxTrees,
int numFolds,
int numSplitsPerFeature,
double softTreeDepthLimit,
double softTreeDepthTolerance) {
this.alpha = alpha;
this.downsampleFactor = downsampleFactor;
this.eta = eta;
this.etaGrowthRatePerTree = etaGrowthRatePerTree;
this.featureBagFraction = featureBagFraction;
this.gamma = gamma;
this.lambda = lambda;
this.maxAttemptsToAddTree = maxAttemptsToAddTree;
this.maxOptimizationRoundsPerHyperparameter = maxOptimizationRoundsPerHyperparameter;
this.maxTrees = maxTrees;
this.numFolds = numFolds;
this.numSplitsPerFeature = numSplitsPerFeature;
this.softTreeDepthLimit = softTreeDepthLimit;
this.softTreeDepthTolerance = softTreeDepthTolerance;
}
public Hyperparameters(StreamInput in) throws IOException {
this.alpha = in.readDouble();
this.downsampleFactor = in.readDouble();
this.eta = in.readDouble();
this.etaGrowthRatePerTree = in.readDouble();
this.featureBagFraction = in.readDouble();
this.gamma = in.readDouble();
this.lambda = in.readDouble();
this.maxAttemptsToAddTree = in.readVInt();
this.maxOptimizationRoundsPerHyperparameter = in.readVInt();
this.maxTrees = in.readVInt();
this.numFolds = in.readVInt();
this.numSplitsPerFeature = in.readVInt();
this.softTreeDepthLimit = in.readDouble();
this.softTreeDepthTolerance = in.readDouble();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeDouble(alpha);
out.writeDouble(downsampleFactor);
out.writeDouble(eta);
out.writeDouble(etaGrowthRatePerTree);
out.writeDouble(featureBagFraction);
out.writeDouble(gamma);
out.writeDouble(lambda);
out.writeVInt(maxAttemptsToAddTree);
out.writeVInt(maxOptimizationRoundsPerHyperparameter);
out.writeVInt(maxTrees);
out.writeVInt(numFolds);
out.writeVInt(numSplitsPerFeature);
out.writeDouble(softTreeDepthLimit);
out.writeDouble(softTreeDepthTolerance);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ALPHA.getPreferredName(), alpha);
builder.field(DOWNSAMPLE_FACTOR.getPreferredName(), downsampleFactor);
builder.field(ETA.getPreferredName(), eta);
builder.field(ETA_GROWTH_RATE_PER_TREE.getPreferredName(), etaGrowthRatePerTree);
builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction);
builder.field(GAMMA.getPreferredName(), gamma);
builder.field(LAMBDA.getPreferredName(), lambda);
builder.field(MAX_ATTEMPTS_TO_ADD_TREE.getPreferredName(), maxAttemptsToAddTree);
builder.field(MAX_OPTIMIZATION_ROUNDS_PER_HYPERPARAMETER.getPreferredName(), maxOptimizationRoundsPerHyperparameter);
builder.field(MAX_TREES.getPreferredName(), maxTrees);
builder.field(NUM_FOLDS.getPreferredName(), numFolds);
builder.field(NUM_SPLITS_PER_FEATURE.getPreferredName(), numSplitsPerFeature);
builder.field(SOFT_TREE_DEPTH_LIMIT.getPreferredName(), softTreeDepthLimit);
builder.field(SOFT_TREE_DEPTH_TOLERANCE.getPreferredName(), softTreeDepthTolerance);
builder.endObject();
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Hyperparameters that = (Hyperparameters) o;
return alpha == that.alpha
&& downsampleFactor == that.downsampleFactor
&& eta == that.eta
&& etaGrowthRatePerTree == that.etaGrowthRatePerTree
&& featureBagFraction == that.featureBagFraction
&& gamma == that.gamma
&& lambda == that.lambda
&& maxAttemptsToAddTree == that.maxAttemptsToAddTree
&& maxOptimizationRoundsPerHyperparameter == that.maxOptimizationRoundsPerHyperparameter
&& maxTrees == that.maxTrees
&& numFolds == that.numFolds
&& numSplitsPerFeature == that.numSplitsPerFeature
&& softTreeDepthLimit == that.softTreeDepthLimit
&& softTreeDepthTolerance == that.softTreeDepthTolerance;
}
@Override
public int hashCode() {
return Objects.hash(
alpha,
downsampleFactor,
eta,
etaGrowthRatePerTree,
featureBagFraction,
gamma,
lambda,
maxAttemptsToAddTree,
maxOptimizationRoundsPerHyperparameter,
maxTrees,
numFolds,
numSplitsPerFeature,
softTreeDepthLimit,
softTreeDepthTolerance
);
}
}