All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.data.tools.DataTools Maven / Gradle / Ivy

There is a newer version: 1.4.9
Show newest version
package com.expleague.ml.data.tools;


import com.expleague.commons.func.types.SerializationRepository;
import com.expleague.commons.func.types.impl.TypeConvertersCollection;
import com.expleague.commons.io.StreamTools;
import com.expleague.commons.math.Func;
import com.expleague.commons.math.MathTools;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.*;
import com.expleague.commons.math.vectors.impl.mx.MxByRowsBuilder;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.math.vectors.impl.vectors.IndexTransVec;
import com.expleague.commons.math.vectors.impl.vectors.VecBuilder;
import com.expleague.commons.random.FastRandom;
import com.expleague.commons.seq.*;
import com.expleague.commons.system.RuntimeUtils;
import com.expleague.commons.text.StringUtils;
import com.expleague.commons.util.JSONTools;
import com.expleague.commons.util.logging.Logger;
import com.expleague.ml.impl.BFGridConstructor;
import com.expleague.ml.impl.BFGridImpl;
import com.expleague.ml.CompositeTrans;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.set.impl.VecDataSetImpl;
import com.expleague.ml.dynamicGrid.models.ObliviousTreeDynamicBin;
import com.expleague.ml.func.FuncJoin;
import com.expleague.ml.BFGrid;
import com.expleague.ml.io.ModelsSerializationRepository;
import com.expleague.ml.loss.L2;
import com.expleague.ml.loss.StatBasedLoss;
import com.expleague.ml.loss.WeightedLoss;
import com.expleague.ml.meta.GroupedDSItem;
import com.expleague.ml.meta.PoolFeatureMeta;
import com.expleague.ml.meta.impl.JsonFeatureMeta;
import com.expleague.ml.meta.impl.JsonTargetMeta;
import com.expleague.ml.meta.items.FakeItem;
import com.expleague.ml.meta.items.QURLItem;
import com.expleague.ml.models.MultiClassModel;
import com.expleague.ml.models.ObliviousTree;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.AnnotationIntrospector;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.type.CollectionType;
import com.fasterxml.jackson.module.jaxb.JaxbAnnotationIntrospector;
import com.expleague.commons.math.vectors.impl.idxtrans.RowsPermutation;
import com.expleague.commons.math.vectors.impl.mx.VecBasedMx;
import com.expleague.commons.math.vectors.impl.vectors.SparseVec;
import com.expleague.commons.util.ArrayTools;
import com.expleague.commons.util.Pair;
import com.expleague.ml.TargetFunc;
import com.expleague.ml.dynamicGrid.interfaces.DynamicGrid;
import com.expleague.ml.func.Ensemble;
import com.expleague.ml.func.TransJoin;
import com.expleague.ml.meta.DSItem;
import com.expleague.ml.meta.impl.JsonDataSetMeta;
import com.expleague.ml.models.ObliviousMultiClassTree;
import com.expleague.ml.models.multilabel.MultiLabelBinarizedModel;
import gnu.trove.list.TDoubleList;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.list.linked.TDoubleLinkedList;
import gnu.trove.list.linked.TIntLinkedList;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.TObjectIntMap;
import gnu.trove.map.hash.TIntObjectHashMap;
import gnu.trove.map.hash.TObjectIntHashMap;
import org.apache.commons.lang3.mutable.MutableInt;

import java.io.*;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.zip.GZIPInputStream;

/**
 * User: solar
 * Date: 26.12.10
 * Time: 19:05
 */
@SuppressWarnings("unused")
public class DataTools {
  public static Logger log = Logger.create(DataTools.class);

  public static final SerializationRepository SERIALIZATION = new SerializationRepository<>(
      new TypeConvertersCollection(MathTools.CONVERSION, DataTools.class, "com.expleague.ml.io"), CharSequence.class);


  public static Pool loadFromFeaturesTxt(final String file) throws IOException {
    return loadFromFeaturesTxt(file, file.endsWith(".gz") ? new InputStreamReader(new GZIPInputStream(new FileInputStream(file))) : new FileReader(file));
  }

  public static FeaturesTxtPool loadFromFeaturesTxt(final String fileName, final Reader in) throws IOException {
    final List items = new ArrayList<>();
    final VecBuilder target = new VecBuilder();
    final VecBuilder data = new VecBuilder();
    final int[] featuresCount = new int[]{-1};
    CharSeqTools.processLines(in, new Consumer() {
      int lindex = 0;

      @Override
      public void accept(final CharSequence arg) {
        lindex++;
        final CharSequence[] parts = CharSeqTools.split(arg, '\t');
        items.add(new QURLItem(CharSeqTools.parseInt(parts[0]), parts[2].toString(), CharSeqTools.parseInt(parts[3])));
        target.append(CharSeqTools.parseDouble(parts[1]));
        if (featuresCount[0] < 0)
          featuresCount[0] = parts.length - 4;
        else if (featuresCount[0] != parts.length - 4)
          throw new RuntimeException("\"Failed to parse line \" + lindex + \":\"");
        for (int i = 4; i < parts.length; i++) {
          data.append(CharSeqTools.parseDouble(parts[i]));
        }
      }
    });
    return new FeaturesTxtPool(
        new ArraySeq<>(items.toArray(new QURLItem[items.size()])),
        new VecBasedMx(featuresCount[0], data.build()),
        target.build()
    );
  }


  public static int getLineCount(final Reader input, final char sep) {
    return CharSeqTools.lines(input, false).limit(1).map(line -> CharSeqTools.split(line, sep).length).findFirst().orElse(0);
  }

  public static double stringToDoubleHash(final CharSequence in) {
    final long hashCode = in.toString().hashCode();
    return hashCode * 1.0 / (1L << 32);
  }

  public static CatboostPool loadFromCatBoostPool(final CatBoostPoolDescription poolDescription,
                                                  final Reader in) throws IOException {
    final VecBuilder target = new VecBuilder();
    final VecBuilder data = new VecBuilder();

    CharSeqTools.processLines(in, new Consumer() {
      int lindex = -1;

      @Override
      public void accept(final CharSequence arg) {
        lindex++;
        if (lindex == 0 && poolDescription.hasHeaderColumn()) {
          return;
        }

        final CharSequence[] parts = CharSeqTools.split(arg, poolDescription.getDelimiter());
        if (parts.length != poolDescription.columnCount()) {
          throw new RuntimeException("\"Failed to parse line \" + lindex + \":\"");
        }

        int id = lindex;

        for (int column = 0; column < poolDescription.columnCount(); ++column) {
          final CharSequence columnSeq = parts[column];
          switch (poolDescription.columnType(column)) {
            case Target: {
              target.append(CharSeqTools.parseDouble(columnSeq));
              break;
            }
            case Num: {
              data.append(CharSeqTools.parseDouble(columnSeq));
              break;
            }
            case Categ: {
              final double value = stringToDoubleHash(columnSeq);
              if (Double.isNaN(value)) {
                throw new RuntimeException("Error: catFeature hash values should not be NaN");
              }
              data.append(value);
              break;
            }
            case Weight: {
              throw new RuntimeException("Unimplemented yet");
            }
            case DocId:
            case Auxiliary:
            case QueryId:
            default: {
              break;
            }
          }
        }
      }
    });

    final Set catFeatureIds = new TreeSet<>();
    int factorId = 0;
    for (int column = 0; column < poolDescription.columnCount(); ++column) {
      final CatBoostPoolDescription.ColumnType columnType = poolDescription.columnType(column);
      if (columnType == CatBoostPoolDescription.ColumnType.Categ) {
        catFeatureIds.add(factorId);
      }
      if (CatBoostPoolDescription.ColumnType.isFactorColumn(columnType)) {
        ++factorId;
      }
    }
    final Mx vecData = new VecBasedMx(poolDescription.factorCount(), data.build());
    final Vec targetVec = target.build();
    return new CatboostPool(vecData, targetVec, catFeatureIds);
  }

  public static void writeModel(final Function result, final File to) throws IOException {
    writeModel(result, new FileOutputStream(to));
  }

  public static void writeModel(final Function result, OutputStream to) {
    final BFGrid grid = grid(result);
    StreamTools.writeChars(CharSeqTools.concat(result.getClass().getCanonicalName(), "\t", Boolean.toString(grid != null), "\n", SERIALIZATION.write(result)), to);
  }

  public static  T readModel(final InputStream inputStream, final ModelsSerializationRepository serializationRepository) throws IOException, ClassNotFoundException {
    final LineNumberReader modelReader = new LineNumberReader(new InputStreamReader(inputStream));
    final String line = modelReader.readLine();
    final CharSequence[] parts = CharSeqTools.split(line, '\t');
    //noinspection unchecked
    final Class modelClazz = (Class) Class.forName(parts[0].toString());
    //noinspection unchecked
    return (T) serializationRepository.read(StreamTools.readReader(modelReader), modelClazz);
  }

  public static  T readModel(final String fileName, final ModelsSerializationRepository serializationRepository) throws IOException, ClassNotFoundException {
    return readModel(new FileInputStream(fileName), serializationRepository);
  }

  public static  T readModel(final InputStream modelInputStream, final InputStream gridInputStream) throws IOException, ClassNotFoundException {
    final ModelsSerializationRepository repository = new ModelsSerializationRepository();
    final BFGrid grid = repository.read(StreamTools.readStream(gridInputStream), BFGridImpl.class);
    final ModelsSerializationRepository customizedRepository = repository.customizeGrid(grid);
    return readModel(modelInputStream, customizedRepository);
  }

  public static  T readModel(final InputStream modelInputStream) throws IOException, ClassNotFoundException {
    final ModelsSerializationRepository repository = new ModelsSerializationRepository();
    final BFGridConstructor grid = new BFGridConstructor();
    final ModelsSerializationRepository customizedRepository = repository.customizeGrid(grid);
    T model = readModel(modelInputStream, customizedRepository);
    grid.build();
    return model;
  }

  public static void writeBinModel(final Function result, final File file) {
    if (result instanceof Ensemble) {
      //noinspection unchecked
      final Ensemble ensemble = (Ensemble) result;
      if (ensemble.models.length == 0)
        return;
      if (ensemble.models[0] instanceof ObliviousTreeDynamicBin) {
        final DynamicGrid grid = dynamicGrid(ensemble);
        final DynamicBinModelBuilder builder = new DynamicBinModelBuilder(Objects.requireNonNull(grid));
        for (int i = 0; i < ensemble.models.length; ++i) {
          builder.append((ObliviousTreeDynamicBin) ensemble.models[i], ensemble.weights.at(i));
        }
        builder.build().toFile(file);
      } else if (ensemble.models[0] instanceof ObliviousTree) {
        final BFGrid grid = grid(ensemble);
        final BinModelBuilder builder = new BinModelBuilder(Objects.requireNonNull(grid));
        for (int i = 0; i < ensemble.models.length; ++i) {
          builder.append((ObliviousTree) ensemble.models[i], ensemble.weights.at(i));
        }
        builder.build().toFile(file);
      }
    }
  }

  public static Pair validateModel(final InputStream modelInputStream, final ModelsSerializationRepository repository) throws IOException {
    try {
      final Trans trans = readModel(modelInputStream, repository);
      return Pair.create(true, "Valid model : " + trans.getClass().getSimpleName());
    }
    catch (ClassNotFoundException e) {
      return Pair.create(false, "Invalid model : " + e.getCause());
    }
  }

  public static Pair validateModel(final String filePath, final ModelsSerializationRepository repository) throws IOException {
    return validateModel(new FileInputStream(filePath), repository);
  }

  public static DynamicGrid dynamicGrid(final Function result) {
    if (result instanceof Ensemble) {
      final Ensemble ensemble = (Ensemble) result;
      return dynamicGrid(ensemble.last());
    }
    else if (result instanceof ObliviousTreeDynamicBin) {
      return ((ObliviousTreeDynamicBin) result).grid();
    }
    return null;
  }

  public static BFGrid grid(final Function result) {
    if (result instanceof CompositeTrans) {
      final CompositeTrans composite = (CompositeTrans) result;
      BFGrid grid = grid(composite.f);
      grid = grid == null ? grid(composite.g) : grid;
      return grid;
    }
    else if (result instanceof FuncJoin) {
      final FuncJoin join = (FuncJoin) result;
      for (final Func dir : join.dirs()) {
        final BFGrid grid = grid(dir);
        if (grid != null)
          return grid;
      }
    }
    else if (result instanceof TransJoin) {
      final TransJoin join = (TransJoin) result;
      for (final Trans dir : join.dirs) {
        final BFGrid grid = grid(dir);
        if (grid != null)
          return grid;
      }
    }
    else if (result instanceof Ensemble) {
      final Ensemble ensemble = (Ensemble) result;
      for (final Trans dir : ensemble.models) {
        final BFGrid grid = grid(dir);
        if (grid != null)
          return grid;
      }
    }
    else if (result instanceof MultiClassModel) {
      return grid(((MultiClassModel) result).getInternModel());
    }
    else if (result instanceof MultiLabelBinarizedModel) {
      return grid(((MultiLabelBinarizedModel) result).getInternModel());
    }
    else if (result instanceof ObliviousTree) {
      return ((ObliviousTree) result).grid();
    }
    else if (result instanceof ObliviousMultiClassTree) {
      return ((ObliviousMultiClassTree) result).binaryClassifier().grid();
    }
    return null;
  }

  public static DataSet extendDataset(final VecDataSet sourceDS, final Mx addedColumns) {
    final Vec[] columns = new Vec[addedColumns.columns()];
    for (int i = 0; i < addedColumns.columns(); i++) {
      columns[i] = addedColumns.col(i);
    }
    return extendDataset(sourceDS, columns);
  }

  public static VecDataSet extendDataset(final VecDataSet sourceDS, final Vec... addedColumns) {
    if (addedColumns.length == 0)
      return sourceDS;

    final Mx oldData = sourceDS.data();
    final Mx newData = new VecBasedMx(oldData.rows(), oldData.columns() + addedColumns.length);
    for (final MxIterator iter = oldData.nonZeroes(); iter.advance(); ) {
      newData.set(iter.row(), iter.column(), iter.value());
    }
    for (int i = 0; i < addedColumns.length; i++) {
      for (final VecIterator iter = addedColumns[i].nonZeroes(); iter.advance(); ) {
        newData.set(iter.index(), oldData.columns() + i, iter.value());
      }
    }
    return new VecDataSetImpl(newData, sourceDS);
  }

  public static Vec value(final Mx ds, final Func f) {
    final Vec result = new ArrayVec(ds.rows());
    for (int i = 0; i < ds.rows(); i++) {
      result.set(i, f.value(ds.row(i)));
    }
    return result;
  }

  public static  WeightedLoss bootstrap(final LocalLoss loss, final FastRandom rnd) {
    final int[] poissonWeights = new int[loss.xdim()];
    for (int i = 0; i < loss.xdim(); i++) {
      poissonWeights[i] = rnd.nextPoisson(1.);
    }
    return new WeightedLoss<>(loss, poissonWeights);
  }

  public static Class targetByName(final String name) {
    try {
      //noinspection unchecked
      return (Class) Class.forName("com.expleague.ml.loss." + name);
    }
    catch (Exception e) {
      throw new RuntimeException("Unable to create requested target: " + name, e);
    }
  }

  public static int[][] splitAtRandom(final int size, final FastRandom rng, final double... v) {
    final Vec weights = new ArrayVec(v);
    final TIntList[] folds = new TIntList[v.length];
    for (int i = 0; i < folds.length; i++) {
      folds[i] = new TIntLinkedList();
    }
    for (int i = 0; i < size; i++) {
      folds[rng.nextSimple(weights)].add(i);
    }
    final int[][] result = new int[folds.length][];
    for (int i = 0; i < folds.length; i++) {
      result[i] = folds[i].toArray();
    }
    return result;
  }

  public static  List> splitDataSet(Pool pool, final FastRandom rng, final double... v) {
    final int[][] cvSplit = DataTools.splitAtRandom(pool.size(), rng, v);
    final List> result = new ArrayList<>();
    for (int i = 0; i < v.length; i++) {
      result.add(pool.sub(cvSplit[i]));
    }
    return result;
  }

  public static  List> splitGroupDataSet(Pool pool, final FastRandom rng, final double... v) {
    final TObjectIntMap groups = new TObjectIntHashMap<>();
    final TIntObjectMap rgroups = new TIntObjectHashMap<>();
    for (int i = 0; i < pool.size(); i++) {
      final String group = pool.items.at(i).groupId();
      if (!groups.containsKey(group)) {
        rgroups.put(groups.size(), group);
        groups.put(group, groups.size());
      }
    }

    final int[][] cvSplit = DataTools.splitAtRandom(groups.size(), rng, v);
    final List> result = new ArrayList<>();
    for (int i = 0; i < v.length; i++) {
      final TIntList indices = new TIntArrayList();
      final Set currentGroups = Arrays.stream(cvSplit[i]).mapToObj(rgroups::get).collect(Collectors.toSet());

      for (int j = 0; j < pool.size(); j++) {
        final String group = pool.items.at(j).groupId();
        if (currentGroups.contains(group))
          indices.add(j);
      }

      result.add(pool.sub(indices.toArray()));
    }
    return result;
  }

  public static  Vec calcAll(final Function result, final DataSet data) {
    final VecBuilder results = new VecBuilder(data.length());
    int dim = 0;
    for (int i = 0; i < data.length(); i++) {
      final Vec vec = result.apply(data.at(i));
      for (int j = 0; j < vec.length(); j++) {
        results.add(vec.at(j));
      }
      dim = vec.length();
    }
    return dim > 1 ? new VecBasedMx(dim, results.build()) : results.build();
  }

  public static  Target newTarget(final Class targetClass, final Seq values, final DataSet ds) {
    Target target;
    target = RuntimeUtils.newInstanceByAssignable(targetClass, values, ds);
    if (target != null)
      return target;
    throw new RuntimeException("No proper constructor!");
  }

  public static  void writePoolTo(final Pool pool, final Writer out) throws IOException {
    final JsonFactory jsonFactory = new JsonFactory();
    jsonFactory.disable(JsonGenerator.Feature.QUOTE_FIELD_NAMES);
    jsonFactory.configure(JsonParser.Feature.ALLOW_COMMENTS, false);
    final JsonGenerator generator = jsonFactory.createGenerator(out);
    { // meta
      out.append("items").append('\t');
      final ObjectMapper mapper = new ObjectMapper(jsonFactory);
      final AnnotationIntrospector introspector =
          new JaxbAnnotationIntrospector(mapper.getTypeFactory());
      {
        mapper.setAnnotationIntrospector(introspector);
        generator.writeStartObject();
        generator.writeStringField("id", pool.meta().id());
        generator.writeStringField("author", pool.meta().author());
        generator.writeStringField("source", pool.meta().source());
        generator.writeNumberField("created", pool.meta().created().getTime());
        generator.writeStringField("type", pool.meta().type().getCanonicalName());
        generator.writeEndObject();
        generator.flush();
      }

      out.append('\t');
      {
        generator.setCodec(mapper);
        generator.writeStartArray();

        for (int i = 0; i < pool.size(); i++) {
          generator.writeObject(pool.data().at(i));
        }
        generator.writeEndArray();
        generator.flush();
      }
      out.append('\n');
    }

    for (int i = 0; i < pool.fcount(); i++) { // features
      out.append("feature").append('\t');
      writeFeature(out, jsonFactory, pool.fmeta(i), pool.fdata(i));
    }

    for (int i = 0; i < pool.tcount(); i++) { // targets
      out.append("target").append('\t');
      writeFeature(out, jsonFactory, pool.tmeta(i), pool.tdata(i));
    }
    out.flush();
    generator.close();
  }

  private static void writeFeature(final Writer out, final JsonFactory jsonFactory,
                                   final PoolFeatureMeta meta, Seq values) throws IOException {
    {
      final StringWriter writer = new StringWriter();
      final JsonGenerator generator = jsonFactory.createGenerator(writer);
      generator.writeStartObject();
      generator.writeStringField("id", meta.id());
      generator.writeStringField("description", meta.description());
      generator.writeStringField("type", meta.type().name());
      generator.writeStringField("associated", meta.associated().meta().id());
      generator.writeEndObject();
      generator.close();
      out.append(writer.getBuffer());
    }
    out.append('\t');
    out.append(SERIALIZATION.write(values));
    out.append('\n');
  }


  public static  Pool readPoolFrom(final Reader input) throws IOException {
    try {
      final PoolBuilder builder = new PoolBuilder();
      final ReaderChopper chopper = new ReaderChopper(input);
      CharSequence name;
      while ((name = chopper.chop('\t')) != null) {
        if (name.length() == 0)
          continue;
        final JsonParser parser = JSONTools.parseJSON(chopper.chop('\t'));

        switch (name.toString()) {
          case "items": {
            final JsonDataSetMeta meta = parser.readValueAs(JsonDataSetMeta.class);
            builder.setMeta(meta);

            final JsonParser parseItems = JSONTools.parseJSON(chopper.chop('\n'));
            final ObjectMapper mapper = (ObjectMapper) parseItems.getCodec();
            mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
            final CollectionType itemsGroupType = mapper.getTypeFactory().constructCollectionType(List.class, meta.type());
            final List myObjects = mapper.readValue(parseItems, itemsGroupType);
            for (int i = 0; i < myObjects.size(); i++) {
              builder.addItem(myObjects.get(i));
            }
            break;
          }
          case "feature": {
            final JsonFeatureMeta fmeta = parser.readValueAs(JsonFeatureMeta.class);
            final Class> vecClass = fmeta.type().clazz();
            builder.newFeature(fmeta, SERIALIZATION.read(
                chopper.chop('\n'),
                vecClass));
            break;
          }
          case "target": {
            final JsonTargetMeta fmeta = parser.readValueAs(JsonTargetMeta.class);
            final Class> vecClass = fmeta.type().clazz();
            builder.newTarget(fmeta, SERIALIZATION.read(
                chopper.chop('\n'),
                vecClass));
            break;
          }
        }
      }
      //noinspection unchecked
      return (Pool) builder.create();
    }
    catch (RuntimeException e) {
      if (e.getCause() instanceof IOException) {
        throw (IOException) e.getCause();
      }
      throw e;
    }
  }

  public static Pool loadFromFile(final String fileName) throws IOException {
    return loadFromFile(new File(fileName));
  }

  public static Reader gzipOrFileReader(final File file) throws IOException {
    return file.getName().endsWith(".gz") ?
        new InputStreamReader(new GZIPInputStream(new FileInputStream(file))) :
        new FileReader(file);
  }

  public static Pool loadFromFile(final File file) throws IOException {
    try (final Reader input = gzipOrFileReader(file)) {
      return readPoolFrom(input);
    }
  }

  public static > Pair createSubset(final VecDataSet sourceDS, final S sourceTarget, final int[] idxs) {
    final VecDataSet subSet = new VecDataSetImpl(
        new VecBasedMx(
            sourceDS.xdim(),
            new IndexTransVec(sourceDS.data(),
                new RowsPermutation(
                    idxs,
                    sourceDS.xdim()
                )
            )
        ),
        sourceDS
    );
    final S subTarget = (S) ArrayTools.cut((Seq) sourceTarget, idxs);
    return Pair.create(subSet, subTarget);
  }

  public static String getPoolInfo(final Pool pool) {
    final VecDataSet vecDataSet = pool.vecData();

    final StringBuilder builder = new StringBuilder()
        .append("Pool size = ").append(pool.size())
        .append("\n")
        .append("VecDS features count = ").append(vecDataSet.xdim())
        .append("\n");
    for (int i = 0; i < pool.fcount(); i++) {
      builder
          .append("\n")
          .append("feature #").append(i)
          .append(": type = ").append(pool.fmeta(i).type());
    }
    return builder.toString();
  }

  public static Pool loadFromLibSvmFormat(final Reader in) throws IOException {
    final MutableInt poolFeaturesCount = new MutableInt(-1);

    final VecBuilder targetBuilder = new VecBuilder();
    final List> features = new ArrayList<>();

    CharSeqTools.processLines(in, new Consumer() {
      int lindex = 0;

      @Override
      public void accept(final CharSequence arg) {
        final CharSequence[] parts = CharSeqTools.split(arg, ' ');

        targetBuilder.add(CharSeqTools.parseDouble(parts[0]));

        final TIntList rowIndices = new TIntLinkedList();
        final TDoubleList rowValues = new TDoubleLinkedList();

        for (int i = 1; i < parts.length; i++) {
          final CharSequence indexAndValue = parts[i];
          if (StringUtils.isBlank(indexAndValue)) {
            continue;
          }

          final CharSequence[] split = CharSeqTools.split(indexAndValue, ':');
          final int index = CharSeqTools.parseInt(split[0]);
          final double value = CharSeqTools.parseDouble(split[1]);
          rowIndices.add(index);
          rowValues.add(value);

          if (poolFeaturesCount.intValue() < index + 1) {
            poolFeaturesCount.setValue(index + 1);
          }
        }
        features.add(Pair.create(rowIndices, rowValues));
        lindex++;
      }
    });

    final MxBuilder mxBuilder = new MxByRowsBuilder();
    for (Pair pair : features) {
      mxBuilder.add(new SparseVec(poolFeaturesCount.intValue(), pair.getFirst().toArray(), pair.getSecond().toArray()));
    }
    final Mx dataMx = mxBuilder.build();

    return FakePool.create(dataMx, targetBuilder.build());
  }

  public static void writePoolInLibfmFormat(final Pool pool, final Writer out) throws IOException {
    final Mx data = pool.vecData().data();
    final Vec target = pool.target(L2.class).target;
    for (int i = 0; i < pool.size(); i++) {
      final double t = target.get(i);
      out.append(String.valueOf(t));
      final VecIterator vecIterator = data.row(i).nonZeroes();
      while (vecIterator.advance()) {
        out.append("\t")
            .append(String.valueOf(vecIterator.index()))
            .append(":")
            .append(String.valueOf(vecIterator.value()));
      }
      out.append("\n");
    }
    out.flush();
  }

  public static void writeClassicPoolTo(final Pool pool, final String fileName) throws IOException {
    DataTools.writeClassicPoolTo(pool, new BufferedWriter(new FileWriter(fileName)));
  }

  public static void writeClassicPoolTo(final Pool pool, final Writer out) throws IOException {
    final DecimalFormat preciseFormatter = new DecimalFormat("###.########", new DecimalFormatSymbols(Locale.US));

    final Mx vecData = pool.vecData().data();
    final Vec target = pool.target(L2.class).target;

    for (int i = 0; i < vecData.rows(); i++) {
      out.write(String.format("%d\t%s\turl\t0", i, preciseFormatter.format(target.get(i))));
      for (int j = 0; j < vecData.columns(); j++) {
        out.append("\t").append(preciseFormatter.format(vecData.get(i, j)));
      }
      out.write("\n");
    }
    out.flush();
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy