
smile.data.formula.Formula Maven / Gradle / Ivy
/*
* Copyright (c) 2010-2021 Haifeng Li. All rights reserved.
*
* Smile is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Smile is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Smile. If not, see .
*/
package smile.data.formula;
import java.io.Serial;
import java.io.Serializable;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.type.*;
import smile.data.vector.*;
import smile.math.matrix.Matrix;
/**
* The model fitting formula in a compact symbolic form.
* An expression of the form {@code y ~ model} is interpreted as a
* specification that the response y is modelled by a linear predictor
* specified symbolically by model. Such a model consists of a series
* of terms separated by {@code +} operators. The terms themselves
* consist of variable and factor names separated by {@code ::} operators.
* Such a term is interpreted as the interaction of all the variables and
* factors appearing in the term. The special term {@code "."} means
* all columns not otherwise in the formula in the context of a data frame.
*
* In addition to {@code +} and {@code ::}, a number of other operators
* are useful in model formulae. The {@code &&} operator denotes factor
* crossing: {@code a && b} interpreted as {@code a+b+a::b}. The {@code ^}
* operator indicates crossing to the specified degree. For example
* {@code (a+b+c)^2} is identical to {@code :(a+b+c)*(a+b+c)} which in turn
* expands to a formula containing the main effects for {@code a},
* {@code b} and {@code c} together with their second-order interactions.
* The {@code -} operator removes the specified terms, so that
* {@code (a+b+c)^2 - a::b} is identical to {@code a + b + c + b::c + a::c}.
* It can also be used to remove the intercept term: when fitting a linear model
* {@code y ~ x - 1} specifies a line through the origin. A model with
* no intercept can be also specified as {@code y ~ x + 0}.
*
* While formulae usually involve just variable and factor names, they
* can also involve arithmetic expressions. The formula
* {@code log(y) ~ a + log(x)}, for example, is legal.
*
* Note that the operators {@code ~}, {@code +}, {@code ::}, {@code ^}
* are only available in Scala API.
*
* @author Haifeng Li
*/
public class Formula implements Serializable {
@Serial
private static final long serialVersionUID = 2L;
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(Formula.class);
/** The left-hand side of formula. */
private final Term response;
/** The right-hand side of formula. */
private final Term[] predictors;
/** The formula-schema binding. */
private transient ThreadLocal binding;
/** The formula-schema binding. */
private static class Binding {
/** The input schema. */
StructType inputSchema;
/** The output schema with response variable and predictors. */
StructType yxschema;
/** The output schema with only predictors. */
StructType xschema;
/** The response variable and predictors. */
Feature[] yx;
/** The predictors. */
Feature[] x;
}
/**
* Constructor.
* @param response the left-hand side of formula, i.e. dependent variable.
* @param predictors the right-hand side of formula, i.e. independent/predictor variables.
*/
public Formula(Term response, Term... predictors) {
if (response instanceof Dot || response instanceof FactorCrossing) {
throw new IllegalArgumentException("The response variable cannot be '.' or FactorCrossing.");
}
this.response = response;
this.predictors = predictors;
}
/**
* Returns the predictors.
* @return the predictors.
*/
public Term[] predictors() {
return predictors;
}
/**
* Returns the response term.
* @return the response term.
*/
public Term response() {
return response;
}
@Override
public String toString() {
String r = response == null ? "" : response.toString();
String p = Arrays.stream(predictors).map(predictor -> {
String s = predictor.toString();
if (!s.startsWith("- ")) s = "+ " + s;
return s;
}).collect(Collectors.joining(" "));
if (p.startsWith("+ ")) p = p.substring(2);
return String.format("%s ~ %s", r, p);
}
@Override
public boolean equals(Object o) {
if (!(o instanceof Formula f)) return false;
if (predictors.length != f.predictors.length) return false;
if (!String.valueOf(response).equals(String.valueOf(f.response))) return false;
for (int i = 0; i < predictors.length; i++) {
if (!String.valueOf(predictors[i]).equals(String.valueOf(f.predictors[i]))) return false;
}
return true;
}
/**
* Factory method. The predictors will be all the columns not otherwise
* in the formula in the context of a data frame.
* @param lhs the left-hand side of formula, i.e. dependent variable.
* @return the formula.
*/
public static Formula lhs(String lhs) {
return lhs(new Variable(lhs));
}
/**
* Factory method. The predictors will be all the columns not otherwise
* in the formula in the context of a data frame.
* @param lhs the left-hand side of formula, i.e. dependent variable.
* @return the formula.
*/
public static Formula lhs(Term lhs) {
return new Formula(lhs, new Dot());
}
/**
* Factory method. No response variable.
* @param predictors the right-hand side of formula, i.e. independent/predictor variables.
* @return the formula.
*/
public static Formula rhs(String... predictors) {
return of(null, predictors);
}
/**
* Factory method. No response variable.
* @param predictors the right-hand side of formula, i.e. independent/predictor variables.
* @return the formula.
*/
public static Formula rhs(Term... predictors) {
return new Formula(null, predictors);
}
/**
* Factory method.
* @param response the left-hand side of formula, i.e. dependent variable.
* @param predictors the right-hand side of formula, i.e. independent/predictor variables.
* @return the formula.
*/
public static Formula of(String response, String... predictors) {
return new Formula(
new Variable(response),
Arrays.stream(predictors).map(predictor ->
switch (predictor) {
case "." -> new Dot();
case "1" -> new Intercept(true);
case "0" -> new Intercept(false);
default -> new Variable(predictor);
}
).toArray(Term[]::new)
);
}
/**
* Factory method.
* @param response the left-hand side of formula, i.e. dependent variable.
* @param predictors the right-hand side of formula, i.e. independent/predictor variables.
* @return the formula.
*/
public static Formula of(String response, Term... predictors) {
return new Formula(new Variable(response), predictors);
}
/**
* Factory method.
* @param response the left-hand side of formula, i.e. dependent variable.
* @param predictors the right-hand side of formula, i.e. independent/predictor variables.
* @return the formula.
*/
public static Formula of(Term response, Term... predictors) {
return new Formula(response, predictors);
}
/**
* Parses a formula string.
* @param s the string representation of formula.
* @return the formula.
*/
public static Formula of(String s) {
String[] tokens = s.split("~");
if (tokens.length != 2) {
throw new IllegalArgumentException("Invalid formula: " + s);
}
String lhs = tokens[0].trim();
Term response = lhs.isEmpty() ? null : Terms.$(lhs);
String rhs = tokens[1].trim();
if (rhs.isEmpty()) {
if (response == null) {
throw new IllegalArgumentException("Invalid formula: " + s);
}
return lhs(response);
}
Pattern regex = Pattern.compile("\\)\\d*");
ArrayList predictors = new ArrayList<>();
if (!rhs.startsWith("+") && !rhs.startsWith("-")) {
rhs = "+ " + rhs; // simplify the loop
}
while (!rhs.isEmpty()) {
boolean delete = false;
if (rhs.startsWith("+")) {
rhs = rhs.substring(1).trim();
} else if (rhs.startsWith("-")) {
delete = true;
rhs = rhs.substring(1).trim();
} else {
throw new IllegalArgumentException("Invalid formula: " + s);
}
String item;
if (rhs.startsWith("(")) {
Matcher matcher = regex.matcher(rhs);
if (matcher.find()) {
if (matcher.end() < rhs.length()) {
item = rhs.substring(0, matcher.end());
rhs = rhs.substring(matcher.end()).trim();
} else {
item = rhs;
rhs = "";
}
} else {
throw new IllegalArgumentException("Invalid formula: " + s);
}
} else {
int end = rhs.indexOf(' ', 1);
if (end > 0) {
item = rhs.substring(0, end);
rhs = rhs.substring(end).trim();
} else {
item = rhs;
rhs = "";
}
}
Term term = Terms.$(item);
if (delete) term = Terms.delete(term);
predictors.add(term);
}
return new Formula(response, predictors.toArray(new Term[0]));
}
/**
* Expands the Dot and FactorCrossing terms on the given schema.
* @param inputSchema the schema to expand on
* @return the expanded formula.
*/
public Formula expand(StructType inputSchema) {
Set columns = new HashSet<>();
if (response != null) columns.addAll(response.variables());
Arrays.stream(predictors)
.filter(term -> term instanceof FactorCrossing || term instanceof Variable)
.forEach(term -> columns.addAll(term.variables()));
List rest = Arrays.stream(inputSchema.fields())
.filter(field -> !columns.contains(field.name))
.map(field -> new Variable(field.name))
.toList();
List expanded = new ArrayList<>();
for (Term predictor : predictors) {
if (predictor instanceof Dot) {
expanded.addAll(rest);
} else if (!(predictor instanceof Delete)) {
expanded.addAll(predictor.expand());
}
}
Set deletes = Arrays.stream(predictors)
.filter(predictor -> predictor instanceof Delete)
.flatMap(predictor -> predictor.expand().stream())
.map(term -> term.toString().substring(2)) // Delete starts with "- "
.collect(Collectors.toSet());
expanded.removeIf(term -> deletes.contains(term.toString()));
return new Formula(response, expanded.toArray(new Term[0]));
}
/**
* Binds the formula to a schema and returns the schema of predictors.
* @param inputSchema the schema to bind with
* @return the data structure of output data frame.
*/
public StructType bind(StructType inputSchema) {
if (binding != null && binding.get().inputSchema == inputSchema) {
return binding.get().xschema;
}
Formula formula = expand(inputSchema);
Binding binding = new Binding();
binding.inputSchema = inputSchema;
List features = Arrays.stream(formula.predictors)
.filter(predictor -> !(predictor instanceof Delete) && !(predictor instanceof Intercept))
.flatMap(predictor -> predictor.bind(inputSchema).stream())
.collect(Collectors.toList());
binding.x = features.toArray(new Feature[0]);
binding.xschema = DataTypes.struct(
features.stream()
.map(Feature::field)
.toArray(StructField[]::new)
);
if (response != null) {
try {
features.addAll(0, response.bind(inputSchema));
binding.yx = features.toArray(new Feature[0]);
binding.yxschema = DataTypes.struct(
features.stream()
.map(Feature::field)
.toArray(StructField[]::new)
);
} catch (IllegalArgumentException ex) {
logger.debug("The response variable {} doesn't exist in the schema {}", response, inputSchema);
}
}
this.binding = new ThreadLocal<>() {
protected synchronized Binding initialValue() {
return binding;
}
};
return binding.xschema;
}
/**
* Apply the formula on a tuple to generate the model data.
* @param tuple the input tuple.
* @return the output tuple.
*/
public Tuple apply(Tuple tuple) {
bind(tuple.schema());
Binding binding = this.binding.get();
return new smile.data.AbstractTuple() {
@Override
public StructType schema() {
return binding.yxschema;
}
@Override
public Object get(int i) {
return binding.yx[i].apply(tuple);
}
@Override
public int getInt(int i) {
return binding.yx[i].applyAsInt(tuple);
}
@Override
public long getLong(int i) {
return binding.yx[i].applyAsLong(tuple);
}
@Override
public float getFloat(int i) {
return binding.yx[i].applyAsFloat(tuple);
}
@Override
public double getDouble(int i) {
return binding.yx[i].applyAsDouble(tuple);
}
@Override
public String toString() {
return binding.yxschema.toString(this);
}
};
}
/**
* Apply the formula on a tuple to generate the predictor data.
* @param tuple the input tuple.
* @return the output tuple.
*/
public Tuple x(Tuple tuple) {
bind(tuple.schema());
Binding binding = this.binding.get();
return new smile.data.AbstractTuple() {
@Override
public StructType schema() {
return binding.xschema;
}
@Override
public Object get(int i) {
return binding.x[i].apply(tuple);
}
@Override
public int getInt(int i) {
return binding.x[i].applyAsInt(tuple);
}
@Override
public long getLong(int i) {
return binding.x[i].applyAsLong(tuple);
}
@Override
public float getFloat(int i) {
return binding.x[i].applyAsFloat(tuple);
}
@Override
public double getDouble(int i) {
return binding.x[i].applyAsDouble(tuple);
}
@Override
public String toString() {
return binding.xschema.toString(this);
}
};
}
/**
* Returns a data frame of predictors and optionally response variable
* (if input data frame has the related variable(s)).
*
* @param data The input data frame.
* @return the output data frame.
*/
public DataFrame frame(DataFrame data) {
bind(data.schema());
Binding binding = this.binding.get();
BaseVector[] vectors = Arrays.stream(binding.yx != null ? binding.yx : binding.x)
.map(term -> term.apply(data)).toArray(BaseVector[]::new);
return DataFrame.of(vectors);
}
/**
* Returns a data frame of predictors.
*
* @param data The input data frame.
* @return the data frame of predictors.
*/
public DataFrame x(DataFrame data) {
bind(data.schema());
Binding binding = this.binding.get();
BaseVector[] vectors = Arrays.stream(binding.x)
.map(term -> term.apply(data)).toArray(BaseVector[]::new);
return DataFrame.of(vectors);
}
/**
* Returns true if the formula has the bias term.
* We assume the formula has the bias term if it isn't
* explicitly specified.
* @return true if the formula has the bias term.
*/
private boolean hasBias() {
boolean bias = true;
Optional intercept = Arrays.stream(predictors)
.filter(term -> term instanceof Intercept)
.map(term -> (Intercept) term)
.findAny();
if (intercept.isPresent()) {
bias = intercept.get().bias();
}
return bias;
}
/**
* Returns the design matrix of predictors.
* All categorical variables will be dummy encoded.
* If the formula doesn't have an Intercept term, the bias
* column will be included. Otherwise, it is based on the
* setting of Intercept term.
*
* @param data The input data frame.
* @return the design matrix.
*/
public Matrix matrix(DataFrame data) {
return matrix(data, hasBias());
}
/**
* Returns the design matrix of predictors.
* All categorical variables will be dummy encoded.
* @param data The input data frame.
* @param bias If true, include the bias column.
* @return the design matrix.
*/
public Matrix matrix(DataFrame data, boolean bias) {
return x(data).toMatrix(bias, CategoricalEncoder.DUMMY, null);
}
/**
* Returns the response vector.
* @param data The input data frame.
* @return the response vector.
*/
public BaseVector y(DataFrame data) {
if (response == null) {
throw new UnsupportedOperationException("The formula has no response variable.");
}
bind(data.schema());
Binding binding = this.binding.get();
if (binding.yx == null) {
throw new UnsupportedOperationException("The data has no response variable.");
}
return binding.yx[0].apply(data);
}
/**
* Returns the real-valued response value.
* @param tuple the input tuple.
* @return the response variable.
*/
public double y(Tuple tuple) {
if (response == null) {
throw new UnsupportedOperationException("The formula has no response variable.");
}
bind(tuple.schema());
Binding binding = this.binding.get();
if (binding.yx == null) {
throw new UnsupportedOperationException("The data has no response variable.");
}
return binding.yx[0].applyAsDouble(tuple);
}
/**
* Returns the integer-valued response value.
* @param tuple the input tuple.
* @return the response variable.
*/
public int yint(Tuple tuple) {
if (response == null) {
throw new UnsupportedOperationException("The formula has no response variable.");
}
bind(tuple.schema());
Binding binding = this.binding.get();
if (binding.yx == null) {
throw new UnsupportedOperationException("The data has no response variable.");
}
return binding.yx[0].applyAsInt(tuple);
}
}