All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.github.chen0040.rl.utils.Vec Maven / Gradle / Ivy

package com.github.chen0040.rl.utils;

import lombok.Getter;
import lombok.Setter;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;


/**
 * Created by xschen on 9/27/2015 0027.
 */
@Getter
@Setter
public class Vec implements Serializable {
    private Map data = new HashMap();
    private int dimension;
    private double defaultValue;
    private int id = -1;

    public Vec(){

    }

    public Vec(double[] v){
        for(int i=0; i < v.length; ++i){
            set(i, v[i]);
        }
    }

    public Vec(int dimension){
        this.dimension = dimension;
        defaultValue = 0;
    }

    public Vec(int dimension, Map data){
        this.dimension = dimension;
        defaultValue = 0;

        for(Map.Entry entry : data.entrySet()){
            set(entry.getKey(), entry.getValue());
        }
    }

    public Vec makeCopy(){
        Vec clone = new Vec(dimension);
        clone.copy(this);
        return clone;
    }

    public void copy(Vec rhs){
        defaultValue = rhs.defaultValue;
        dimension = rhs.dimension;
        id = rhs.id;

        data.clear();
        for(Map.Entry entry : rhs.data.entrySet()){
            data.put(entry.getKey(), entry.getValue());
        }
    }

    public void set(int i, double value){
        if(value == defaultValue) return;

        data.put(i, value);
        if(i >= dimension){
            dimension = i+1;
        }
    }


    public double get(int i){
        return data.getOrDefault(i, defaultValue);
    }

    @Override
    public boolean equals(Object rhs){
        if(rhs != null && rhs instanceof Vec){
            Vec rhs2 = (Vec)rhs;
            if(dimension != rhs2.dimension){
                return false;
            }

            if(data.size() != rhs2.data.size()){
                return false;
            }

            for(Integer index : data.keySet()){
                if(!rhs2.data.containsKey(index)) return false;
                if(!DoubleUtils.equals(data.get(index), rhs2.data.get(index))){
                    return false;
                }
            }

            if(defaultValue != rhs2.defaultValue){
                for(int i=0; i < dimension; ++i){
                    if(data.containsKey(i)){
                        return false;
                    }
                }
            }

            return true;
        }

        return false;
    }

    public void setAll(double value){
        defaultValue = value;
        for(Integer index : data.keySet()){
            data.put(index, defaultValue);
        }
    }

    public IndexValue indexWithMaxValue(Set indices){
        if(indices == null){
            return indexWithMaxValue();
        }else{
            IndexValue iv = new IndexValue();
            iv.setIndex(-1);
            iv.setValue(Double.NEGATIVE_INFINITY);
            for(Integer index : indices){
                double value = data.getOrDefault(index, Double.NEGATIVE_INFINITY);
                if(value > iv.getValue()){
                    iv.setIndex(index);
                    iv.setValue(value);
                }
            }
            return iv;
        }
    }

    public IndexValue indexWithMaxValue(){
        IndexValue iv = new IndexValue();
        iv.setIndex(-1);
        iv.setValue(Double.NEGATIVE_INFINITY);


        for(Map.Entry entry : data.entrySet()){
            if(entry.getKey() >= dimension) continue;

            double value = entry.getValue();
            if(value > iv.getValue()){
                iv.setValue(value);
                iv.setIndex(entry.getKey());
            }
        }

        if(!iv.isValid()){
            iv.setValue(defaultValue);
        } else{
            if(iv.getValue() < defaultValue){
                for(int i=0; i < dimension; ++i){
                    if(!data.containsKey(i)){
                        iv.setValue(defaultValue);
                        iv.setIndex(i);
                        break;
                    }
                }
            }
        }

        return iv;
    }



    public Vec projectOrthogonal(Iterable vlist) {
        Vec b = this;
        for(Vec v : vlist)
        {
            b = b.minus(b.projectAlong(v));
        }

        return b;
    }

    public Vec projectOrthogonal(List vlist, Map alpha) {
        Vec b = this;
        for(int i = 0; i < vlist.size(); ++i)
        {
            Vec v = vlist.get(i);
            double norm_a = v.multiply(v);

            if (DoubleUtils.isZero(norm_a)) {
                return new Vec(dimension);
            }
            double sigma = multiply(v) / norm_a;
            Vec v_parallel = v.multiply(sigma);

            alpha.put(i, sigma);

            b = b.minus(v_parallel);
        }

        return b;
    }

    public Vec projectAlong(Vec rhs)
    {
        double norm_a = rhs.multiply(rhs);

        if (DoubleUtils.isZero(norm_a)) {
            return new Vec(dimension);
        }
        double sigma = multiply(rhs) / norm_a;
        return rhs.multiply(sigma);
    }

    public Vec multiply(double rhs){
        Vec clone = (Vec)this.makeCopy();
        for(Integer i : data.keySet()){
            clone.data.put(i, rhs * data.get(i));
        }
        return clone;
    }

    public double multiply(Vec rhs)
    {
        double productSum = 0;
        if(defaultValue == 0) {
            for (Map.Entry entry : data.entrySet()) {
                productSum += entry.getValue() * rhs.get(entry.getKey());
            }
        } else {
            for(int i=0; i < dimension; ++i){
                productSum += get(i) * rhs.get(i);
            }
        }

        return productSum;
    }

    public Vec pow(double scalar)
    {
        Vec result = new Vec(dimension);
        for (Map.Entry entry : data.entrySet())
        {
            result.data.put(entry.getKey(), Math.pow(entry.getValue(), scalar));
        }
        return result;
    }

    public Vec add(Vec rhs)
    {
        Vec result = new Vec(dimension);
        int index;
        for (Map.Entry entry : data.entrySet()) {
            index = entry.getKey();
            result.data.put(index, entry.getValue() + rhs.data.get(index));
        }
        for(Map.Entry entry : rhs.data.entrySet()){
            index = entry.getKey();
            if(result.data.containsKey(index)) continue;
            result.data.put(index, entry.getValue() + data.get(index));
        }

        return result;
    }

    public Vec minus(Vec rhs)
    {
        Vec result = new Vec(dimension);
        int index;
        for (Map.Entry entry : data.entrySet()) {
            index = entry.getKey();
            result.data.put(index, entry.getValue() - rhs.data.get(index));
        }
        for(Map.Entry entry : rhs.data.entrySet()){
            index = entry.getKey();
            if(result.data.containsKey(index)) continue;
            result.data.put(index, data.get(index) - entry.getValue());
        }

        return result;
    }

    public double sum(){
        double sum = 0;

        for(Map.Entry entry : data.entrySet()){
            sum += entry.getValue();
        }
        sum += defaultValue * (dimension - data.size());

        return sum;
    }

    public boolean isZero(){
        return DoubleUtils.isZero(sum());
    }

    public double norm(int level)
    {
        if (level == 1)
        {
            double sum = 0;
            for (Double val : data.values())
            {
                sum += Math.abs(val);
            }
            if(!DoubleUtils.isZero(defaultValue)) {
                sum += Math.abs(defaultValue) * (dimension - data.size());
            }
            return sum;
        }
        else if (level == 2)
        {
            double sum = multiply(this);
            if(!DoubleUtils.isZero(defaultValue)){
                sum += (dimension - data.size()) * (defaultValue * defaultValue);
            }
            return Math.sqrt(sum);
        }
        else
        {
            double sum = 0;
            for (Double val : this.data.values())
            {
                sum += Math.pow(Math.abs(val), level);
            }
            if(!DoubleUtils.isZero(defaultValue)) {
                sum += Math.pow(Math.abs(defaultValue), level) * (dimension - data.size());
            }
            return Math.pow(sum, 1.0 / level);
        }
    }

    public Vec normalize()
    {
        double norm = norm(2); // L2 norm is the cartesian distance
        if (DoubleUtils.isZero(norm))
        {
            return new Vec(dimension);
        }
        Vec clone = new Vec(dimension);
        clone.setAll(defaultValue / norm);

        for (Integer k : data.keySet())
        {
            clone.data.put(k, data.get(k) / norm);
        }
        return clone;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy