All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.api.ops.OpContext Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.linalg.api.ops;

import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;

import java.util.List;

public interface OpContext extends AutoCloseable {

    /**
     * This method sets integer arguments required for operation
     * @param arguments
     */
    void setIArguments(long... arguments);

    List getIArguments();

    int numIArguments();

    /**
     * This method sets floating point arguments required for operation
     * @param arguments
     */
    void setTArguments(double... arguments);
    List getTArguments();
    int numTArguments();

    /**
     * This method sets data type arguments required for operation
     * @param arguments
     */
    void setDArguments(DataType... arguments);
    List getDArguments();
    int numDArguments();

    /**
     * This method sets boolean arguments required for operation
     * @param arguments
     */
    void setBArguments(boolean... arguments);
    List getBArguments();
    int numBArguments();

    /**
     * This method sets root-level seed for rng
     * @param seed
     */
    void setRngStates(long rootState, long nodeState);

    /**
     * This method returns RNG states, root first node second
     * @return
     */
    Pair getRngStates();

    /**
     * This method adds INDArray as input argument for future op call
     *
     * @param index
     * @param array
     */
    void setInputArray(int index, INDArray array);

    /**
     * This method sets provided arrays as input arrays
     * @param arrays
     */
    void setInputArrays(List arrays);

    /**
     * This method sets provided arrays as input arrays
     * @param arrays
     */
    void setInputArrays(INDArray... arrays);

    /**
     * This method returns List of input arrays defined within this context
     * @return
     */
    List getInputArrays();

    int numInputArguments();

    INDArray getInputArray(int idx);

    /**
     * This method adds INDArray as output for future op call
     * @param index
     * @param array
     */
    void setOutputArray(int index, INDArray array);

    /**
     * This method sets provided arrays as output arrays
     * @param arrays
     */
    void setOutputArrays(List arrays);

    /**
     * This method sets provided arrays as output arrays
     * @param arrays
     */
    void setOutputArrays(INDArray... arrays);

    /**
     * This method returns List of output arrays defined within this context
     * @return
     */
    List getOutputArrays();

    INDArray getOutputArray(int i);

    int numOutputArguments();

    /**
     * This method returns pointer to context, to be used during native op execution
     * @return
     */
    Pointer contextPointer();

    /**
     * This method allows to set op as inplace
     * @param reallyInplace
     */
    void markInplace(boolean reallyInplace);

    /**
     * This method allows to enable/disable use of platform helpers within ops. I.e. mkldnn or cuDNN.
     * PLEASE NOTE: default value is True
     *
     * @param reallyAllow
     */
    void allowHelpers(boolean reallyAllow);

    /**
     * This methos allows to disape outputs validation via shape function
     * @param reallyOverride
     */
    void shapeFunctionOverride(boolean reallyOverride);

    /**
     * This method returns current execution mode for Context
     * @return
     */
    ExecutionMode getExecutionMode();

    /**
     * This method allows to set certain execution mode
     *
     * @param mode
     */
    void setExecutionMode(ExecutionMode mode);

    /**
     * This method removes all in/out arrays from this OpContext
     */
    void purge();


    /**
     * set context arguments
     * @param inputArrs
     * @param iArgs
     * @param dArgs
     * @param tArgs
     * @param bArgs
     */
    void setArgs(INDArray[] inputArrs, long[] iArgs, DataType[] dArgs, double[] tArgs, boolean[] bArgs);
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy