All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.autodiff.samediff.config.ExecutionResult Maven / Gradle / Ivy

The newest version!
package org.nd4j.autodiff.samediff.config;

import lombok.Builder;
import lombok.Data;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.util.MultiValueMap;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.factory.Nd4j;

import java.util.*;

@Builder
@Data
public class ExecutionResult {

    private Map> outputs;
    private Map valueOutputs;



    public static  ExecutionResult createFrom(List names,List input) {
        Preconditions.checkState(names.size() == input.size(),"Inputs and names must be equal size!");
        Map> outputs = new LinkedHashMap<>();
        for(int i = 0; i < input.size(); i++) {
            outputs.put(names.get(i),input.get(i) == null ? Optional.empty() : Optional.of(input.get(i)));
        }
        return ExecutionResult.builder()
                .outputs(outputs)
                .build();
    }

    public static ExecutionResult createValue(String name,SDValue inputs) {
        return ExecutionResult.builder()
                .valueOutputs(Collections.singletonMap(name,inputs))
                .build();
    }


    public static ExecutionResult createValue(String name,List inputs) {
        return ExecutionResult.builder()
                .valueOutputs(Collections.singletonMap(name,SDValue.create(inputs)))
                .build();
    }

    public static  ExecutionResult createFrom(String name,INDArray input) {
        return createFrom(Arrays.asList(name),Arrays.asList(input));
    }


    public static  ExecutionResult createFrom(DifferentialFunction func, OpContext opContext) {
        return createFrom(Arrays.asList(func.outputVariablesNames())
                ,opContext.getOutputArrays().toArray(new INDArray[opContext.getOutputArrays().size()]));
    }

    public static  ExecutionResult createFrom(List names,INDArray[] input) {
        Preconditions.checkState(names.size() == input.length,"Inputs and names must be equal size!");
        Map> outputs = new LinkedHashMap<>();
        for(int i = 0; i < input.length; i++) {
            outputs.put(names.get(i),Optional.ofNullable(input[i]));
        }
        return ExecutionResult.builder()
                .outputs(outputs)
                .build();
    }

    public INDArray[] outputsToArray(List inputs) {
        if(valueOutputs != null) {
            INDArray[] ret =  new INDArray[valueOutputs.size()];
            int count = 0;
            for(Map.Entry entry : valueOutputs.entrySet()) {
                if(entry.getValue() != null)
                    ret[count++] = entry.getValue().getTensorValue();
            }
            return ret;
        } else if(outputs != null) {
            INDArray[] ret =  new INDArray[inputs.size()];
            for(int i = 0; i < inputs.size(); i++) {
                Optional get = outputs.get(inputs.get(i));
                try {
                    ret[i] = get.get();
                }catch(NullPointerException e) {
                    ret[i] = null;
                }
            }

            return ret;
        } else {
            throw new IllegalStateException("No outputs to be converted.");
        }

    }


    public boolean hasValues() {
        return valueOutputs != null;
    }

    public boolean hasSingle() {
        return outputs != null;
    }


    public int numResults() {
        if(outputs != null && !outputs.isEmpty())
            return outputs.size();
        else if(valueOutputs != null && !valueOutputs.isEmpty())
            return valueOutputs.size();
        return 0;
    }




    public boolean valueExistsAtIndex(int index) {
        if (outputs != null)
            return resultAt(index) != null;
        else if (valueOutputs != null) {
            SDValue value = valueWithKey(valueAtIndex(index));
            if (value != null) {
                switch (value.getSdValueType()) {
                    case TENSOR:
                        return value.getTensorValue() != null;
                    case LIST:
                        return value.getListValue() != null;
                }
            }

        }

        return false;

    }


    public boolean isNull() {
        return valueOutputs == null && outputs == null;
    }


    public INDArray resultOrValueAt(int index, boolean returnDummy) {
        if(hasValues()) {
            SDValue sdValue = valueWithKeyAtIndex(index, returnDummy);
            if(sdValue != null)
                return sdValue.getTensorValue();
            return null;
        }
        else
            return resultAt(index);
    }


    private String valueAtIndex(int index) {
        Set keys = valueOutputs != null ? valueOutputs.keySet() : outputs.keySet();
        int count = 0;
        for(String value : keys) {
            if(count == index)
                return value;
            count++;
        }

        return null;
    }

    public SDValue valueWithKeyAtIndex(int index, boolean returnDummy) {
        if(valueOutputs == null)
            return null;
        String key = valueAtIndex(index);
        if(valueOutputs.containsKey(key)) {
            SDValue sdValue = valueOutputs.get(key);
            if(sdValue != null && sdValue.getSdValueType() == SDValueType.LIST && returnDummy)
                return SDValue.create(Nd4j.empty(DataType.FLOAT));
            else
                return sdValue;
        }
        return valueOutputs.get(key);
    }

    public SDValue valueWithKey(String name) {
        if(valueOutputs == null)
            return null;
        return valueOutputs.get(name);
    }

    public INDArray resultAt(int index) {
        if(outputs == null) {
            return null;
        }

        String name = this.valueAtIndex(index);
        return outputs.get(name).get();
    }


    public static Map unpack(Map> result) {
        Map ret = new LinkedHashMap<>();
        for(Map.Entry> entry : result.entrySet()) {
            ret.put(entry.getKey(),entry.getValue().get());
        }

        return ret;
    }


    public static Map> pack(Map result) {
        Map> ret = new LinkedHashMap<>();
        for(Map.Entry entry : result.entrySet()) {
            ret.put(entry.getKey(),Optional.ofNullable(entry.getValue().get()));
        }

        return ret;
    }


}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy