All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.gitlab.chaver.mining.patterns.problems.PatternProblem Maven / Gradle / Ivy

There is a newer version: 1.0.1
Show newest version
/*
 * This file is part of io.gitlab.chaver:data-mining (https://gitlab.com/chaver/data-mining)
 *
 * Copyright (c) 2022, IMT Atlantique
 *
 * Licensed under the MIT license.
 *
 * See LICENSE file in the project root for full license information.
 */
package io.gitlab.chaver.mining.patterns.problems;

import io.gitlab.chaver.chocotools.problem.BuildModelException;
import io.gitlab.chaver.chocotools.problem.ChocoProblem;
import io.gitlab.chaver.chocotools.problem.SetUpException;
import io.gitlab.chaver.chocotools.search.loop.monitors.SolutionRecorderMonitor;
import io.gitlab.chaver.chocotools.util.ISolutionProvider;
import io.gitlab.chaver.mining.patterns.io.DatReader;
import io.gitlab.chaver.mining.patterns.io.Database;
import io.gitlab.chaver.mining.patterns.io.Pattern;
import io.gitlab.chaver.mining.patterns.io.PatternProblemProperties;
import io.gitlab.chaver.mining.patterns.measure.Measure;
import io.gitlab.chaver.mining.patterns.measure.attribute.*;
import io.gitlab.chaver.mining.patterns.measure.operand.MeasureOperand;
import io.gitlab.chaver.mining.patterns.measure.pattern.*;
import io.gitlab.chaver.mining.patterns.search.loop.monitors.SkypatternMonitor;
import io.gitlab.chaver.mining.patterns.search.strategy.selectors.variables.MinCov;
import io.gitlab.chaver.mining.patterns.util.MeasureListConverter;
import io.gitlab.chaver.mining.patterns.util.PatternCreator;
import io.gitlab.chaver.mining.patterns.util.TransactionGetter;
import org.chocosolver.solver.constraints.Constraint;
import org.chocosolver.solver.expression.discrete.relational.ReExpression;
import org.chocosolver.solver.search.strategy.Search;
import org.chocosolver.solver.search.strategy.selectors.values.IntDomainMin;
import org.chocosolver.solver.variables.BoolVar;
import org.chocosolver.solver.variables.IntVar;
import picocli.CommandLine.Option;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static io.gitlab.chaver.mining.patterns.measure.MeasureFactory.*;

public abstract class PatternProblem extends ChocoProblem {

    @Option(names = "-d", required = true, description = "Path of the transactional database")
    private String dataPath;
    @Option(names = "--skym", description = "Skypattern measures", converter = MeasureListConverter.class,
            paramLabel = "")
    private List skypatternMeasures = new LinkedList<>();
    @Option(names = "--clom", description = "Closed pattern measures", converter = MeasureListConverter.class,
            paramLabel = "")
    protected List closedMeasures = new LinkedList<>();
    @Option(names = "--addm", description = "Additional measures", converter = MeasureListConverter.class,
            paramLabel = "")
    private List additionalMeasures = new LinkedList<>();
    @Option(names = {"--nc"}, description = "Ignore class of the transactions")
    private boolean noClasses;
    @Option(names = "--lmin", description = "Min length of the pattern (default : ${DEFAULT-VALUE})",
            defaultValue = "1")
    private int lengthMin;
    @Option(names = "--fmin", description = "Min freq of the pattern (default : ${DEFAULT-VALUE})", defaultValue = "1")
    protected int freqMin;
    @Option(names = "--no-infgr", description = "No infinite growth-rate")
    private boolean noInfiniteGr;
    @Option(names = "--trans", description = "Save transactions of the patterns")
    private boolean saveTrans;
    @Option(names = "--0i", description = "Items to exclude from the mining (path of a file where each line " +
            "represents an item to exclude)")
    private String zeroItemsPath;
    @Option(names = "--ri", description = "Required items : post a constraint such that at least one of these items" +
            " is present in the pattern (path of a file where each line represents an item)")
    private String requiredItemsPath;
    @Option(names = "--lab", description = "File path with the label of items (each line corresponds to one item)")
    private String labelsPath;
    private String[] labels;

    private List allMeasures;
    protected Database database;
    private ISolutionProvider solutionProvider;
    private int[] zeroItems;
    private int[] requiredItems;

    // CP variables
    protected BoolVar[] items;
    protected Map measureVars = new HashMap<>();

    protected void itemVars() {
        items = model.boolVarArray("items", database.getNbItems());
        for (int i = 0; i < database.getNbClass(); i++) {
            model.arithm(items[i], "=", 0).post();
        }
    }

    private void createMeasureVar(Measure m) throws BuildModelException {
        if (measureVars.containsKey(m.getId())) return;
        int num = (m instanceof AttributeMeasure) ? ((AttributeMeasure) m).getNum() : -1;
        if (m.getClass() == Freq.class) freqVar();
        else if (m.getClass() == Length.class) lengthVar();
        else if (m.getClass() == Area.class) areaVar();
        else if (m.getClass() == MaxFreq.class) maxFreqVar();
        else if (m.getClass() == AllConf.class) aconfVar();
        else if (m.getClass() == GrowthRate.class) growthRateVar();
        else if (m.getClass() == Mean.class) meanValueVar(num);
        else if (m.getClass() == Min.class) minValueVar(num);
        else if (m.getClass() == Max.class) maxValueVar(num);
        else throw new BuildModelException("Can't create var for this measure : " + m);
    }

    protected void lengthVar() {
        String lengthId = length().getId();
        IntVar length = model.intVar(lengthId, lengthMin, database.getNbItems());
        model.count(1, items, length).post();
        measureVars.put(lengthId, length);
    }

    protected abstract void freqVar();
    protected abstract void freq1Var();

    protected void freq2Var() {
        String freq2Id = freq2().getId();
        IntVar freq2 = model.intVar(freq2Id, 0, database.getClassCount()[1]);
        IntVar freq = measureVars.get(freq().getId());
        IntVar freq1 = measureVars.get(freq1().getId());
        model.arithm(freq, "-", freq1, "=", freq2).post();
        measureVars.put(freq2Id, freq2);
    }

    protected void areaVar() {
        if (!measureVars.containsKey(freq().getId())) freqVar();
        if (!measureVars.containsKey(length().getId())) lengthVar();
        IntVar freq = measureVars.get(freq().getId());
        IntVar length = measureVars.get(length().getId());
        measureVars.put(area().getId(), freq.mul(length).intVar());
    }

    protected void growthRateVar() {
        if (!measureVars.containsKey(freq1().getId())) freq1Var();
        if (!measureVars.containsKey(freq2().getId())) freq2Var();
        String growthRateId = growthRate().getId();
        IntVar freq1 = measureVars.get(freq1().getId());
        IntVar freq2 = measureVars.get(freq2().getId());
        int d1 = database.getClassCount()[0];
        int d2 = database.getClassCount()[1];
        int grUB = noInfiniteGr ? IntVar.MAX_INT_BOUND - 1 : IntVar.MAX_INT_BOUND;
        IntVar growthRate = model.intVar(growthRateId, 0, grUB);
        // freq1 == 0 && growthRate == 0 (i.e. no items are of class 1)
        ReExpression e1 = (freq1.eq(0)).and(growthRate.eq(0));
        // freq1 > 0 && freq == freq1 && growthRate == MAX_INT_BOUND (i.e. all items are of class 1)
        ReExpression e2 = (freq1.gt(0)).and(freq2.eq(0)).and(growthRate.eq(IntVar.MAX_INT_BOUND));
        // Compute growth rate, if freq == freq1 then we add 1 in (freq - freq1) to avoid 0 in denominator
        IntVar computedGr = freq1.mul(d2).div((freq2.add(freq2.eq(0))).mul(d1)).intVar();
        // freq1 > 0 && freq > freq1 && growthRate == computedGr
        ReExpression e3 = (freq1.gt(0)).and(freq2.gt(0)).and(growthRate.le(computedGr)).and(growthRate.eq(computedGr));
        e1.or(e2).or(e3).post();
        measureVars.put(growthRateId, growthRate);
    }

    protected void maxFreqVar() {
        int[] itemFreq = database.computeItemFreq();
        IntVar[] itemFreqVar = model.intVarArray(database.getNbItems(), 0, database.getNbTransactions());
        for (int i = 0; i < database.getNbItems(); i++) {
            // itemFreqVar[i] = itemFreq[i] if items[i] == 1 else 0
            model.arithm(items[i], "*", model.intVar(itemFreq[i]), "=", itemFreqVar[i]).post();
        }
        String maxFreqId = maxFreq().getId();
        IntVar maxFreq = model.intVar(maxFreqId, 0, database.getNbTransactions());
        // Compute max value of itemFreqVar
        model.max(maxFreq, itemFreqVar).post();
        measureVars.put(maxFreqId, maxFreq);
    }

    protected void aconfVar() {
        if (!measureVars.containsKey(freq().getId())) freqVar();
        if (!measureVars.containsKey(maxFreq().getId())) maxFreqVar();
        IntVar freq = measureVars.get(freq().getId());
        IntVar maxFreq = measureVars.get(maxFreq().getId());
        int coeff = 10000;
        IntVar aconf = model.intVar(allConf().getId(), 0, coeff);
        aconf.eq(freq.mul(coeff).div(maxFreq)).post();
        measureVars.put(allConf().getId(), aconf);
    }

    protected void minValueVar(int num) {
        int[] values = database.getValues()[num];
        int valUB = Arrays.stream(values).max().getAsInt();
        IntVar[] valuesMin = model.intVarArray("valuesMin_" + num, database.getNbItems(), 0, valUB);
        for (int i = 0; i < database.getNbItems(); i++) {
            ReExpression e1 = items[i].eq(0).and(valuesMin[i].eq(valUB));
            ReExpression e2 = items[i].eq(1).and(valuesMin[i].eq(values[i]));
            (e1).or(e2).post();
        }
        IntVar minVal = model.intVar("minVal_" + num, 0, valUB);
        model.min(minVal, valuesMin).post();
        measureVars.put(min(num).getId(), minVal);
    }

    protected void maxValueVar(int num) {
        int[] values = database.getValues()[num];
        int valUB = Arrays.stream(values).max().getAsInt();
        IntVar[] valuesMax = model.intVarArray("valuesMax_" + num, database.getNbItems(), 0, valUB);
        for (int i = 0; i < database.getNbItems(); i++) {
            valuesMax[i].eq(items[i].mul(values[i])).post();
        }
        IntVar maxVal = model.intVar("maxVal_" + num, 0, valUB);
        model.max(maxVal, valuesMax).post();
        measureVars.put(max(num).getId(), maxVal);
    }

    protected void meanValueVar(int num) {
        if (!measureVars.containsKey(min(num).getId())) minValueVar(num);
        if (!measureVars.containsKey(max(num).getId())) maxValueVar(num);
        IntVar min = measureVars.get(min(num).getId());
        IntVar max = measureVars.get(max(num).getId());
        measureVars.put(mean(num).getId(), (min.add(max)).div(2).intVar());
    }

    @Override
    protected void parseArgs() throws SetUpException {
        if (closedMeasures.size() == 0 && skypatternMeasures.size() == 0) {
            throw new SetUpException("--skym or --clom must be specified");
        }
        if (closedMeasures.size() > 0 && skypatternMeasures.size() > 0) {
            throw new SetUpException("--skym and --clom can't be specified both");
        }
        if (closedMeasures.size() == 0) {
            closedMeasures = new LinkedList<>(MeasureOperand.maxConvert(skypatternMeasures));
            allMeasures = Stream
                    .of(skypatternMeasures, additionalMeasures)
                    .flatMap(Collection::stream)
                    .distinct()
                    .collect(Collectors.toList());
        }
        if (skypatternMeasures.size() == 0) {
            allMeasures = Stream
                    .of(closedMeasures, additionalMeasures)
                    .flatMap(Collection::stream)
                    .distinct()
                    .collect(Collectors.toList());
        }
        int idxValMeasure = allMeasures
                .stream()
                .filter(m -> m instanceof AttributeMeasure)
                .mapToInt(m -> ((AttributeMeasure) m).getNum())
                .max()
                .orElse(-1);
        try {
            database = new DatReader(dataPath, idxValMeasure + 1, noClasses).readFiles();
            Map itemsMap = database.getItemsMap();
            if (zeroItemsPath != null) {
                zeroItems = Files
                        .readAllLines(Paths.get(zeroItemsPath), StandardCharsets.UTF_8)
                        .stream()
                        .mapToInt(s -> itemsMap.get(Integer.parseInt(s)))
                        .toArray();
            }
            if (requiredItemsPath != null) {
                requiredItems = Files
                        .readAllLines(Paths.get(requiredItemsPath), StandardCharsets.UTF_8)
                        .stream()
                        .mapToInt(s -> itemsMap.get(Integer.parseInt(s)))
                        .toArray();
            }
            if (labelsPath != null) {
                labels = Files.readAllLines(Paths.get(labelsPath), StandardCharsets.UTF_8).toArray(new String[0]);
            }
        } catch (IOException e) {
            throw new SetUpException(e.getMessage(), e);
        }
    }

    private void zeroItemsConstraint() {
        if (zeroItems == null) return;
        Arrays.stream(zeroItems).forEach(i -> items[i].eq(0).post());
    }

    private void requiredItemsConstraint() {
        if (requiredItems == null) return;
        model.or(Arrays.stream(requiredItems).mapToObj(i -> items[i]).toArray(BoolVar[]::new)).post();
    }

    @Override
    public void buildModel() throws BuildModelException {
        itemVars();
        zeroItemsConstraint();
        requiredItemsConstraint();
        freqVar();
        lengthVar();
        for (Measure m : allMeasures) createMeasureVar(m);
        closedConstraint();
        plugSearchMonitor();
        solver.setSearch(Search.intVarSearch(
                new MinCov(model, database),
                new IntDomainMin(),
                items
        ));
    }

    private void plugSearchMonitor() {
        List allMeasuresId = allMeasures.stream().map(Measure::getId).collect(Collectors.toList());
        TransactionGetter transactionGetter = saveTrans ? transactionGetter() : null;
        PatternCreator creator = new PatternCreator(database, items, allMeasuresId, measureVars, transactionGetter);
        if (skypatternMeasures.size() == 0) {
            SolutionRecorderMonitor monitor = new SolutionRecorderMonitor<>(creator);
            solver.plugMonitor(monitor);
            solutionProvider = monitor;
        }
        else {
            IntVar[] obj = skypatternMeasures.stream().map(m -> measureVars.get(m.getId())).toArray(IntVar[]::new);
            SkypatternMonitor monitor = new SkypatternMonitor(obj, creator, false);
            model.post(new Constraint("Pareto", monitor));
            solver.plugMonitor(monitor);
            solutionProvider = monitor;
        }
    }

    protected abstract TransactionGetter transactionGetter();

    protected abstract void closedConstraint() throws BuildModelException;

    @Override
    public List getSolutions() {
        return Objects.isNull(solutionProvider) ? null : solutionProvider.getSolutions();
    }

    @Override
    public PatternProblemProperties getProperties() {
        PatternProblemProperties properties = new PatternProblemProperties(solver.getMeasures());
        properties.setBestSolutionCount(getSolutions().size());
        properties.setClosedMeasures(closedMeasures.stream().map(Measure::getId).collect(Collectors.toList()));
        properties.setSkyMeasures(skypatternMeasures.stream().map(Measure::getId).collect(Collectors.toList()));
        properties.setAllMeasures(allMeasures.stream().map(Measure::getId).collect(Collectors.toList()));
        return properties;
    }

    @Override
    protected void printStats() {
        super.printStats();
        if (skypatternMeasures.size() > 0) System.out.println("\tNb skypatterns : " + getSolutions().size());
    }

    @Override
    protected void printSolutions() {
        List allMeasuresId = allMeasures.stream().map(Measure::getId).collect(Collectors.toList());
        for (Pattern p : getSolutions()) {
            System.out.println(p.toString(allMeasuresId, labels, database));
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy