All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.kotlinx.dl.api.inference.InferenceModel.kt Maven / Gradle / Ivy

/*
 * Copyright 2020-2022 JetBrains s.r.o. and Kotlin Deep Learning project contributors. All Rights Reserved.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
 */

package org.jetbrains.kotlinx.dl.api.inference

/**
 * The basic interface for all models which defines the basic functions required for inference tasks only.
 */
public interface InferenceModel : AutoCloseable {
    /** Input specification for this model. */
    public val inputDimensions: LongArray

    /**
     * Predicts the class of [inputData].
     *
     * @param [inputData] The single example with unknown label.
     * @return Predicted class index.
     */
    public fun predict(inputData: FloatArray): Int

    /**
     * Predicts vector of probabilities instead of specific class in [predict] method.
     *
     * @param [inputData] The single example with unknown vector of probabilities.
     * @param [predictionTensorName] The name of prediction tensor. It could be changed, if you need to get alternative outputs from model graph.
     * @return Vector that represents the probability distributions of a list of potential outcomes
     */
    public fun predictSoftly(inputData: FloatArray, predictionTensorName: String = ""): FloatArray

    /**
     * Chain-like setter to set up input shape.
     *
     * @param [dims] The input shape.
     */
    public fun reshape(vararg dims: Long)

    /**
     * Creates a copy.
     *
     * @param [copiedModelName] Set up this name to make a copy with a new name.
     * @return A copied inference model.
     */
    public fun copy(
        copiedModelName: String? = null,
        saveOptimizerState: Boolean = false,
        copyWeights: Boolean = true
    ): InferenceModel
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy