All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.io.ObliviousTreeDynamicBinConversionPack Maven / Gradle / Ivy

package com.expleague.ml.io;

import com.expleague.commons.func.types.ConversionPack;
import com.expleague.ml.dynamicGrid.interfaces.BinaryFeature;
import com.expleague.commons.func.types.TypeConverter;
import com.expleague.commons.seq.CharSeqReader;
import com.expleague.commons.seq.CharSeqTools;
import com.expleague.ml.DynamicGridEnabled;
import com.expleague.ml.dynamicGrid.interfaces.DynamicGrid;
import com.expleague.ml.dynamicGrid.models.ObliviousTreeDynamicBin;

import java.io.IOException;
import java.io.LineNumberReader;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.text.MessageFormat;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;

public class ObliviousTreeDynamicBinConversionPack implements ConversionPack {
  private static final MessageFormat FEATURE_LINE_PATTERN = new MessageFormat("feature: {0, number}, bin: {1, number}, ge: {2, number,#.#####}", Locale.US);

  static {
    final DecimalFormat format = new DecimalFormat();
    format.setDecimalSeparatorAlwaysShown(false);
    format.setGroupingUsed(false);
    format.setDecimalFormatSymbols(new DecimalFormatSymbols(Locale.US));
    format.setDecimalSeparatorAlwaysShown(true);
    format.setMaximumFractionDigits(5);
    format.setParseIntegerOnly(false);
    format.setParseIntegerOnly(false);
    FEATURE_LINE_PATTERN.setFormat(2, format);
  }

  public static class To implements TypeConverter {
    @Override
    public CharSequence convert(final ObliviousTreeDynamicBin ot) {
      StringBuilder result = new StringBuilder();
      for (final BinaryFeature feature : ot.features()) {
        result.append(FEATURE_LINE_PATTERN.format(new Object[]{feature.fIndex(), feature.binNo(), feature.condition()}))
                .append("\n");
      }
      final int leafsCount = 1 << ot.features().length;
      for (int i = 0; i < leafsCount; i++) {
        if (ot.values()[i] != 0.) {
          result.append(Integer.toBinaryString(i))
                  .append(":").append(ot.values()[i])
                  .append(":").append(0)
                  .append(" ");
        }
      }
      result = result.delete(result.length() - 1, result.length());
      return result;
    }
  }

  public static class From implements DynamicGridEnabled, TypeConverter {
    private DynamicGrid grid;

    @Override
    public DynamicGrid getGrid() {
      return grid;
    }

    @Override
    public void setGrid(final DynamicGrid grid) {
      this.grid = grid;
    }

    @Override
    public ObliviousTreeDynamicBin convert(final CharSequence source) {
      if (grid == null)
        throw new RuntimeException("DynamicGrid must be setup for serialization of oblivious trees with dynamicGrid, use SerializationRepository.customize!");
      String line;
      final LineNumberReader lnr = new LineNumberReader(new CharSeqReader(source));
      final List splits = new ArrayList(10);
      try {
        while ((line = lnr.readLine()) != null) {
          if (line.startsWith("feature")) {
            final Object[] parts = FEATURE_LINE_PATTERN.parse(line);
            final BinaryFeature bf = grid.row(((Long) parts[0]).intValue()).bf(((Long) parts[1]).intValue());
            splits.add(bf);
            if (Math.abs(bf.condition() - ((Number) parts[2]).doubleValue()) > 1e-4)
              throw new RuntimeException("Inconsistent grid set, conditions do not match! Grid: " + bf.condition() + " Found: " + parts[2]);
          } else break;
        }
        final double[] values = new double[1 << splits.size()];
        final CharSequence[] valuesStr = CharSeqTools.split(line, ' ');
        for (final CharSequence value : valuesStr) {
          final CharSequence[] pattern2ValueBased = CharSeqTools.split(value, ':');
          final int leafIndex = Integer.parseInt(pattern2ValueBased[0].toString(), 2);
          values[leafIndex] = Double.parseDouble(pattern2ValueBased[1].toString());
        }
        return new ObliviousTreeDynamicBin(splits, values);
      } catch (IOException e) {
        throw new RuntimeException(e);
      } catch (ParseException e) {
        throw new RuntimeException("Invalid oblivious tree format", e);
      }
    }
  }

  @Override
  public Class> to() {
    return To.class;
  }

  @Override
  public Class> from() {
    return From.class;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy