All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.trees.isolation.IsolationForest Maven / Gradle / Ivy

package com.github.chen0040.trees.isolation;


import com.github.chen0040.data.frame.BasicDataFrame;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;


/**
 * Created by xschen on 17/8/15.
 */
@Getter
@Setter
public class IsolationForest {

    private double threshold = 0.5;
    private int treeCount = 100;

    private static final Random random = new Random();
    private static final double log2 = Math.log(2);

    @Setter(AccessLevel.NONE)
    private List trees;

    @Setter(AccessLevel.NONE)
    private int rowCount;



    public IsolationForest(){

    }

    private static double log2(double n){
        return Math.log(n) / log2;
    }


    public boolean isAnomaly(DataRow tuple) {
        return evaluate(tuple) > threshold;
    }

    public void fit(DataFrame batch) {
        trees = new ArrayList<>();
        rowCount = batch.rowCount();
        int maxHeight = (int)Math.ceil(log2(rowCount));
        for(int i=0; i < treeCount; ++i){
            DataFrame treeBatch = randomize(batch);
            IFTreeNode tree = new IFTreeNode(treeBatch, random, 0, maxHeight);
            trees.add(tree);
        }
    }

    private DataFrame randomize(DataFrame dataFrame){
        DataFrame treeBatch = new BasicDataFrame();

        List list =new ArrayList();
        for(int i = 0; i < dataFrame.rowCount(); ++i){
            list.add(dataFrame.row(i));
        }
        Collections.shuffle(list);
        for(int i=0; i < list.size(); ++i){
            treeBatch.addRow(list.get(i));
        }

        treeBatch.lock();
        return treeBatch;
    }

    public double[] getDistributionScores(DataRow tuple){
        double[] scores = new double[2];
        scores[0] = evaluate(tuple);
        scores[1] = 1- scores[0];

        return scores;
    }

    public double evaluate(DataRow tuple) {
        double avgPathLength = 0;
        for(int i=0; i < trees.size(); ++i){
            avgPathLength += trees.get(i).pathLength(tuple);
        }
        avgPathLength /= trees.size();

        return Math.pow(2, - avgPathLength / IFTreeNode.heuristicCost(rowCount));
    }


    public DataFrame fitAndTransform(DataFrame data) {
        fit(data);
        data = data.makeCopy();
        for(int i=0; i < data.rowCount(); ++i) {
            DataRow row = data.row(i);
            boolean anomaly = isAnomaly(row);
            row.setCategoricalTargetCell("anomaly", anomaly ? "1" : "0");
        }

        return data;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy