tech.tablesaw.api.ml.classification.LogisticRegression Maven / Gradle / Ivy
package tech.tablesaw.api.ml.classification;
import com.google.common.base.Preconditions;
import tech.tablesaw.api.BooleanColumn;
import tech.tablesaw.api.CategoryColumn;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.api.NumericColumn;
import tech.tablesaw.api.ShortColumn;
import tech.tablesaw.util.DoubleArrays;
import java.util.SortedSet;
import java.util.TreeSet;
/**
*
*/
public class LogisticRegression extends AbstractClassifier {
private final smile.classification.LogisticRegression classifierModel;
private LogisticRegression(smile.classification.LogisticRegression classifierModel) {
this.classifierModel = classifierModel;
}
public static LogisticRegression learn(ShortColumn labels, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.toIntArray());
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(IntColumn labels, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.data()
.toIntArray());
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(BooleanColumn labels, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.toIntArray());
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(CategoryColumn labels, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.data()
.toIntArray());
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(ShortColumn labels, double lambda, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.toIntArray(),
lambda);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(IntColumn labels, double lambda, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.data()
.toIntArray(), lambda);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(BooleanColumn labels, double lambda, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.toIntArray(),
lambda);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(CategoryColumn labels, double lambda, NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(DoubleArrays.to2dArray(predictors), labels.data()
.toIntArray(), lambda);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(ShortColumn labels,
double lambda,
double tolerance,
int maxIters,
NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(
DoubleArrays.to2dArray(predictors),
labels.toIntArray(),
lambda,
tolerance,
maxIters);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(IntColumn labels,
double lambda,
double tolerance,
int maxIters,
NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(
DoubleArrays.to2dArray(predictors),
labels.data().toIntArray(),
lambda,
tolerance,
maxIters);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(BooleanColumn labels,
double lambda,
double tolerance,
int maxIters,
NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(
DoubleArrays.to2dArray(predictors),
labels.toIntArray(),
lambda,
tolerance,
maxIters);
return new LogisticRegression(classifierModel);
}
public static LogisticRegression learn(CategoryColumn labels,
double lambda,
double tolerance,
int maxIters,
NumericColumn... predictors) {
smile.classification.LogisticRegression classifierModel =
new smile.classification.LogisticRegression(
DoubleArrays.to2dArray(predictors),
labels.data().toIntArray(),
lambda,
tolerance,
maxIters
);
return new LogisticRegression(classifierModel);
}
public int predict(double[] data) {
return classifierModel.predict(data);
}
public ConfusionMatrix predictMatrix(ShortColumn labels, NumericColumn... predictors) {
Preconditions.checkArgument(predictors.length > 0);
SortedSet
© 2015 - 2025 Weber Informatics LLC | Privacy Policy