All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.string.NDArrayStrings Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
package org.nd4j.linalg.string;

import org.apache.commons.lang3.StringUtils;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;

import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;

/**
 * @author Adam Gibson
 * @author Susan Eraly
 */
public class NDArrayStrings {

    private String sep = ",";
    private int padding = 0;

    private String decFormatNum = "#,###,##0";
    private String decFormatRest = "";
    private DecimalFormat decimalFormat = new DecimalFormat(decFormatNum + decFormatRest);

    public NDArrayStrings(String sep) {
        this(", ",2,"#,###,##0");
    }

    public NDArrayStrings(int precision) {
        this(", ",precision, "#,###,##0");
    }

    public NDArrayStrings(String sep, int precision) {
        this(sep, precision, "#,###,##0");
    }

    public NDArrayStrings(String sep, int precision, String decFormat) {
        this.decFormatNum = decFormat;
        this.sep = sep;
        if (precision != 0) {
            this.decFormatRest = ".";
            while (precision > 0) {
                this.decFormatRest += "0";
                precision--;
            }
        }
        this.decimalFormat = new DecimalFormat(decFormatNum + decFormatRest);
        DecimalFormatSymbols sepNgroup = DecimalFormatSymbols.getInstance();
        sepNgroup.setDecimalSeparator('.');
        sepNgroup.setGroupingSeparator(',');
        decimalFormat.setDecimalFormatSymbols(sepNgroup);
    }

    public NDArrayStrings() {
        this(", ",2,"#,###,##0");
    }


    /**
     * Format the given ndarray as a string
     * @param arr the array to format
     * @return the formatted array
     */
    public String format(INDArray arr) {
        String padding = decimalFormat.format(3.0000);
        this.padding = padding.length();
        return format(arr,arr.rank());
    }
    private String format(INDArray arr,int rank) {
        return format(arr,arr.rank(),0);
    }

    private String format(INDArray arr,int rank, int offset) {
        StringBuilder sb = new StringBuilder();
        if(arr.isScalar()) {
            if(arr instanceof IComplexNDArray)
                return ((IComplexNDArray) arr).getComplex(0).toString();
            return decimalFormat.format(arr.getDouble(0));
        }
        else if(rank <= 0)
            return "";

        else if(arr.isVector()) {
            sb.append("[");
            for(int i = 0; i < arr.length(); i++) {
                if(arr instanceof IComplexNDArray)
                    sb.append(((IComplexNDArray) arr).getComplex(i).toString());
                else
                    sb.append(String.format("%1$"+padding+"s",decimalFormat.format(arr.getDouble(i))));
                if(i < arr.length() - 1)
                    sb.append(sep);
            }
            sb.append("]");
            return sb.toString();
        }

        else {
            offset++;
            sb.append("[");
            for(int i = 0; i < arr.slices(); i++) {
                sb.append(format(arr.slice(i),rank - 1,offset));
                if (i != arr.slices() - 1) {
                    sb.append(",\n");
                    sb.append(StringUtils.repeat("\n",rank-2));
                    sb.append(StringUtils.repeat(" ",offset));
                }
            }
            sb.append("]");
            return sb.toString();
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy