org.nd4j.autodiff.functions.FunctionProperties Maven / Gradle / Ivy
package org.nd4j.autodiff.functions;
import com.google.common.primitives.Doubles;
import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import com.google.flatbuffers.FlatBufferBuilder;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatProperties;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
@Data
@Slf4j
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class FunctionProperties {
private String name;
@Builder.Default private Map fieldNames = new LinkedHashMap<>();
@Builder.Default private List i = new ArrayList<>();
@Builder.Default private List l = new ArrayList<>();
@Builder.Default private List d = new ArrayList<>();
@Builder.Default private List a = new ArrayList<>();
/**
* This method converts this FunctionProperties instance to FlatBuffers representation
* @param bufferBuilder
* @return
*/
public int asFlatProperties(FlatBufferBuilder bufferBuilder) {
int iname = bufferBuilder.createString(name);
int ii = FlatProperties.createIVector(bufferBuilder, Ints.toArray(i));
int il = FlatProperties.createLVector(bufferBuilder, Longs.toArray(l));
int id = FlatProperties.createDVector(bufferBuilder, Doubles.toArray(d));
int arrays[] = new int[a.size()];
int cnt = 0;
for (val array: a) {
int off = array.toFlatArray(bufferBuilder);
arrays[cnt++] = off;
}
int ia = FlatProperties.createAVector(bufferBuilder, arrays);
return FlatProperties.createFlatProperties(bufferBuilder, iname, ii, il, id, ia);
}
/**
* This method creates new FunctionProperties instance from FlatBuffers representation
* @param properties
* @return
*/
public static FunctionProperties fromFlatProperties(FlatProperties properties) {
val props = new FunctionProperties();
for (int e = 0; e < properties.iLength(); e++)
props.getI().add(properties.i(e));
for (int e = 0; e < properties.lLength(); e++)
props.getL().add(properties.l(e));
for (int e = 0; e < properties.dLength(); e++)
props.getD().add(properties.d(e));
for (int e = 0; e < properties.iLength(); e++)
props.getA().add(Nd4j.createFromFlatArray(properties.a(e)));
return props;
}
/**
* This method converts multiple FunctionProperties to FlatBuffers representation
*
* @param bufferBuilder
* @param properties
* @return
*/
public static int asFlatProperties(FlatBufferBuilder bufferBuilder, Collection properties) {
int props[] = new int[properties.size()];
int cnt = 0;
for (val p: properties)
props[cnt++] = p.asFlatProperties(bufferBuilder);
return FlatNode.createPropertiesVector(bufferBuilder, props);
}
}