io.gitlab.chaver.mining.rules.problems.AssociationRuleMining Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of data-mining Show documentation
Show all versions of data-mining Show documentation
Data mining with Choco solver
/*
* This file is part of io.gitlab.chaver:data-mining (https://gitlab.com/chaver/data-mining)
*
* Copyright (c) 2022, IMT Atlantique
*
* Licensed under the MIT license.
*
* See LICENSE file in the project root for full license information.
*/
package io.gitlab.chaver.mining.rules.problems;
import io.gitlab.chaver.chocotools.io.JsonResultReader;
import io.gitlab.chaver.chocotools.io.ProblemResult;
import io.gitlab.chaver.chocotools.io.ProblemResultReader;
import io.gitlab.chaver.chocotools.problem.ChocoProblem;
import io.gitlab.chaver.chocotools.problem.SetUpException;
import io.gitlab.chaver.mining.patterns.constraints.CoverClosure;
import io.gitlab.chaver.mining.patterns.constraints.CoverSize;
import io.gitlab.chaver.mining.patterns.constraints.Generator;
import io.gitlab.chaver.mining.patterns.io.DatReader;
import io.gitlab.chaver.mining.patterns.io.Database;
import io.gitlab.chaver.mining.patterns.io.Pattern;
import io.gitlab.chaver.mining.patterns.io.PatternProblemProperties;
import io.gitlab.chaver.mining.rules.io.ArMeasuresView;
import io.gitlab.chaver.mining.rules.io.AssociationRule;
import io.gitlab.chaver.mining.rules.io.RuleType;
import io.gitlab.chaver.mining.rules.measure.RuleMeasure;
import io.gitlab.chaver.mining.rules.search.loop.monitors.ArMonitor;
import org.chocosolver.solver.Model;
import org.chocosolver.solver.Settings;
import org.chocosolver.solver.constraints.Constraint;
import org.chocosolver.solver.expression.discrete.relational.ReExpression;
import org.chocosolver.solver.search.strategy.selectors.values.IntDomainMin;
import org.chocosolver.solver.search.strategy.selectors.variables.InputOrder;
import org.chocosolver.solver.variables.BoolVar;
import org.chocosolver.solver.variables.IntVar;
import org.chocosolver.util.tools.ArrayUtils;
import picocli.CommandLine;
import picocli.CommandLine.Command;
import picocli.CommandLine.Option;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.text.DecimalFormat;
import java.util.*;
import java.util.stream.Collectors;
import static io.gitlab.chaver.mining.patterns.util.PatternUtil.findClosedPattern;
import static io.gitlab.chaver.mining.rules.measure.SimpleRuleMeasures.*;
import static org.chocosolver.solver.search.strategy.Search.intVarSearch;
@Command(name = "arm", description = "Association rule mining", mixinStandardHelpOptions = true)
public class AssociationRuleMining extends ChocoProblem {
@Option(names = {"-d", "--data"}, required = true, description = "Datafile to use")
private String dataPath;
//@Option(names = {"--nc"}, description = "Ignore classes of transactions")
private boolean noClasses = true;
@Option(names = {"--rt"}, description = "Rule type : ${COMPLETION-CANDIDATES} (default : ${DEFAULT-VALUE})")
private RuleType ruleType = RuleType.ar;
@Option(names = "--fmin", description = "Min frequency of the rule (absolute value)")
private int minFreq;
@Option(names = "--rfmin", description = "Min frequency of the rule (relative value)")
private double relativeMinFreq;
@Option(names = "--cmin", description = "Min confidence of the rule")
private double minConf;
@Option(names = "--sky", description = "Skypatterns file (impose constraint)")
private String skyPath;
@Option(names = "--0a", description = "Items to exclude in the antecedent (path of a file where each line " +
"represents an item to exclude)")
private String zeroItemsAntecedentPath;
@Option(names = "--0c", description = "Items to exclude in the consequent (path of a file where each line " +
"represents an item to exclude)")
private String zeroItemsConsequentPath;
@Option(names = "--or", description = "Items to include in the antecedent or the consequent (path of a file where" +
"each line represents an item to include")
private String orItemsPath;
@Option(names = "--lab", description = "File path with the label of items (each line corresponds to one item)")
private String labelsPath;
private String[] labels;
private List measures = Arrays.asList(sup, rsup, conf, lift);
private DecimalFormat measureFormat = new DecimalFormat("0.000");
private Database database;
private ArMonitor arMonitor;
private Map, Set> closedPatterns;
private int[] zeroItemsAntecedent = new int[0];
private int[] zeroItemsConsequent = new int[0];
private int[] orItems = new int[0];
private int[] readItems(String path) throws IOException {
Map itemsMap = database.getItemsMap();
return Files.readAllLines(Paths.get(path), StandardCharsets.UTF_8).stream().mapToInt(s -> itemsMap.get(Integer.parseInt(s))).toArray();
}
@Override
public void parseArgs() throws SetUpException {
try {
database = new DatReader(dataPath, 0, noClasses).readFiles();
if (relativeMinFreq != 0 && minFreq != 0) {
throw new SetUpException("--fmin and --rfmin are mutually exclusive (specify only one)");
}
if (relativeMinFreq != 0) {
minFreq = (int) Math.round(relativeMinFreq * database.getNbTransactions());
}
if ((minFreq == 0 || minConf == 0) && (skyPath == null)) {
throw new SetUpException("You should precise (--fmin and --cmin) or --sky");
}
if (skyPath != null) {
closedPatterns = getClosedPatterns();
}
if (zeroItemsAntecedentPath != null) {
zeroItemsAntecedent = readItems(zeroItemsAntecedentPath);
}
if (zeroItemsConsequentPath != null) {
zeroItemsConsequent = readItems(zeroItemsConsequentPath);
}
if (orItemsPath != null) {
orItems = readItems(orItemsPath);
}
if (labelsPath != null) {
labels = Files.readAllLines(Paths.get(labelsPath), StandardCharsets.UTF_8).toArray(new String[0]);
}
}
catch (IOException e) {
throw new SetUpException(e.getMessage());
}
}
/**
* If skypattern constraint, associate each skypattern to its closure
* @return a map which associates each skypattern to its closure
* @throws IOException if the skypattern file doesn't exist
*/
private Map, Set> getClosedPatterns() throws IOException {
if (skyPath == null) return new HashMap<>();
Map, Set> closedPatterns = new HashMap<>();
ProblemResultReader reader = new JsonResultReader<>(skyPath);
ProblemResult result = reader.readResult(Pattern[].class,
PatternProblemProperties.class);
for (Pattern p : result.getSolutions()) {
Set x = Arrays.stream(p.getItems()).boxed().collect(Collectors.toSet());
Set y = Arrays.stream(findClosedPattern(p, database)).boxed().collect(Collectors.toSet());
closedPatterns.put(x, y);
}
return closedPatterns;
}
private Set patternsUnion(Set> patterns) {
Set union = new HashSet<>();
patterns.forEach(union::addAll);
return union;
}
private BoolVar[] skypatternConstraint(BoolVar[] y, BoolVar[] z) {
if (closedPatterns == null) return new BoolVar[0];
Set patternsUnion = patternsUnion(new HashSet<>(closedPatterns.values()));
BoolVar[] skyVar = new BoolVar[closedPatterns.size()];
Map itemsMap = database.getItemsMap();
for (int i = 0; i < database.getNbItems(); i++) {
if (!patternsUnion.contains(database.getItems()[i])) {
model.arithm(z[i], "=", 0).post();
}
}
int i = 0;
//System.out.println(closedPatterns);
for (Map.Entry, Set> entry : closedPatterns.entrySet()) {
ReExpression sky = null;
Set skypattern = entry.getKey();
Set closedPattern = entry.getValue();
Set itemsOnlyInClosure = new HashSet<>(closedPattern);
itemsOnlyInClosure.removeAll(skypattern);
for (int item : patternsUnion) {
ReExpression temp;
int idx = itemsMap.get(item);
if (!closedPattern.contains(item)) {
temp = z[idx].eq(0);
}
else if (itemsOnlyInClosure.contains(item)) {
temp = y[idx].eq(1);
}
else {
temp = z[idx].eq(1);
}
sky = sky == null ? temp : sky.and(temp);
}
skyVar[i] = sky.boolVar();
i++;
}
model.sum(skyVar, ">=", 1).post();
return skyVar;
}
private void zeroItemsConstraint(BoolVar[] items, int[] zeroItems) {
Arrays.stream(zeroItems).forEach(i -> items[i].eq(0).post());
}
private void orItemsConstraint(BoolVar[] z) {
if (orItems.length == 0) return;
BoolVar[] orItemVars = Arrays.stream(orItems).mapToObj(i -> z[i]).toArray(BoolVar[]::new);
model.or(orItemVars).post();
}
@Override
public void buildModel() {
BoolVar[] x = model.boolVarArray("x", database.getNbItems());
BoolVar[] y = model.boolVarArray("y", database.getNbItems());
zeroItemsConstraint(x, zeroItemsAntecedent);
zeroItemsConstraint(y, zeroItemsConsequent);
BoolVar[] z = model.boolVarArray("z", database.getNbItems());
for (int i = 0; i < database.getNbItems(); i++) {
model.arithm(x[i], "+", y[i], "<=", 1).post();
model.addClausesBoolOrEqVar(x[i], y[i], z[i]);
}
orItemsConstraint(z);
model.addClausesBoolOrArrayEqualTrue(x);
model.addClausesBoolOrArrayEqualTrue(y);
IntVar freqZ = model.intVar("freqZ", minFreq, database.getNbTransactions());
new Constraint("frequent Z", new CoverSize(database, freqZ, z)).post();
IntVar freqX = model.intVar("freqX", minFreq, database.getNbTransactions());
new Constraint("frequent X", new CoverSize(database, freqX, x)).post();
if (minConf > 0) freqZ.mul(10000).ge(freqX.mul((int) Math.round(minConf * 10000))).post();
IntVar freqY = model.intVar("freqY", minFreq, database.getNbTransactions());
new Constraint("frequent Y", new CoverSize(database, freqY, y)).post();
if (ruleType.equals(RuleType.mnr)) {
new Constraint("generator x", new Generator(database, x))
.post();
new Constraint("closed z", new CoverClosure(database, z)).post();
}
BoolVar[] skyVars = skypatternConstraint(y, z);
BoolVar[] heuristicVars = ArrayUtils.append(skyVars, x, y, z);
model.getSolver().setSearch(intVarSearch(
new InputOrder<>(model),
new IntDomainMin(),
heuristicVars
));
arMonitor = new ArMonitor(database, x, y, freqX, freqY, freqZ);
model.getSolver().plugMonitor(arMonitor);
}
@Override
protected Model createModel() {
return new Model("AR mining", Settings.prod());
}
@Override
public List getSolutions() {
return arMonitor.getAssociationRules();
}
@Override
public ArMeasuresView getProperties() {
ArMeasuresView measures = new ArMeasuresView(solver.getMeasures());
measures.setNbTransactions(database.getNbTransactions());
return measures;
}
@Override
protected void printSolutions() {
getSolutions().forEach(s -> System.out.println(s.toString(database, labels, measures, measureFormat)));
}
public static void main(String[] args) throws Exception {
new CommandLine(new AssociationRuleMining()).execute(args);
}
}