All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hex.tree.isofor.IsolationForestModel Maven / Gradle / Ivy

There is a newer version: 3.46.0.6
Show newest version
package hex.tree.isofor;

import hex.ModelCategory;
import hex.ModelMetrics;
import hex.genmodel.CategoricalEncoding;
import hex.genmodel.utils.ArrayUtils;
import hex.ScoreKeeper;
import hex.genmodel.utils.DistributionFamily;
import hex.tree.SharedTreeModel;
import water.Key;
import water.fvec.Frame;
import water.util.SBPrintStream;
import water.util.TwoDimTable;


public class IsolationForestModel extends SharedTreeModel {

  public static class IsolationForestParameters extends SharedTreeModel.SharedTreeParameters {
    public String algoName() { return "IsolationForest"; }
    public String fullName() { return "Isolation Forest"; }
    public String javaName() { return IsolationForestModel.class.getName(); }
    public int _mtries;
    public long _sample_size;
    public double _contamination;

    public IsolationForestParameters() {
      super();
      _mtries = -1;
      _sample_size = 256;
      _max_depth = 8; // log2(_sample_size)
      _sample_rate = -1;
      _min_rows = 1;
      _min_split_improvement = 0;
      _nbins = 2;
      _nbins_cats = 2;
      // _nbins_top_level = 2;
      _histogram_type = HistogramType.Random;
      _distribution = DistributionFamily.gaussian;
      
      // IF specific
      _contamination = -1; // disabled
    }

    @Override
    protected double defaultStoppingTolerance() {
      return 0.01; // (inherited value 0.001 would be too low for the default criterion anomaly_score)
    }
  }

  public static class IsolationForestOutput extends SharedTreeModel.SharedTreeOutput {
    public int _max_path_length;
    public int _min_path_length;

    public String _response_column;
    public String[] _response_domain;
    
    public IsolationForest.VarSplits _var_splits;
    public TwoDimTable _variable_splits;

    public IsolationForestOutput(IsolationForest b) { 
      super(b);
      if (b.vresponse() != null) {
        _response_column = b._parms._response_column;
        _response_domain = b.vresponse().domain();
      }
    }

    @Override
    public ModelCategory getModelCategory() {
      return ModelCategory.AnomalyDetection;
    }

    @Override
    public double defaultThreshold() {
      return _defaultThreshold;
    }

    @Override
    public String responseName() {
      return _response_column;
    }

    @Override
    public boolean hasResponse() {
      return _response_column != null;
    }

    @Override
    public int responseIdx() {
      return _names.length;
    }
  }

  public IsolationForestModel(Key selfKey, IsolationForestParameters parms, IsolationForestOutput output ) { 
    super(selfKey, parms, output);
  }

  @Override
  public void initActualParamValues() {
    super.initActualParamValues();
    if (_parms._stopping_metric == ScoreKeeper.StoppingMetric.AUTO){
      if (_parms._stopping_rounds == 0) {
        _parms._stopping_metric = null;
      } else {
        _parms._stopping_metric = ScoreKeeper.StoppingMetric.anomaly_score;
      }
    }
    if (_parms._categorical_encoding == Parameters.CategoricalEncodingScheme.AUTO) {
      _parms._categorical_encoding = Parameters.CategoricalEncodingScheme.Enum;
    }
  }

  @Override
  public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
    // note: in the context of scoring on a training frame during model building domain will be null
    if (domain != null && _output.hasResponse()) {
      return new MetricBuilderAnomalySupervised(domain);
    } else {
      return new ModelMetricsAnomaly.MetricBuilderAnomaly("Isolation Forest Metrics", outputAnomalyFlag());
    }
  }

  @Override
  protected String[] makeScoringNames() {
    if (outputAnomalyFlag()) {
      return new String[]{"predict", "score", "mean_length"};
    } else {
      return new String[]{"predict", "mean_length"};
    }
  }

  @Override
  protected String[][] makeScoringDomains(Frame adaptFrm, boolean computeMetrics, String[] names) {
    assert outputAnomalyFlag() ? names.length == 3 : names.length == 2;
    String[][] domains = new String[names.length][];
    if (outputAnomalyFlag()) {
      domains[0] = _output._response_domain != null ? _output._response_domain : new String[]{"0", "1"};
    }
    return domains;
  }

  /** Bulk scoring API for one row.  Chunks are all compatible with the model,
   *  and expect the last Chunks are for the final distribution and prediction.
   *  Default method is to just load the data into the tmp array, then call
   *  subclass scoring logic. */
  @Override protected double[] score0(double[] data, double[] preds, double offset, int ntrees) {
    super.score0(data, preds, offset, ntrees);
    boolean outputAnomalyFlag = outputAnomalyFlag();
    int off = outputAnomalyFlag ? 1 : 0;
    if (ntrees >= 1) 
      preds[off + 1] = preds[0] / ntrees;
    preds[off] = normalizePathLength(preds[0]);
    if (outputAnomalyFlag)
      preds[0] = preds[1] >= _output._defaultThreshold ? 1 : 0; 
    return preds;
  }

  final double normalizePathLength(double pathLength) {
    return normalizePathLength(pathLength, _output._min_path_length, _output._max_path_length);
  }

  static double normalizePathLength(double pathLength, int minPathLength, int maxPathLength) {
    if (maxPathLength > minPathLength) {
      return  (maxPathLength - pathLength) / (maxPathLength - minPathLength);
    } else {
      return 1;
    }
  }

  @Override
  public IsolationForestMojoWriter getMojo() {
    return new IsolationForestMojoWriter(this);
  }

  @Override
  public String[] adaptTestForTrain(Frame test, boolean expensive, boolean computeMetrics) {
    if (!computeMetrics || _output._response_column == null) {
      return super.adaptTestForTrain(test, expensive, computeMetrics);
    } else {
      return adaptTestForTrain(
              test,
              _output._origNames,
              _output._origDomains,
              ArrayUtils.append(_output._names, _output._response_column),
              ArrayUtils.append(_output._domains, _output._response_domain),
              _parms,
              expensive,
              true,
              _output.interactionBuilder(),
              getToEigenVec(),
              _toDelete,
              false
      );
    }
  }

  final boolean outputAnomalyFlag() {
    return _output._defaultThreshold >= 0;
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy