meka.filters.multilabel.SuperNodeFilter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of meka Show documentation
Show all versions of meka Show documentation
The MEKA project provides an open source implementation of methods for multi-label classification and evaluation. It is based on the WEKA Machine Learning Toolkit. Several benchmark methods are also included, as well as the pruned sets and classifier chains methods, other methods from the scientific literature, and a wrapper to the MULAN framework.
/*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
package meka.filters.multilabel;
import meka.core.SuperLabelUtils;
import weka.core.*;
import meka.core.MLUtils;
import meka.classifiers.multitarget.NSR;
import weka.filters.*;
import java.util.*;
import java.io.*; // for test routin main()
/**
* SuperNodeFilter.java - Super Class Filter.
*
* Input:
* Data with label attributes, e.g., [0,1,2,3,4]
* A desired partition of indices, e.g., [[1,3],[4],[0,2]], filter
* Output:
* New data with label attributes: [1+3,4,0+2]
* (the values each attribute can take are pruned if necessary)
*
* @author Jesse Read
* @version June 2012
*/
public class SuperNodeFilter extends SimpleBatchFilter {
protected Instance x_template = null;
protected int m_P = 0, m_N = 0;
protected int indices[][] = null;
public void setIndices(int n[][]) {
for(int i = 0; i < n.length; i++) {
Arrays.sort(n[i]); // always sorted!
}
this.indices = n;
}
public void setP(int p) {
this.m_P = p;
}
public int getP() {
return this.m_P;
}
public void setN(int n) {
this.m_N = n;
}
@Override
public Instances determineOutputFormat(Instances D) throws Exception {
//System.out.println("DETERMINE OUTPUT FORMAT = "+D.numInstances());
Instances D_out = new Instances(D,0);
int L = D.classIndex();
for(int i = 0; i < L-indices.length; i++) {
D_out.deleteAttributeAt(0);
}
return D_out;
}
public Instance getTemplate() {
return x_template;
}
@Override
public Instances process(Instances D) throws Exception {
//System.out.println("PROCESS! = "+D.numInstances());
int L = D.classIndex();
D = new Instances(D); // D_
// rename classes
for(int j = 0; j < L; j++) {
D.renameAttribute(j,encodeClass(j));
}
// merge labels
D = mergeLabels(D,indices,m_P,m_N);
// templates
x_template = D.firstInstance();
setOutputFormat(D);
//System.out.println("PROCESS! => "+D);
return D;
}
private static String join(int objs[], final String delimiter) {
if (objs == null || objs.length < 1)
return "";
StringBuffer buffer = new StringBuffer(String.valueOf(objs[0]));
for(int j = 1; j < objs.length; j++) {
buffer.append(delimiter).append(String.valueOf(objs[j]));
}
return buffer.toString();
}
/** (3,'_') -> "c_3" */
public static String encodeClass(int j) {
return "c_"+j;
}
/** ("c_3",'_') -> 3 */
public static int decodeClass(String a) {
//System.out.println(""+a);
return Integer.parseInt(a.substring(a.indexOf('_')+1));
}
/** (["c_3","c_1"]) -> "c_3+1" */
public static String encodeClass(String c_j, String c_k) {
return "c_"+join(decodeClasses(c_j),"+")+"+"+join(decodeClasses(c_k),"+");
}
/** ([3,1]) -> "c_3+1" */
public static String encodeClass(int c_[]) {
String c = "c_";
for(int j = 0; j < c_.length; j++) {
c = c + c_[j] + "+";
}
c = c.substring(0,c.length()-1);
return c;
}
/** ("c_3+1") -> [3,1] */
public static int[] decodeClasses(String a) {
String s[] = new String(a.substring(a.indexOf('_')+1)).split("\\+");
int vals[] = new int[s.length];
for(int j = 0; j < vals.length; j++) {
vals[j] = Integer.parseInt(s[j]);
}
return vals;
}
/** (3,1) -> "3+1" */
public static String encodeValue(String v_j, String v_k) {
return String.valueOf(v_j)+"+"+String.valueOf(v_k);
}
/** (3,1,2) -> "3+1+2" */
public static String encodeValue(Instance x, int indices[]) {
String v = "";
for(int j = 0; j < indices.length; j++) {
v+=x.stringValue(indices[j])+"+";
}
v = v.substring(0,v.length()-1);
return v;
}
/** "C+A+B" -> ["C","A","B"] */
public static String[] decodeValue(String a) {
return a.split("\\+");
}
/**
* Return a set of all the combinations of attributes at 'indices' in 'D', pruned by 'p'; e.g., {00,01,11}.
*/
public static Set getValues(Instances D, int indices[], int p) {
HashMap count = getCounts(D, indices, p);
return count.keySet();
}
/**
* Return a set of all the combinations of attributes at 'indices' in 'D', pruned by 'p'; AND THEIR COUNTS, e.g., {(00:3),(01:8),(11:3))}.
*/
public static HashMap getCounts(Instances D, int indices[], int p) {
HashMap count = new HashMap();
for(int i = 0; i < D.numInstances(); i++) {
String v = encodeValue(D.instance(i), indices);
count.put(v, count.containsKey(v) ? count.get(v) + 1 : 1);
}
MLUtils.pruneCountHashMap(count,p);
return count;
}
/**
* Merge Labels - Make a new 'D', with labels made into superlabels, according to partition 'indices', and pruning values 'p' and 'n'.
* @param D assume attributes in D labeled by original index
* @return Instances with attributes at j and k moved to position L as (j,k), with classIndex = L-1
*/
public static Instances mergeLabels(Instances D, int indices[][], int p, int n) {
int L = D.classIndex();
int K = indices.length;
ArrayList values[] = new ArrayList[K];
HashMap counts[] = new HashMap[K];
// create D_
Instances D_ = new Instances(D);
// clear D_
for(int j = 0; j < L; j++) {
D_.deleteAttributeAt(0);
}
// create atts
for(int j = 0; j < K; j++) {
int att[] = indices[j];
//int values[] = new int[2]; //getValues(indices,D,p);
counts[j] = getCounts(D,att,p);
Set vals = counts[j].keySet(); //getValues(D,att,p);
values[j] = new ArrayList(vals);
D_.insertAttributeAt(new Attribute(encodeClass(att),new ArrayList(vals)),j);
}
// copy over values
ArrayList deleteList = new ArrayList();
for(int i = 0; i < D.numInstances(); i++) {
Instance x = D.instance(i);
for(int j = 0; j < K; j++) {
String y = encodeValue(x,indices[j]);
try {
D_.instance(i).setValue(j,y); // y =
} catch(Exception e) {
// value not allowed
deleteList.add(i); // mark it for deletion
String y_close[] = NSR.getTopNSubsets(y,counts[j],n); // get N subsets
for(int m = 0; m < y_close.length; m++) {
//System.out.println("add "+y_close[m]+" "+counts[j]);
Instance x_copy = (Instance)D_.instance(i).copy();
x_copy.setValue(j,y_close[m]);
x_copy.setWeight(1.0/y_close.length);
D_.add(x_copy);
}
}
}
}
// clean up
Collections.sort(deleteList,Collections.reverseOrder());
//System.out.println("Deleting "+deleteList.size()+" defunct instances.");
for (int i : deleteList) {
D_.delete(i);
}
// set class
D_.setClassIndex(K);
// done!
D = null;
return D_;
}
/**
* Merge Labels.
*
* @param j index 1 (assume that j < k
)
* @param k index 2 (assume that j < k
)
* @param D iInstances, with attributes in labeled by original index
* @return Instaces with attributes at j and k moved to position L as (j,k), with classIndex = L-1
*/
public static Instances mergeLabels(Instances D, int j, int k, int p) {
int L = D.classIndex();
HashMap count = new HashMap();
Set values = new HashSet();
for(int i = 0; i < D.numInstances(); i++) {
String v = encodeValue(D.instance(i).stringValue(j),D.instance(i).stringValue(k));
String w = ""+(int)D.instance(i).value(j)+(int)D.instance(i).value(k);
//System.out.println("w = "+w);
count.put(v,count.containsKey(v) ? count.get(v) + 1 : 1);
values.add(encodeValue(D.instance(i).stringValue(j),D.instance(i).stringValue(k)));
}
//System.out.println("("+j+","+k+")"+values);
System.out.print("pruned from "+count.size()+" to ");
MLUtils.pruneCountHashMap(count,p);
String y_max = (String)MLUtils.argmax(count); // @todo won't need this in the future
System.out.println(""+count.size()+" with p = "+p);
System.out.println(""+count);
values = count.keySet();
// Create and insert the new attribute
D.insertAttributeAt(new Attribute(encodeClass(D.attribute(j).name(),D.attribute(k).name()),new ArrayList(values)),L);
// Set values for the new attribute
for(int i = 0; i < D.numInstances(); i++) {
Instance x = D.instance(i);
String y_jk = encodeValue(x.stringValue(j),x.stringValue(k));
try {
x.setValue(L,y_jk); // y_jk =
} catch(Exception e) {
//x.setMissing(L);
//D.delete(i);
//i--;
String y_close[] = getNeighbours(y_jk,count,1); // A+B+NEG, A+C+NEG
//System.out.println("OK, that value ("+y_jk+") didn't exist ... set the closests ones ...: "+Arrays.toString(y_close));
int max_c = 0;
for (String y_ : y_close) {
int c = count.get(y_);
if (c > max_c) {
max_c = c;
y_max = y_;
}
}
//System.out.println("we actually found "+Arrays.toString(y_close)+" but will only set one for now (the one with the highest count) : "+y_max+" ...");
x.setValue(L,y_max);
// ok, that value didn't exist, set the maximum one (@TODO: set the nearest one)
}
}
// Delete separate attributes
D.deleteAttributeAt(k > j ? k : j);
D.deleteAttributeAt(k > j ? j : k);
// Set class index
D.setClassIndex(L-1);
return D;
}
/**
* GetNeighbours - return from set S, label-vectors closest to y, having no more different than 'n' bits different.
*/
public static String[] getNeighbours(String y, ArrayList S, int n) {
String ya[] = decodeValue(y);
ArrayList Y = new ArrayList();
for(String y_ : S) {
if(MLUtils.bitDifference(ya,decodeValue(y_)) <= n) {
Y.add(y_);
}
}
return (String[])Y.toArray(new String[Y.size()]);
}
/**
* GetNeighbours - return from set S (the keySet of HashMap C), label-vectors closest to y, having no more different than 'n' bits different.
*/
public static String[] getNeighbours(String y, HashMap C, int n) {
return getNeighbours(y,new ArrayList(C.keySet()),n);
}
protected int m_Seed = 0;
@Override
public String globalInfo() {
return "A SuperNode Filter";
}
public static void main(String[] argv) {
try {
String fname = Utils.getOption('i',argv);
Instances D = new Instances(new BufferedReader(new FileReader(fname)));
SuperNodeFilter f = new SuperNodeFilter();
int c = Integer.parseInt(Utils.getOption('c',argv));
D.setClassIndex(c);
System.out.println(""+f.process(D));
//runFilter(new SuperNodeFilter(), argv);
} catch(Exception e) {
System.err.println("");
e.printStackTrace();
//System.exit(1);
}
}
}