All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.latinolib.model.ModelUtils Maven / Gradle / Ivy

package org.latinolib.model;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import org.latinolib.SparseVector;
import org.latinolib.VectorEntry;

import java.util.Collections;
import java.util.List;

/**
 * Author saxo
 */
public class ModelUtils
{
    private ModelUtils() {
    }

    public static SparseVector cutLowWeights(SparseVector vec, double cutLowWgtPerc) {
        Preconditions.checkNotNull(vec);
        Preconditions.checkArgument(cutLowWgtPerc >= 0 && cutLowWgtPerc < 1);
        if (cutLowWgtPerc > 0) {
            double wgtSum = 0;
            List tmp = Lists.newArrayListWithCapacity(vec.size());
            for (VectorEntry item : vec) {
                wgtSum += item.getData();
                tmp.add(new WeightIndex(item.getData(), item.getIndex()));
            }
            Collections.sort(tmp);
            double cutSum = cutLowWgtPerc * wgtSum;
            double cutWgt = -1;
            for (WeightIndex item : tmp) {
                cutSum -= item.weight;
                if (cutSum <= 0) {
                    cutWgt = item.weight;
                    break;
                }
            }
            SparseVector newVec = new SparseVector();
            if (cutWgt != -1) {
                for (VectorEntry item : vec) {
                    if (item.getData() >= cutWgt) {
                        newVec.add(item.getIndex(), item.getData());
                    }
                }
            }
            return newVec;
        }
        return vec;
    }

    public static double getVecLenL2(SparseVector vec) {
        Preconditions.checkNotNull(vec);
        double len = 0;
        for (VectorEntry entry : vec) {
            len += entry.getData() * entry.getData();
        }
        return Math.sqrt(len);
    }

    public static boolean tryNrmVecL2(SparseVector vec) {
        double len = getVecLenL2(vec);
        if (len == 0) {
            return false;
        }
        for (VectorEntry entry : vec) {
            entry.setData(entry.getData() / len);
        }
        return true;
    }

    private static class WeightIndex implements Comparable
    {
        public final double weight;
        public final int index;

        private WeightIndex(double weight, int index) {
            this.weight = weight;
            this.index = index;
        }

        @Override
        public int compareTo(WeightIndex o) {
            return Double.compare(weight, o.weight);
        }
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy