All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.config.SDValue Maven / Gradle / Ivy

The newest version!
/*
 *  ******************************************************************************
 *  *
 *  *
 *  * This program and the accompanying materials are made available under the
 *  * terms of the Apache License, Version 2.0 which is available at
 *  * https://www.apache.org/licenses/LICENSE-2.0.
 *  *
 *  *  See the NOTICE file distributed with this work for additional
 *  *  information regarding copyright ownership.
 *  * Unless required by applicable law or agreed to in writing, software
 *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 *  * License for the specific language governing permissions and limitations
 *  * under the License.
 *  *
 *  * SPDX-License-Identifier: Apache-2.0
 *  *****************************************************************************
 */

package org.nd4j.autodiff.samediff.config;

import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.concurrent.atomic.AtomicLong;
import java.util.*;
import org.nd4j.autodiff.samediff.internal.IDependeeGroup;

/**
 * An SDValue represents a value that can be passed in
 * and returned from a {@link org.nd4j.autodiff.samediff.SameDiff}
 * graph for execution.
 *
 * @author Adam Gibson
 */
@Getter
@EqualsAndHashCode
public class SDValue implements IDependeeGroup {

    private SDValueType sdValueType;
    private INDArray tensorValue;
    private Map dictValue;
    private List listValue;
    private static final AtomicLong counter = new AtomicLong(0);
    protected transient long id = counter.getAndIncrement();

    private SDValue() {
    }

    public long getId() {
        return id;
    }

    public Collection getCollection() {
        return getListValue();
    }

    /**
     * Create an empty value for the given
     * {@link DataType}
     * 
     * @param valueType the value type to create {@link SDValue} for
     * @param dataType  the data type of the empty value
     * @return an empty ({@link Nd4j#empty(DataType)} for {@link SDValueType#TENSOR}
     *         or an empty list or map for the other associated types
     */
    public static SDValue empty(SDValueType valueType, DataType dataType) {
        switch (valueType) {
            case LIST:
                return SDValue.create(Arrays.asList());
            case DICT:
                return SDValue.create(Collections.emptyMap());
            case TENSOR:
                return SDValue.create(Nd4j.zeros(1).castTo(dataType));
            default:
                throw new IllegalArgumentException("Unable to create empty value, unknown value type " + valueType);
        }
    }

    /**
     * Return an {@link INDArray}
     * if the value type is {@link SDValueType#LIST}
     * and the number of elements is 1 otherwise
     * return the {@link #tensorValue}
     * 
     * @return
     */
    public INDArray getTensorValue() {
        if (listValue != null && listValue.size() == 1)
            return listValue.get(0);
        return tensorValue;
    }

    /**
     * Return an {@link INDArray[]}
     * if the value type is {@link SDValueType#TENSOR}
     * else return the list type
     * 
     * @return
     */
    public List getListValue() {
        if (tensorValue != null)
            return Arrays.asList(tensorValue);
        return listValue;
    }

    /**
     * Wrap an {@link INDArray} in a tensor
     * with an {@link SDValueType#TENSOR} type
     * 
     * @param inputValue the input value for the {@link SDValue}
     * @return the created value
     */
    public static SDValue create(INDArray inputValue) {
        SDValue sdValue = new SDValue();
        sdValue.tensorValue = inputValue;
        sdValue.sdValueType = SDValueType.TENSOR;
        return sdValue;
    }

    /**
     * Wrap an {@link INDArray[]} in a value
     * with an {@link SDValueType#LIST} type
     * 
     * @param inputValue the input value
     * @return the created value
     */
    public static SDValue create(Collection inputValue) {
        SDValue sdValue = new SDValue();
        sdValue.listValue = (List) inputValue;
        sdValue.sdValueType = SDValueType.LIST;
        return sdValue;
    }

    /**
     * Wrap an {@link INDArray[]} in a value
     * with an {@link SDValueType#LIST} type
     * 
     * @param inputValue the input value
     * @return the created value
     */
    public static SDValue create(List inputValue) {
        SDValue sdValue = new SDValue();
        sdValue.listValue = inputValue;
        sdValue.sdValueType = SDValueType.LIST;
        return sdValue;
    }

    /**
     * Wrap an {@link Map} in a value
     * with an {@link SDValueType#DICT} type
     * 
     * @param inputValue the input value
     * @return the created value
     */
    public static SDValue create(Map inputValue) {
        SDValue sdValue = new SDValue();
        sdValue.dictValue = inputValue;
        sdValue.sdValueType = SDValueType.DICT;
        return sdValue;
    }

    @Override
    public boolean equals(Object o) {
        SDValue sd = (SDValue) o;
        return sd.getId() == this.getId();
    }

    @Override
    public int hashCode() {
        return Long.hashCode(this.getId());
    }

    @Override
    public String toString() {
        INDArray h = this.getTensorValue();
        StringBuilder st = new StringBuilder();
        if (h != null) {
            st.append("--sdValueId-");
            st.append(this.getId() + "--key--" + this.getSdValueType() + " --Array " + h.getId());
        } else {

            List listx = this.getListValue();
            if (listx != null && listx.size() > 0) {
                st.append("--sdValueId-");
                st.append(this.getId() + "--key--" + this.getSdValueType() + " -- List Size " + listx.size());
                for (INDArray gh : this.getListValue()) {
                    if (gh == null) {
                        st.append(" --Array NULL ");
                    } else {
                        st.append(" --Array " + gh.getId() + " --\t ");
                    }

                }
            }
        }
        return st.toString();
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy