All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.expleague.ml.io.EnsembleModelConversionPack Maven / Gradle / Ivy

package com.expleague.ml.io;

import com.expleague.commons.func.types.ConversionDependant;
import com.expleague.commons.func.types.ConversionPack;
import com.expleague.commons.func.types.ConversionRepository;
import com.expleague.commons.func.types.TypeConverter;
import com.expleague.commons.math.Trans;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.vectors.ArrayVec;
import com.expleague.commons.seq.CharSeqTools;
import com.expleague.commons.util.Pair;
import com.expleague.ml.func.Ensemble;

import java.util.StringTokenizer;

/**
 * User: solar
 * Date: 12.08.13
 * Time: 17:16
 */
public class EnsembleModelConversionPack implements ConversionPack {
  public static abstract class BaseTo implements TypeConverter, ConversionDependant {
    private ConversionRepository repository;

    @Override
    public void setConversionRepository(final ConversionRepository repository) {
      this.repository = repository;
    }

    protected CharSequence convertModels(final F from) {
      final StringBuilder builder = new StringBuilder();
      builder.append(from.size());
      builder.append("\n\n");
      for (int i = 0; i < from.size(); i++) {
        final Trans model = from.models[i];
        builder.append(from.models[i].getClass().getCanonicalName()).append(" ");
        builder.append(from.weights.get(i)).append("\n");
        builder.append(repository.convert(model, CharSequence.class));
        builder.append("\n\n");
      }
      builder.delete(builder.length() - 1, builder.length());
      return builder;
    }
  }

  public static class To extends BaseTo {
    @Override
    public CharSequence convert(final Ensemble from) {
      return convertModels(from);
    }
  }

  public abstract static class BaseFrom implements TypeConverter, ConversionDependant {

    private ConversionRepository repository;

    @Override
    public void setConversionRepository(final ConversionRepository repository) {
      this.repository = repository;
    }
    protected Pair convertModels(CharSequence from) {
      if (from.toString().indexOf('\r') >= 0)
        from = from.toString().replace("\r", ""); // fix windows newlines created by GIT

      final CharSequence[] elements = CharSeqTools.split(from, "\n\n");
      final Trans[] models;
      final double[] weights;

      try {
        final int count = Integer.parseInt(elements[0].toString());
        models = new Trans[count];
        weights = new double[count];
        for (int i = 0; i < count; i++) {
          final CharSequence[] lines = CharSeqTools.split(elements[i + 1], "\n");
          final StringTokenizer tok = new StringTokenizer(lines[0].toString(), " ");
          final Class elementClass = (Class) Class.forName(tok.nextToken());
          weights[i] = Double.parseDouble(tok.nextToken());
          models[i] = repository.convert(elements[i + 1].subSequence(lines[0].length() + 1, elements[i + 1].length()), elementClass);
        }
      } catch (ClassNotFoundException e) {
        throw new RuntimeException("Element class not found!", e);
      }
      return Pair.create(models, (Vec) new ArrayVec(weights));
    }
  }

  public static class From extends BaseFrom {
    @Override
    public Ensemble convert(final CharSequence from) {
      final Pair pair = convertModels(from);
      final Trans[] models = pair.getFirst();
      final Vec weights = pair.getSecond();
      return new Ensemble(models, weights);
    }
  }

  @Override
  public Class to() {
    return To.class;
  }

  @Override
  public Class from() {
    return From.class;
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy