All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.lwhite1.tablesaw.api.ml.association.AssociationRuleMining Maven / Gradle / Ivy

package com.github.lwhite1.tablesaw.api.ml.association;

import com.github.lwhite1.tablesaw.api.CategoryColumn;
import com.github.lwhite1.tablesaw.api.FloatColumn;
import com.github.lwhite1.tablesaw.api.ShortColumn;
import com.github.lwhite1.tablesaw.api.Table;
import com.github.lwhite1.tablesaw.table.TemporaryView;
import com.github.lwhite1.tablesaw.table.ViewGroup;
import it.unimi.dsi.fastutil.ints.IntRBTreeSet;
import it.unimi.dsi.fastutil.objects.Object2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.shorts.ShortRBTreeSet;
import smile.association.ARM;
import smile.association.AssociationRule;

import java.util.Arrays;
import java.util.List;

/**
 *
 */
public class AssociationRuleMining {

  private final ARM model;

  public AssociationRuleMining(ShortColumn sets, ShortColumn items, double support) {
    Table temp = Table.create("temp");
    temp.addColumn(sets.copy());
    temp.addColumn(items.copy());
    temp.sortAscendingOn(sets.name(), items.name());

    ViewGroup baskets = temp.splitOn(temp.column(0));
    int[][] itemsets = new int[baskets.size()][];
    int basketIndex = 0;
    for (TemporaryView basket : baskets) {
      ShortRBTreeSet set = new ShortRBTreeSet(basket.shortColumn(1).data());
      int itemIndex = 0;
      itemsets[basketIndex] = new int[set.size()];
      for (short item : set) {
        itemsets[basketIndex][itemIndex] = item;
        itemIndex++;
      }
      basketIndex++;
    }

    this.model = new ARM(itemsets, support);
  }

  public List learn(double confidenceThreshold) {
    return model.learn(confidenceThreshold);
  }

  public List interestingRules(double confidenceThreshold,
                                                double interestThreshold,
                                                Object2DoubleOpenHashMap confidenceMap) {
    List rules = model.learn(confidenceThreshold);
    for (AssociationRule rule : rules) {
      double interest = rule.confidence - confidenceMap.getDouble(rule.consequent);
      if (Math.abs(interest) < interestThreshold) {
        rules.remove(rule);
      }
    }
    return rules;
  }

  public Table interest(double confidenceThreshold,
                        double interestThreshold,
                        Object2DoubleOpenHashMap confidenceMap) {

    Table interestTable = Table.create("Interest");
    interestTable.addColumn(CategoryColumn.create("Antecedent"));
    interestTable.addColumn(CategoryColumn.create("Consequent"));
    interestTable.addColumn(FloatColumn.create("Confidence"));
    interestTable.addColumn(FloatColumn.create("Interest"));

    List rules = model.learn(confidenceThreshold);

    for (AssociationRule rule : rules) {
      double interest = rule.confidence - confidenceMap.getDouble(new IntRBTreeSet(rule.consequent));
      if (Math.abs(interest) > interestThreshold) {
        interestTable.categoryColumn(0).addCell(Arrays.toString(rule.antecedent));
        interestTable.categoryColumn(1).addCell(Arrays.toString(rule.consequent));
        interestTable.floatColumn(2).add(rule.confidence);
        interestTable.floatColumn(3).add(interest);
      }
    }
    return interestTable;
  }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy