com.tencent.angel.sona.psf.embedding.w2v.SkipgramModel Maven / Gradle / Ivy
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/
package com.tencent.angel.sona.psf.embedding.w2v;
import io.netty.buffer.ByteBuf;
import it.unimi.dsi.fastutil.floats.FloatArrayList;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap;
import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
import it.unimi.dsi.fastutil.objects.ObjectIterator;
import java.util.Arrays;
import java.util.Random;
public class SkipgramModel extends EmbeddingModel {
public SkipgramModel(int dim, int negative, int window, int seed, int maxIndex, int numNodeOneRow, int maxLength, float[][] layers) {
super(dim, negative, window, seed, maxIndex, numNodeOneRow, maxLength, layers);
}
@Override
public float[] dot(int[][] sentences) {
Random windowSeed = new Random(seed);
Random negativeSeed = new Random(seed + 1);
FloatArrayList partialDots = new FloatArrayList();
IntOpenHashSet numInputs = new IntOpenHashSet();
IntOpenHashSet numOutputs = new IntOpenHashSet();
for (int s = 0; s < sentences.length; s++) {
int[] sen = sentences[s];
for (int position = 0; position < sen.length; position++) {
int word = sen[position];
// window size
int b = windowSeed.nextInt(window);
// Skip-Gram model
float[] inputs = layers[word / numNodeOneRow];
int l1 = (word % numNodeOneRow) * dim * 2;
numInputs.add(word);
// Accumulate the input vectors from context
for (int a = b; a < window * 2 + 1 - b; a++)
if (a != window) {
int c = position - window + a;
if (c < 0) continue;
if (c >= sen.length) continue;
int sentence_word = sen[c];
if (sentence_word == -1) continue;
// Negative sampling
int target;
for (int d = 0; d < negative + 1; d ++) {
if (d == 0) target = word;
else do{
target = negativeSeed.nextInt(maxIndex);
}while (target == word);
numOutputs.add(target);
float[] outputs = layers[target / numNodeOneRow];
int l2 = (target % numNodeOneRow) * dim * 2 + dim;
float f = 0.0f;
for (c = 0; c < dim; c++) f += inputs[l1 + c] * outputs[l2 + c];
partialDots.add(f);
}
}
}
}
this.numInputsToUpdate = numInputs.size();
this.numOutputsToUpdate = numOutputs.size();
return partialDots.toFloatArray();
}
@Override
public void adjust(int[][] sentences, ByteBuf buf, int numInputs, int numOutputs) {
int length = buf.readInt();
// used to accumulate the updates for input vectors
float[] neu1e = new float[dim];
float[] inputUpdates = new float[numInputs * dim];
float[] outputUpdates = new float[numOutputs * dim];
Int2IntOpenHashMap inputIndex = new Int2IntOpenHashMap();
Int2IntOpenHashMap outputIndex = new Int2IntOpenHashMap();
Int2IntOpenHashMap inputUpdateCounter = new Int2IntOpenHashMap();
Int2IntOpenHashMap outputUpdateCounter = new Int2IntOpenHashMap();
Random windowSeed = new Random(seed);
Random negativeSeed = new Random(seed + 1);
for (int s = 0; s < sentences.length; s++) {
int[] sen = sentences[s];
for (int position = 0; position < sen.length; position++) {
int word = sen[position];
float[] inputs = layers[word / numNodeOneRow];
int l1 = (word % numNodeOneRow) * dim * 2;
// window size
int b = windowSeed.nextInt(window);
Arrays.fill(neu1e, 0);
// skip-gram model
for (int a = b; a < window * 2 + 1 - b; a++)
if (a != window) {
int c = position - window + a;
if (c < 0) continue;
if (c >= sen.length) continue;
if (sen[c] == -1) continue;
// Negative sampling
int target;
for (int d = 0; d < negative + 1; d ++) {
if (d == 0) target = word;
else while (true) {
target = negativeSeed.nextInt(maxIndex);
if (target == word) continue;
else break;
}
float[] outputs = layers[target / numNodeOneRow];
int l2 = (target % numNodeOneRow) * dim * 2 + dim;
float g = buf.readFloat();
length --;
// accumulate for the hidden layer
for (c = 0; c < dim; c ++) neu1e[c] += g * outputs[c + l2];
// update output layer
merge(outputUpdates, outputIndex, target, inputs, g, l1);
outputUpdateCounter.addTo(target, 1);
}
// update the hidden layer
merge(inputUpdates, inputIndex, word, neu1e, 1, 0);
inputUpdateCounter.addTo(word, 1);
}
}
}
// update input
ObjectIterator it = inputIndex.int2IntEntrySet().fastIterator();
while (it.hasNext()) {
Int2IntMap.Entry entry = it.next();
int node = entry.getIntKey();
int offset = entry.getIntValue() * dim;
int divider = inputUpdateCounter.get(node);
int col = (node % numNodeOneRow) * dim * 2;
float[] values = layers[node / numNodeOneRow];
for (int a = 0; a < dim; a++) values[a + col] += inputUpdates[offset + a] / divider;
}
// update output
it = outputIndex.int2IntEntrySet().fastIterator();
while (it.hasNext()) {
Int2IntMap.Entry entry = it.next();
int node = entry.getIntKey();
int offset = entry.getIntValue() * dim;
int col = (node % numNodeOneRow) * dim * 2 + dim;
float[] values = layers[node / numNodeOneRow];
int divider = outputUpdateCounter.get(node);
for (int a = 0; a < dim; a++) values[a + col] += outputUpdates[offset + a] / divider;
}
assert length == 0;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy