org.deeplearning4j.models.embeddings.WeightLookupTable Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.models.embeddings;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.plot.BarnesHutTsne;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.File;
import java.io.Serializable;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicLong;
/**
* General weight lookup table
*
* @author Adam Gibson
*/
public interface WeightLookupTable extends Serializable {
/**
* Returns unique ID of this table.
* Used for JointStorage/DistributedLookupTable mechanics
*
* @return ID of this table
*/
Long getTableId();
/**
* Set's table Id.
* Please note, it should be unique withing Joint/Distributed LookupTable
*
* @param tableId
*/
void setTableId(Long tableId);
/**
* The layer size for the lookup table
* @return the layer size for the lookup table
*/
int layerSize();
/**
* Returns gradient for specified word
* @param column
* @param gradient
* @return
*/
double getGradient(int column, double gradient);
/**
* Clear out all weights regardless
* @param reset
*/
void resetWeights(boolean reset);
/**
* Render the words via TSNE
* @param tsne the tsne to use
*/
void plotVocab(BarnesHutTsne tsne, int numWords, UiConnectionInfo connectionInfo);
/**
* Render the words via TSNE
* @param tsne the tsne to use
*/
void plotVocab(BarnesHutTsne tsne, int numWords, File file);
/**
* Render the words via tsne
*/
void plotVocab(int numWords, UiConnectionInfo connectionInfo);
/**
* Render the words via tsne
*/
void plotVocab(int numWords, File file);
/**
*
* @param codeIndex
* @param code
*/
void putCode(int codeIndex, INDArray code);
/**
* Loads the co-occurrences for the given codes
* @param codes the codes to load
* @return an ndarray of code.length by layerSize
*/
INDArray loadCodes(int[] codes);
/**
* Iterate on the given 2 vocab words
* @param w1 the first word to iterate on
* @param w2 the second word to iterate on
*/
@Deprecated
void iterate(T w1, T w2);
/**
* Iterate on the given 2 vocab words
* @param w1 the first word to iterate on
* @param w2 the second word to iterate on
* @param nextRandom nextRandom for sampling
* @param alpha the alpha to use for learning
*/
@Deprecated
void iterateSample(T w1, T w2, AtomicLong nextRandom, double alpha);
/**
* Inserts a word vector
* @param word the word to insert
* @param vector the vector to insert
*/
void putVector(String word, INDArray vector);
/**
*
* @param word
* @return
*/
INDArray vector(String word);
/**
* Reset the weights of the cache
*/
void resetWeights();
/**
* Sets the learning rate
* @param lr
*/
void setLearningRate(double lr);
/**
* Iterates through all of the vectors in the cache
* @return an iterator for all vectors in the cache
*/
Iterator vectors();
INDArray getWeights();
/**
* Returns corresponding vocabulary
*/
VocabCache getVocabCache();
}