com.expleague.ml.methods.wrappers.MultiLabelWrapper Maven / Gradle / Ivy
package com.expleague.ml.methods.wrappers;
import com.expleague.commons.func.WeakListenerHolder;
import com.expleague.commons.func.impl.WeakListenerHolderImpl;
import com.expleague.commons.math.Trans;
import com.expleague.ml.data.set.VecDataSet;
import com.expleague.ml.data.tools.MultiLabelTools;
import com.expleague.ml.methods.VecOptimization;
import com.expleague.ml.models.multilabel.MultiLabelModel;
import com.expleague.ml.TargetFunc;
import java.lang.ref.WeakReference;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
/**
* User: qdeee
* Date: 03.04.15
*/
public class MultiLabelWrapper extends WeakListenerHolderImpl implements VecOptimization {
private final VecOptimization strong;
public MultiLabelWrapper(final VecOptimization strong) {
this.strong = strong;
}
@Override
public MultiLabelModel fit(final VecDataSet learn, final GlobalLoss targetFunc) {
List internListeners = new ArrayList<>();
if (strong instanceof WeakListenerHolder) {
for (WeakReference> externalListenerRef : listeners) {
final Consumer super Trans> externalListener = externalListenerRef.get();
if (externalListener != null) {
final Consumer internListener = externalListener::accept;
internListeners.add(internListener);
((WeakListenerHolder) strong).addListener(internListener);
}
}
}
final Trans model = strong.fit(learn, targetFunc);
return MultiLabelTools.extractMultiLabelModel(model);
}
}