All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.deeplearning4j.iterator.bert.BertMaskedLMMasker Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.deeplearning4j.iterator.bert;

import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class BertMaskedLMMasker implements BertSequenceMasker {
    public static final double DEFAULT_MASK_PROB = 0.15;
    public static final double DEFAULT_MASK_TOKEN_PROB = 0.8;
    public static final double DEFAULT_RANDOM_WORD_PROB = 0.1;

    protected final Random r;
    protected final double maskProb;
    protected final double maskTokenProb;
    protected final double randomTokenProb;

    /**
     * Create a BertMaskedLMMasker with all default probabilities
     */
    public BertMaskedLMMasker(){
        this(new Random(), DEFAULT_MASK_PROB, DEFAULT_MASK_TOKEN_PROB, DEFAULT_RANDOM_WORD_PROB);
    }

    /**
     * See: {@link BertMaskedLMMasker} for details.
     * @param r                 Random number generator
     * @param maskProb          Probability of masking each token
     * @param maskTokenProb     Probability of replacing a selected token with the mask token
     * @param randomTokenProb    Probability of replacing a selected token with a random token
     */
    public BertMaskedLMMasker(Random r, double maskProb, double maskTokenProb, double randomTokenProb){
        Preconditions.checkArgument(maskProb > 0 && maskProb < 1, "Probability must be beteen 0 and 1, got %s", maskProb);
        Preconditions.checkState(maskTokenProb >=0 && maskTokenProb <= 1.0, "Mask token probability must be between 0 and 1, got %s", maskTokenProb);
        Preconditions.checkState(randomTokenProb >=0 && randomTokenProb <= 1.0, "Random token probability must be between 0 and 1, got %s", randomTokenProb);
        Preconditions.checkState(maskTokenProb + randomTokenProb <= 1.0, "Sum of maskTokenProb (%s) and randomTokenProb (%s) must be <= 1.0, got sum is %s",
                maskTokenProb, randomTokenProb, (maskTokenProb + randomTokenProb));
        this.r = r;
        this.maskProb = maskProb;
        this.maskTokenProb = maskTokenProb;
        this.randomTokenProb = randomTokenProb;
    }

    @Override
    public Pair,boolean[]> maskSequence(List input, String maskToken, List vocabWords) {
        List out = new ArrayList<>(input.size());
        boolean[] masked = new boolean[input.size()];
        for(int i=0; i(out, masked);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy