All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.data.tools.CatBoostPoolDescription Maven / Gradle / Ivy

package com.expleague.ml.data.tools;

import com.expleague.commons.func.Processor;
import com.expleague.commons.seq.CharSeqTools;

import java.io.IOException;
import java.io.Reader;
import java.util.Arrays;

/**
 * Created by noxoomo on 15/10/2017.
 */
public class CatBoostPoolDescription {

  enum ColumnType {
    Num,
    Cat,
    Target,
    Auxiliary,
    DocId,
    QueryId,
    Weight;

    static boolean isFactorColumn(ColumnType type) {
      return type == Num || type == Cat;
    }
  }

  private ColumnType[] columnTypes;
  private char delimiter = '\t';
  private boolean headerColumnFlag = false;


  public CatBoostPoolDescription(int columnCount) {
    columnTypes = new ColumnType[columnCount];
    Arrays.fill(columnTypes, ColumnType.Num);
    columnTypes[0] = ColumnType.Target;
  }

  ColumnType columnType(int columnIndex) {
    return columnTypes[columnIndex];
  }

  public char getDelimiter() {
    return delimiter;
  }

  public boolean hasHeaderColumn() {
    return headerColumnFlag;
  }

  public int columnCount() {
    return columnTypes.length;
  }

  public int factorCount() {
    int total = 0;
    for (ColumnType type : columnTypes) {
      if (ColumnType.isFactorColumn(type)) {
        total++;
      }
    }
    return total;
  }

  private void validate() {
    int targetCount = 0;
    int weightCount = 0;
    int featureCount = 0;
    for (ColumnType type : columnTypes) {
      switch (type) {
        case Target: {
          ++targetCount;
          break;
        }
        case Num:
        case Cat: {
          ++featureCount;
          break;
        }
        case Weight: {
          ++weightCount;
          break;
        }
      }
    }
    if ((targetCount != 1) || (weightCount > 1) || (featureCount == 0)) {
      throw new RuntimeException("Wrong pool description format");
    }
  }

  public static class DescriptionBuilder {
    private final CatBoostPoolDescription description;
    private int columnCount;

    public DescriptionBuilder(int columnCount) {
      this.columnCount = columnCount;
      this.description = new CatBoostPoolDescription(columnCount);
    }

    public DescriptionBuilder(final Reader poolInput,
                              final char delimiter) {
      this.columnCount = DataTools.getLineCount(poolInput, delimiter);
      this.description = new CatBoostPoolDescription(columnCount);
      description.delimiter = delimiter;
    }

    public DescriptionBuilder loadColumnDescription(final Reader input) throws IOException {
      description.columnTypes = new ColumnType[columnCount];
      Arrays.fill(description.columnTypes, ColumnType.Num);
      CharSeqTools.processLines(input, (Processor) arg -> {
        final CharSequence[] parts = CharSeqTools.split(arg, '\t');

        final int index = CharSeqTools.parseInt(parts[0]);
        description.columnTypes[index] = ColumnType.valueOf(parts[1].toString());
      });
      return this;
    }

    public DescriptionBuilder setDelimiter(char delimiter) {
      description.delimiter = delimiter;
      return this;
    }

    public DescriptionBuilder setHasHeaderColumnFlag(boolean headerColumn) {
      description.headerColumnFlag = headerColumn;
      return this;
    }

    public CatBoostPoolDescription description() {
      description.validate();
      return description;
    }
  }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy