All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.opstate.OpState Maven / Gradle / Ivy

package org.nd4j.autodiff.opstate;

import lombok.Builder;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.factory.Nd4j;

import java.io.Serializable;
import java.util.Map;
import java.util.UUID;

/**
 * Describes the type of operation that needs to happen
 * @author Adam Gibson
 */
@Data
@Builder
@EqualsAndHashCode
public class OpState implements Serializable {
    private long n;
    private OpType opType;
    private String opName;
    private Number scalarValue;
    private String[] vertexIds;
    private String id;
    private int[] axes;
    private Object[] extraArgs;
    private Object[] extraArgsWithoutInPlace;
    private NDArrayInformation result;


    /**
     * Creates an op state from
     * the given op.
     * @param op
     * @param arrToVertexID a map of {@link INDArray}
     *                      to vertex id (this map
     *                      is typically a reference map
     *                      {@link java.util.IdentityHashMap})
     * @return
     */
    public static OpState fromOp(Op op, Map arrToVertexID) {
        OpState opState = OpState.builder()
                .extraArgs(op.extraArgs())
                .n(op.n()).vertexIds(null)
                .id(UUID.randomUUID().toString())
                .opName(op.name()).vertexIds(new String[]{
                        String.valueOf(arrToVertexID.get(op.x()))
                        ,String.valueOf(arrToVertexID.get(op.y()))
                })
                .opType(opTypeFromOp(op))
                .build();
        NDArrayInformation ndArrayInformation = NDArrayInformation.newInfo(op.z());
        ndArrayInformation.setOwner(opState);
        opState.setResult(ndArrayInformation);
        return opState;
    }


    /**
     * Create an {@link OpType}
     * based on the type of {@link Op}
     * @param op the input op
     * @return the optype based on
     * the given op
     */
    public static OpType opTypeFromOp(Op op) {
       if(op instanceof ScalarOp)
           return OpType.SCALAR_TRANSFORM;
       else if(op instanceof Accumulation)
           return OpType.ACCUMULATION;
       else if(op instanceof IndexAccumulation)
           return OpType.INDEX_ACCUMULATION;
       else if(op instanceof GridOp)
           return OpType.AGGREGATE;
       else if(op instanceof TransformOp)
           return OpType.TRANSFORM;
       else if(op instanceof ShapeOp)
           return OpType.SHAPE;
       else if(op instanceof BroadcastOp)
           return OpType.BROADCAST;
       throw new IllegalStateException("Illegal op type " + op.getClass().getName());
    }

    public boolean isInPlace() {
        return getInPlace(extraArgs);
    }

    public Object[] getExtraArgs() {
        if(extraArgs == null)
            return null;
        if(extraArgsWithoutInPlace == null) {
            extraArgsWithoutInPlace = new Object[extraArgs.length - 1];
            int count = 0;
            for(int i = 0; i < extraArgs.length; i++) {
                if(!(extraArgs[i] instanceof Boolean))
                    extraArgsWithoutInPlace[count++] = extraArgs[i];
            }
        }
        return extraArgsWithoutInPlace;
    }

    public void setExtraArgs(Object[] extraArgs) {
        this.extraArgs = extraArgs;
    }

    protected boolean getInPlace(Object[] extraArgs) {
        if(extraArgs == null) {
            return false;
        }
        else {
            for(int i = 0; i < extraArgs.length; i++) {
                if(extraArgs[i] instanceof Boolean)
                    return (Boolean) extraArgs[i];
            }
        }

        return false;
    }

    public  enum OpType {
        SCALAR_TRANSFORM,
        ACCUMULATION,
        TRANSFORM,
        BROADCAST,
        INDEX_ACCUMULATION,
        AGGREGATE,
        SHAPE
    }



}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy