
org.deeplearning4j.spark.models.embeddings.glove.Glove Maven / Gradle / Ivy
The newest version!
/*
*
* * Copyright 2015 Skymind,Inc.
* *
* * Licensed under the Apache License, Version 2.0 (the "License");
* * you may not use this file except in compliance with the License.
* * You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.deeplearning4j.spark.models.embeddings.glove;
import org.apache.commons.math3.util.FastMath;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.CounterMap;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCalculator;
import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCounts;
import org.deeplearning4j.spark.text.functions.TextPipeline;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.AdaGrad;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.atomic.AtomicLong;
import static org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables.*;
/**
* Spark glove
*
* @author Adam Gibson
*/
public class Glove implements Serializable {
private Broadcast> vocabCacheBroadcast;
private String tokenizerFactoryClazz = DefaultTokenizerFactory.class.getName();
private boolean symmetric = true;
private int windowSize = 15;
private int iterations = 300;
private static Logger log = LoggerFactory.getLogger(Glove.class);
/**
*
* @param tokenizerFactoryClazz the fully qualified class name of the tokenizer
* @param symmetric whether the co occurrence counts should be symmetric
* @param windowSize the window size for co occurrence
* @param iterations the number of iterations
*/
public Glove(String tokenizerFactoryClazz, boolean symmetric, int windowSize, int iterations) {
this.tokenizerFactoryClazz = tokenizerFactoryClazz;
this.symmetric = symmetric;
this.windowSize = windowSize;
this.iterations = iterations;
}
/**
*
* @param symmetric whether the co occurrence counts should be symmetric
* @param windowSize the window size for co occurrence
* @param iterations the number of iterations
*/
public Glove(boolean symmetric, int windowSize, int iterations) {
this.symmetric = symmetric;
this.windowSize = windowSize;
this.iterations = iterations;
}
private Pair update(
AdaGrad weightAdaGrad
,AdaGrad biasAdaGrad
,INDArray syn0
,INDArray bias
,VocabWord w1
,INDArray wordVector
,INDArray contextVector
,double gradient) {
//gradient for word vectors
INDArray grad1 = contextVector.mul(gradient);
INDArray update = weightAdaGrad.getGradient(grad1,w1.getIndex(),syn0.shape());
wordVector.subi(update);
double w1Bias = bias.getDouble(w1.getIndex());
double biasGradient = biasAdaGrad.getGradient(gradient,w1.getIndex(),bias.shape());
double update2 = w1Bias - biasGradient;
bias.putScalar(w1.getIndex(),bias.getDouble(w1.getIndex()) - update2);
return new Pair<>(update,update2);
}
/**
* Train on the corpus
* @param rdd the rdd to train
* @return the vocab and weights
*/
public Pair,GloveWeightLookupTable> train(JavaRDD rdd) throws Exception{
// Each `train()` can use different parameters
final JavaSparkContext sc = new JavaSparkContext(rdd.context());
final SparkConf conf = sc.getConf();
final int vectorLength = assignVar(VECTOR_LENGTH, conf, Integer.class);
final boolean useAdaGrad = assignVar(ADAGRAD, conf, Boolean.class);
final double negative = assignVar(NEGATIVE, conf, Double.class);
final int numWords = assignVar(NUM_WORDS, conf, Integer.class);
final int window = assignVar(WINDOW, conf, Integer.class);
final double alpha = assignVar(ALPHA, conf, Double.class);
final double minAlpha = assignVar(MIN_ALPHA, conf, Double.class);
final int iterations = assignVar(ITERATIONS, conf, Integer.class);
final int nGrams = assignVar(N_GRAMS, conf, Integer.class);
final String tokenizer = assignVar(TOKENIZER, conf, String.class);
final String tokenPreprocessor = assignVar(TOKEN_PREPROCESSOR, conf, String.class);
final boolean removeStop = assignVar(REMOVE_STOPWORDS, conf, Boolean.class);
Map tokenizerVarMap = new HashMap() {{
put("numWords", numWords);
put("nGrams", nGrams);
put("tokenizer", tokenizer);
put("tokenPreprocessor", tokenPreprocessor);
put("removeStop", removeStop);
}};
Broadcast
© 2015 - 2025 Weber Informatics LLC | Privacy Policy