All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.models.word2vec.VocabWord Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*
 *
 *  * Copyright 2015 Skymind,Inc.
 *  *
 *  *    Licensed under the Apache License, Version 2.0 (the "License");
 *  *    you may not use this file except in compliance with the License.
 *  *    You may obtain a copy of the License at
 *  *
 *  *        http://www.apache.org/licenses/LICENSE-2.0
 *  *
 *  *    Unless required by applicable law or agreed to in writing, software
 *  *    distributed under the License is distributed on an "AS IS" BASIS,
 *  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  *    See the License for the specific language governing permissions and
 *  *    limitations under the License.
 *
 */

package org.deeplearning4j.models.word2vec;

import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import com.google.common.util.concurrent.AtomicDouble;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;


/**
 * Intermediate layers of the neural network
 *
 * @author Adam Gibson
 */
public  class VocabWord implements Comparable,Serializable {

	private static final long serialVersionUID = 2223750736522624256L;
	//used in comparison when building the huffman tree
	private AtomicDouble wordFrequency = new AtomicDouble(0);
	private int index = -1;
	private List codes = new ArrayList<>();
	//for my sanity
	private String word;
	private INDArray historicalGradient;
	private List points = new ArrayList<>();
    private int codeLength = 0;


	public static VocabWord none() {
		return new VocabWord(0,"none");
	}

	/**
	 *
	 * @param wordFrequency count of the word

	 */
	public VocabWord(double wordFrequency,String word) {
		this.wordFrequency.set(wordFrequency);
		if(word == null || word.isEmpty())
			throw new IllegalArgumentException("Word must not be null or empty");
		this.word = word;

	}


	public VocabWord() {}


	public void write(DataOutputStream dos) throws IOException {
		dos.writeDouble(wordFrequency.get());

	}

	public VocabWord read(DataInputStream dos) throws IOException {
		this.wordFrequency.set(dos.readDouble());
		return this;
	}



	public String getWord() {
		return word;
	}

	public void setWord(String word) {
		this.word = word;
	}
	public void increment() {
		increment(1);
	}

	public void increment(int by) {
		wordFrequency.getAndAdd(by);
	}


	public int getIndex() {
		return index;
	}

	public void setIndex(int index) {
		this.index = index;
	}

	public double getWordFrequency() {
		if(wordFrequency == null)
            return 0.0;

        return wordFrequency.get();
	}

    public List getCodes() {
        return codes;
    }

    public void setCodes(List codes) {
        this.codes = codes;
    }

    @Override
	public int compareTo(VocabWord o) {
		return Double.compare(wordFrequency.get(), o.wordFrequency.get());
	}

	public double getGradient(int index, double g) {
		if(historicalGradient == null) {
			historicalGradient = Nd4j.zeros(getCodes().size());
		}

		double pow =  Math.pow(g,2);
		historicalGradient.putScalar(index, historicalGradient.getDouble(index) + pow);
		double sqrt =  FastMath.sqrt(historicalGradient.getDouble(index));
		double abs = FastMath.abs(g) / (sqrt + 1e-6f);
		double ret = abs * 1e-1f;
		return ret;

	}

    public List getPoints() {
        return points;
    }

    public void setPoints(List points) {
        this.points = points;
    }

    public int getCodeLength() {
        return codeLength;
    }

    public void setCodeLength(int codeLength) {
        this.codeLength = codeLength;
        if(codes.size() < codeLength) {
            for(int i = 0; i < codeLength; i++)
                codes.add(0);
        }

        if(points.size() < codeLength) {
            for(int i = 0; i < codeLength; i++)
                points.add(0);
        }
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;

        VocabWord vocabWord = (VocabWord) o;

        if (codeLength != vocabWord.codeLength) return false;
        if (index != vocabWord.index) return false;
        if (!codes.equals(vocabWord.codes)) return false;
        if (historicalGradient != null ? !historicalGradient.equals(vocabWord.historicalGradient) : vocabWord.historicalGradient != null)
            return false;
        if (!points.equals(vocabWord.points)) return false;
        if (!word.equals(vocabWord.word)) return false;
        return wordFrequency.get() == vocabWord.wordFrequency.get();

    }

    @Override
    public int hashCode() {
        int result = wordFrequency.hashCode();
        result = 31 * result + index;
        result = 31 * result + codes.hashCode();
        result = 31 * result + word.hashCode();
        result = 31 * result + (historicalGradient != null ? historicalGradient.hashCode() : 0);
        result = 31 * result + points.hashCode();
        result = 31 * result + codeLength;
        return result;
    }

    @Override
    public String toString() {
        return "VocabWord{" +
                "wordFrequency=" + wordFrequency +
                ", index=" + index +
                ", codes=" + codes +
                ", word='" + word + '\'' +
                ", historicalGradient=" + historicalGradient +
                ", points=" + points +
                ", codeLength=" + codeLength +
                '}';
    }


}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy