All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.djl.modality.nlp.generate.SearchConfig Maven / Gradle / Ivy

/*
 * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package ai.djl.modality.nlp.generate;

/**
 * {@code SearchConfig} is a class whose fields are parameters used for autoregressive search / text
 * generation.
 */
public class SearchConfig {

    private int k;
    private float alpha;
    private int beam;
    private int maxSeqLength;
    private long padTokenId;
    private long eosTokenId;
    private boolean suffixPadding;

    /** Constructs a new {@code ContrastiveSearchConfig} instance with default values. */
    public SearchConfig() {
        this.k = 4;
        this.alpha = 0.6f;
        this.beam = 3;
        this.maxSeqLength = 30;
        this.eosTokenId = 50256;
        this.padTokenId = 50256;
    }

    /**
     * Returns the value of the k.
     *
     * @return the value of k
     */
    public int getK() {
        return k;
    }

    /**
     * Sets the value for the topk choice.
     *
     * @param k the value for topk choice
     */
    public void setK(int k) {
        this.k = k;
    }

    /**
     * Returns the value of the alpha.
     *
     * @return the value of alpha
     */
    public float getAlpha() {
        return alpha;
    }

    /**
     * Sets the value of alpha the penalty for repetition.
     *
     * @param alpha the value of the penalty for repetition
     */
    public void setAlpha(float alpha) {
        this.alpha = alpha;
    }

    /**
     * Returns the value of the beam.
     *
     * @return the value of beam
     */
    public int getBeam() {
        return beam;
    }

    /**
     * Sets the value of beam size.
     *
     * @param beam the value of beam size
     */
    public void setBeam(int beam) {
        this.beam = beam;
    }

    /**
     * Returns the value of the maxSeqLength.
     *
     * @return the value of maxSeqLength
     */
    public int getMaxSeqLength() {
        return maxSeqLength;
    }

    /**
     * Sets the value of max sequence length.
     *
     * @param maxSeqLength the value max sequence length
     */
    public void setMaxSeqLength(int maxSeqLength) {
        this.maxSeqLength = maxSeqLength;
    }

    /**
     * Returns the value of the padTokenId.
     *
     * @return the value of padTokenId
     */
    public long getPadTokenId() {
        return padTokenId;
    }

    /**
     * Sets the value of padTokenId.
     *
     * @param padTokenId the token id for padding
     */
    public void setPadTokenId(long padTokenId) {
        this.padTokenId = padTokenId;
    }

    /**
     * Returns the value of the eosTokenId.
     *
     * @return the value of eosTokenId
     */
    public long getEosTokenId() {
        return eosTokenId;
    }

    /**
     * Returns the value of the suffixPadding.
     *
     * @return the value of suffixPadding
     */
    public boolean isSuffixPadding() {
        return suffixPadding;
    }

    /**
     * Sets the value of suffixPadding or rightPadding.
     *
     * @param suffixPadding whether the padding is from right
     */
    public void setSuffixPadding(boolean suffixPadding) {
        this.suffixPadding = suffixPadding;
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy