All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.gengoai.apollo.ml.model.embedding.Glove Maven / Gradle / Ivy

There is a newer version: 2.1
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package com.gengoai.apollo.ml.model.embedding;

import com.gengoai.ParameterDef;
import com.gengoai.Stopwatch;
import com.gengoai.apollo.math.linalg.DenseMatrix;
import com.gengoai.apollo.ml.DataSet;
import com.gengoai.apollo.ml.Datum;
import com.gengoai.apollo.ml.model.Params;
import com.gengoai.apollo.ml.observation.Observation;
import com.gengoai.apollo.ml.observation.Sequence;
import com.gengoai.collection.counter.Counter;
import com.gengoai.collection.counter.Counters;
import com.gengoai.tuple.IntPair;
import lombok.NonNull;
import lombok.extern.java.Log;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;

import static com.gengoai.LogUtils.logInfo;
import static com.gengoai.function.Functional.with;

/**
 * 

Implementation of Glove as defined in:

*
 *  Jeffrey Pennington, Richard Socher, and Christopher D. Manning. 2014. GloVe: Global Vectors for Word Representation.
 * 
* * @author David B. Bracewell */ @Log public class Glove extends TrainableWordEmbedding { private static final long serialVersionUID = 1L; public static final ParameterDef alpha = ParameterDef.doubleParam("alpha"); public static final ParameterDef xMax = ParameterDef.intParam("xMax"); /** * Instantiates a new Glove with default Parameters. */ public Glove() { super(new Parameters()); } /** * Instantiates a new Glove model with the given Glove parameters. * * @param parameters the parameters */ public Glove(@NonNull Glove.Parameters parameters) { super(parameters); } /** * Instantiates a new Glove with the given Parameter updater. * * @param updater method to update the model parameters */ public Glove(@NonNull Consumer updater) { super(with(new Parameters(), updater)); } @Override public void estimate(@NonNull DataSet dataset) { Stopwatch sw = Stopwatch.createStarted(); dataset = dataset.cache(); double size = dataset.size(); final AtomicLong processed = new AtomicLong(0); Counter counts = Counters.newCounter(); vectorStore = new InMemoryVectorStore(parameters.dimension.value(), parameters.unknownWord.value(), parameters.specialWords.value()); for(Datum datum : dataset) { datum.stream(parameters.inputs.value()).forEach(source -> { Sequence input = source.asSequence(); List ids = toIndices(input); for(int i = 1; i < ids.size(); i++) { int iW = ids.get(i); for(int j = Math.max(0, i - parameters.windowSize.value()); j < i; j++) { int jW = ids.get(j); double incrementBy = 1.0 / (i - j); counts.increment(IntPair.of(iW, jW), incrementBy); counts.increment(IntPair.of(jW, iW), incrementBy); } } double cnt = processed.incrementAndGet(); if(cnt % 1000 == 0) { if(parameters.verbose.value()) { logInfo(log, "processed {0}", (100 * cnt / size)); } } }); } sw.stop(); if(parameters.verbose.value()) { logInfo(log, "Cooccurrence Matrix computed in {0}", sw); } List cooccurrences = new ArrayList<>(); counts.forEach((e, v) -> cooccurrences.add(new Cooccurrence(e.v1, e.v2, v))); counts.clear(); DoubleMatrix[] W = new DoubleMatrix[vectorStore.size() * 2]; DoubleMatrix[] gradSq = new DoubleMatrix[vectorStore.size() * 2]; for(int i = 0; i < vectorStore.size() * 2; i++) { W[i] = DoubleMatrix.rand(parameters.dimension.value()).sub(0.5f).divi(parameters.dimension.value()); gradSq[i] = DoubleMatrix.ones(parameters.dimension.value()); } DoubleMatrix biases = DoubleMatrix.rand(vectorStore.size() * 2).sub(0.5f).divi(parameters.dimension.value()); DoubleMatrix gradSqBiases = DoubleMatrix.ones(vectorStore.size() * 2); int vocabLength = vectorStore.size(); for(int itr = 0; itr < parameters.maxIterations.value(); itr++) { double globalCost = 0d; Collections.shuffle(cooccurrences); for(Cooccurrence cooccurrence : cooccurrences) { int iWord = cooccurrence.word1; int iContext = cooccurrence.word2 + vocabLength; double count = cooccurrence.count; DoubleMatrix v_main = W[iWord]; double b_main = biases.get(iWord); DoubleMatrix gradsq_W_main = gradSq[iWord]; double gradsq_b_main = gradSqBiases.get(iWord); DoubleMatrix v_context = W[iContext]; double b_context = biases.get(iContext); DoubleMatrix gradsq_W_context = gradSq[iContext]; double gradsq_b_contenxt = gradSqBiases.get(iContext); double diff = v_main.dot(v_context) + b_main + b_context - Math.log(count); double fdiff = count > parameters.xMax.value() ? diff : Math.pow(count / parameters.xMax.value(), parameters.alpha.value()) * diff; globalCost += 0.5 * fdiff * diff; fdiff *= parameters.learningRate.value(); //Gradients for word vector terms DoubleMatrix grad_main = v_context.mmul(fdiff); DoubleMatrix grad_context = v_main.mmul(fdiff); v_main.subi(grad_main.divi(MatrixFunctions.sqrt(gradsq_W_main))); v_context.subi(grad_context.divi(MatrixFunctions.sqrt(gradsq_W_context))); gradsq_W_main.addi(MatrixFunctions.pow(grad_context, 2)); gradsq_W_context.addi(MatrixFunctions.pow(grad_main, 2)); biases.put(iWord, b_main - fdiff / Math.sqrt(gradsq_b_main)); biases.put(iContext, b_context - fdiff / Math.sqrt(gradsq_b_contenxt)); fdiff *= fdiff; gradSqBiases.put(iWord, gradSqBiases.get(iWord) + fdiff); gradSqBiases.put(iContext, gradSqBiases.get(iContext) + fdiff); } if(parameters.verbose.value()) { logInfo(log, "Iteration: {0}, cost:{1}", (itr + 1), globalCost / cooccurrences.size()); } } for(int i = 0; i < vocabLength; i++) { W[i].addi(W[i + vocabLength]); String k = vectorStore.decode(i); vectorStore.updateVector(i, new DenseMatrix(W[i]).setLabel(k).T()); } } private List toIndices(Sequence sequence) { List out = new ArrayList<>(); for(Observation example : sequence) { example.getVariableSpace() .forEach(v -> out.add(vectorStore.addOrGetIndex(parameters.nameSpace.value().getName(v)))); } return out; } private static class Cooccurrence { public final double count; public final int word1; public final int word2; public Cooccurrence(int word1, int word2, double count) { this.word1 = word1; this.word2 = word2; this.count = count; } }//END OF Cooccurrence /** * Fit parameters for Glove models */ public static class Parameters extends WordEmbeddingFitParameters { /** * Controls the the normalization of cooccurrence counts to probabilities (default 0.75). */ public final Parameter alpha = parameter(Glove.alpha, 0.75); /** * The learning rate (default 0.05) */ public final Parameter learningRate = parameter(Params.Optimizable.learningRate, 0.05); /** * The cutoff value for cooccurrence counts (default 100). */ public final Parameter xMax = parameter(Glove.xMax, 100); /** * The maximum number of iterations to train the model (default 25). */ public final Parameter maxIterations = parameter(Params.Optimizable.maxIterations, 25); } }//END OF Glove




© 2015 - 2025 Weber Informatics LLC | Privacy Policy