
org.deeplearning4j.models.glove.GloveWork Maven / Gradle / Ivy
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.models.glove;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.scaleout.perform.models.glove.GloveResult;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.AdaGrad;
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* Glove work
*
* @author Adam Gibson
*/
public class GloveWork implements Serializable {
private Map> vectors = new ConcurrentHashMap<>();
private List> coOccurrences;
private Map indexes = new ConcurrentHashMap<>();
private Map originalVectors = new ConcurrentHashMap<>();
private Map biases = new ConcurrentHashMap<>();
private Map adaGrads = new ConcurrentHashMap<>();
private Map biasAdaGrads = new ConcurrentHashMap<>();
public GloveWork(GloveWeightLookupTable table, List> coOccurrences) {
this.coOccurrences = coOccurrences;
for(Pair coOccurrence : coOccurrences) {
indexes.put(coOccurrence.getFirst().getIndex(),coOccurrence.getFirst());
indexes.put(coOccurrence.getSecond().getIndex(),coOccurrence.getSecond());
addWord(coOccurrence.getFirst(),table);
addWord(coOccurrence.getSecond(),table);
}
}
private void addWord(VocabWord word,GloveWeightLookupTable table) {
if(word == null)
throw new IllegalArgumentException("Word must not be null!");
indexes.put(word.getIndex(),word);
vectors.put(word.getWord(),new Pair<>(word,table.getSyn0().getRow(word.getIndex()).dup()));
originalVectors.put(word.getWord(),table.getSyn0().getRow(word.getIndex()).dup());
biases.put(word.getWord(),table.getBias().getDouble(word.getIndex()));
adaGrads.put(word.getWord(),table.getWeightAdaGrad().createSubset(word.getIndex()));
biasAdaGrads.put(word.getWord(),table.getBiasAdaGrad().createSubset(word.getIndex()));
}
public AdaGrad getBiasAdaGrad(String word) {
return biasAdaGrads.get(word);
}
public AdaGrad getAdaGrad(String word) {
return adaGrads.get(word);
}
public void updateBias(String word,double bias) {
biases.put(word,bias);
}
public GloveResult addDeltas() {
Map syn0Change = new HashMap<>();
for(Pair sentence : coOccurrences) {
VocabWord w1 = sentence.getFirst();
VocabWord w2 = sentence.getSecond();
syn0Change.put(w1.getWord(),vectors.get(w1.getWord()).getSecond().sub(originalVectors.get(w1.getWord())));
syn0Change.put(w2.getWord(),vectors.get(w2.getWord()).getSecond().sub(originalVectors.get(w2.getWord())));
}
return new GloveResult(syn0Change);
}
public double getBias(String word) {
return biases.get(word);
}
public List> getCoOccurrences() {
return coOccurrences;
}
public void setCoOccurrences(List> coOccurrences) {
this.coOccurrences = coOccurrences;
}
public Map> getVectors() {
return vectors;
}
public void setVectors(Map> vectors) {
this.vectors = vectors;
}
public Map getIndexes() {
return indexes;
}
public void setIndexes(Map indexes) {
this.indexes = indexes;
}
public Map getOriginalVectors() {
return originalVectors;
}
public void setOriginalVectors(Map originalVectors) {
this.originalVectors = originalVectors;
}
public Map getBiases() {
return biases;
}
public void setBiases(Map biases) {
this.biases = biases;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy