com.expleague.ml.data.tools.PoolByRowsBuilder Maven / Gradle / Ivy
package com.expleague.ml.data.tools;
import com.expleague.commons.func.Factory;
import com.expleague.commons.math.vectors.Vec;
import com.expleague.commons.math.vectors.impl.vectors.SparseVecBuilder;
import com.expleague.commons.math.vectors.impl.vectors.VecBuilder;
import com.expleague.commons.seq.*;
import com.expleague.commons.util.Holder;
import com.expleague.ml.Vectorization;
import com.expleague.ml.data.set.DataSet;
import com.expleague.ml.meta.DSItem;
import com.expleague.ml.meta.FeatureMeta;
import com.expleague.ml.meta.PoolFeatureMeta;
import com.expleague.ml.meta.TargetMeta;
import com.expleague.ml.meta.impl.*;
import com.expleague.ml.meta.impl.fake.FakeFeatureMeta;
import com.expleague.ml.meta.impl.fake.FakeTargetMeta;
import com.expleague.ml.meta.items.FakeItem;
import gnu.trove.map.hash.TObjectIntHashMap;
import java.lang.reflect.Array;
import java.util.*;
/**
* User: solar
* Date: 07.07.14
* Time: 12:55
*/
@SuppressWarnings("UnusedDeclaration")
public class PoolByRowsBuilder- implements Factory
> {
private JsonDataSetMeta meta = new JsonDataSetMeta();
private List- items = new ArrayList<>();
public PoolByRowsBuilder(Class
- type) {
meta.type = type;
meta.author = System.getProperty("user.name");
meta.created = new Date();
StackTraceElement[] stack = Thread.currentThread ().getStackTrace ();
StackTraceElement main = stack[stack.length - 1];
meta.source = main.getClassName();
meta.id = "Unknown pool";
}
@Override
public Pool
- create() {
//noinspection unchecked
return create((Class
- )meta.type());
}
public Pool
- create(final Class
- clazz) {
@SuppressWarnings("unchecked")
final LinkedHashMap
> features = new LinkedHashMap<>();
@SuppressWarnings({"unchecked", "SuspiciousToArrayCall"})
final Item[] items = this.items.toArray((Item[])Array.newInstance(clazz, this.items.size()));
final Pool- result = new Pool<>(meta, new ArraySeq<>(items), features);
final DataSet
- ds = result.data();
final Holder
> dataSet = Holder.create(null);
for (int i = 0; i < featureMetas.size(); i++) {
final FeatureMeta meta = featureMetas.get(i);
features.put(new JsonFeatureMeta(meta, ds.meta().id()), featureBuilders.get(i).build());
featureBuilders.set(i, createBuilderByMeta(meta)); // cleanup
}
for (int i = 0; i < targetMetas.size(); i++) {
final TargetMeta meta = targetMetas.get(i);
features.put(new JsonTargetMeta(meta, ds.meta().id()), targetBuilders.get(i).build());
targetBuilders.set(i, createBuilderByMeta(meta)); // cleanup
}
final Set itemIds = new HashSet<>();
for (final Item item : this.items) {
if (itemIds.contains(item.id()))
throw new RuntimeException(
"Contain duplicates! Id = " + item.id()
);
itemIds.add(toString());
}
this.items.clear();
return result;
}
public void setMeta(final JsonDataSetMeta meta) {
this.meta = meta;
}
// public void setItemType(DataSetMeta.ItemType type) {
// this.meta.type = type;
// }
public void addItem(final Item item, final Vectorization- vec) {
items.add(item);
if (featureBuilders.size() == 0) { //
for (int i = 0; i < vec.dim(); i++)
addFeature(vec.meta(i));
}
final Vec value = vec.value(item);
for (int i = 0; i < value.length(); i++) {
//noinspection unchecked
((SeqBuilder