org.deeplearning4j.models.embeddings.WeightLookupTable Maven / Gradle / Ivy
The newest version!
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.models.embeddings;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.core.ui.UiConnectionInfo;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.File;
import java.io.Serializable;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicLong;
/**
* General weight lookup table
*
* @author Adam Gibson
*/
public interface WeightLookupTable extends Serializable {
/**
* Returns unique ID of this table.
* Used for JointStorage/DistributedLookupTable mechanics
*
* @return ID of this table
*/
Long getTableId();
/**
* Set's table Id.
* Please note, it should be unique withing Joint/Distributed LookupTable
*
* @param tableId
*/
void setTableId(Long tableId);
/**
* The layer size for the lookup table
* @return the layer size for the lookup table
*/
int layerSize();
/**
* Returns gradient for specified word
* @param column
* @param gradient
* @return
*/
double getGradient(int column, double gradient);
/**
* Clear out all weights regardless
* @param reset
*/
void resetWeights(boolean reset);
/**
*
* @param codeIndex
* @param code
*/
void putCode(int codeIndex, INDArray code);
/**
* Loads the co-occurrences for the given codes
* @param codes the codes to load
* @return an ndarray of code.length by layerSize
*/
INDArray loadCodes(int[] codes);
/**
* Iterate on the given 2 vocab words
* @param w1 the first word to iterate on
* @param w2 the second word to iterate on
*/
@Deprecated
void iterate(T w1, T w2);
/**
* Iterate on the given 2 vocab words
* @param w1 the first word to iterate on
* @param w2 the second word to iterate on
* @param nextRandom nextRandom for sampling
* @param alpha the alpha to use for learning
*/
@Deprecated
void iterateSample(T w1, T w2, AtomicLong nextRandom, double alpha);
/**
* Inserts a word vector
* @param word the word to insert
* @param vector the vector to insert
*/
void putVector(String word, INDArray vector);
/**
*
* @param word
* @return
*/
INDArray vector(String word);
/**
* Reset the weights of the cache
*/
void resetWeights();
/**
* Sets the learning rate
* @param lr
*/
void setLearningRate(double lr);
/**
* Iterates through all of the vectors in the cache
* @return an iterator for all vectors in the cache
*/
Iterator vectors();
INDArray getWeights();
/**
* Returns corresponding vocabulary
*/
VocabCache getVocabCache();
}