ai.djl.ndarray.internal.NDFormat Maven / Gradle / Ivy
The newest version!
/*
* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.ndarray.internal;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.Utils;
import java.lang.management.ManagementFactory;
import java.util.Arrays;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/** A helper for printing an {@link NDArray}. */
public abstract class NDFormat {
private static final int PRECISION = 8;
private static final String LF = System.lineSeparator();
private static final Pattern PATTERN = Pattern.compile("\\s*\\d\\.(\\d*?)0*e[+-](\\d+)");
private static final boolean DEBUGGER =
!Boolean.getBoolean("jshell")
&& ManagementFactory.getRuntimeMXBean().getInputArguments().stream()
.anyMatch(arg -> arg.startsWith("-agentlib:jdwp"));
/**
* Formats the contents of an array as a pretty printable string.
*
* @param array the array to print
* @param maxSize the maximum elements to print out
* @param maxDepth the maximum depth to print out
* @param maxRows the maximum rows to print out
* @param maxColumns the maximum columns to print out
* @return the string representation of the array
*/
public static String format(
NDArray array, int maxSize, int maxDepth, int maxRows, int maxColumns) {
return format(array, maxSize, maxDepth, maxRows, maxColumns, !DEBUGGER);
}
/**
* Formats the contents of an array as a pretty printable string.
*
* @param array the array to print
* @param maxSize the maximum elements to print out
* @param maxDepth the maximum depth to print out
* @param maxRows the maximum rows to print out
* @param maxColumns the maximum columns to print out
* @param withContent true to show the content of NDArray
* @return the string representation of the array
*/
public static String format(
NDArray array,
int maxSize,
int maxDepth,
int maxRows,
int maxColumns,
boolean withContent) {
StringBuilder sb = new StringBuilder(1000);
String name = array.getName();
if (name != null) {
sb.append(name).append(": ");
} else {
sb.append("ND: ");
}
sb.append(array.getShape())
.append(' ')
.append(array.getDevice())
.append(' ')
.append(array.getDataType());
if (array.hasGradient()) {
sb.append(" hasGradient");
}
if (!withContent) {
sb.append("\nCheck the \"Development Guideline\"->Debug to enable array display.\n");
return sb.toString();
}
NDFormat format;
DataType dataType = array.getDataType();
if (dataType == DataType.BOOLEAN) {
format = new BooleanFormat();
} else if (dataType == DataType.STRING) {
format = new StringFormat();
} else if (dataType.isInteger()) {
format = new IntFormat();
} else {
format = new FloatFormat();
}
return format.dump(sb, array, maxSize, maxDepth, maxRows, maxColumns);
}
protected abstract CharSequence format(Number value);
protected void init(NDArray array) {}
protected String dump(
StringBuilder sb,
NDArray array,
int maxSize,
int maxDepth,
int maxRows,
int maxColumns) {
sb.append(LF);
long size = array.size();
long dimension = array.getShape().dimension();
if (size == 0) {
// corner case: 0 dimension
sb.append("[]").append(LF);
} else if (dimension == 0) {
// scalar case
init(array);
sb.append(format(array.toArray()[0])).append(LF);
} else if (size > maxSize) {
sb.append("Exceed max print size:").append(LF);
int limit = Math.min(maxSize, maxRows * maxColumns);
dumpFlat(sb, array, limit);
} else if (dimension > maxDepth) {
sb.append("Exceed max print dimension:").append(LF);
int limit = Math.min(maxSize, maxRows * maxColumns);
dumpFlat(sb, array, limit);
} else {
init(array);
dump(sb, array, 0, true, maxRows, maxColumns);
}
return sb.toString();
}
private void dump(
StringBuilder sb,
NDArray array,
int depth,
boolean first,
int maxRows,
int maxColumns) {
if (!first) {
Utils.pad(sb, ' ', depth);
}
sb.append('[');
Shape shape = array.getShape();
if (shape.dimension() == 1) {
append(sb, array.toArray(), maxColumns);
} else {
long len = shape.head();
long limit = Math.min(len, maxRows);
for (int i = 0; i < limit; ++i) {
try (NDArray nd = array.get(i)) {
dump(sb, nd, depth + 1, i == 0, maxRows, maxColumns);
}
}
long remaining = len - limit;
if (remaining > 0) {
Utils.pad(sb, ' ', depth + 1);
sb.append("... ").append(remaining).append(" more");
}
Utils.pad(sb, ' ', depth);
}
// last "]"
if (depth == 0) {
sb.append(']').append(LF);
} else {
sb.append("],").append(LF);
}
}
@SuppressWarnings("try")
private void dumpFlat(StringBuilder sb, NDArray array, int limit) {
try (NDScope ignore = new NDScope()) {
NDArray tmp = array.flatten().get(":" + limit);
init(tmp);
sb.append('{');
append(sb, array.toArray(), limit);
sb.append('}').append(LF);
}
}
private void append(StringBuilder sb, Number[] values, int maxColumns) {
if (values.length == 0) {
return;
}
long limit = Math.min(values.length, maxColumns);
sb.append(format(values[0]));
for (int i = 1; i < limit; ++i) {
sb.append(", ");
sb.append(format(values[i]));
}
long remaining = values.length - limit;
if (remaining > 0) {
sb.append(", ... ").append(remaining).append(" more");
}
}
private static final class FloatFormat extends NDFormat {
private boolean exponential;
private int precision;
private int totalLength;
/** {@inheritDoc} */
@Override
public void init(NDArray array) {
Number[] values = array.toArray();
int maxIntPartLen = 0;
int maxFractionLen = 0;
int expFractionLen = 0;
int maxExpSize = 2;
boolean sign = false;
double max = 0;
double min = Double.MAX_VALUE;
for (Number n : values) {
double v = n.doubleValue();
if (v < 0) {
sign = true;
}
if (!Double.isFinite(v)) {
int intPartLen = v < 0 ? 4 : 3;
if (totalLength < intPartLen) {
totalLength = intPartLen;
}
continue;
}
double abs = Math.abs(v);
String str = String.format(Locale.ENGLISH, "%16e", abs);
Matcher m = PATTERN.matcher(str);
if (!m.matches()) {
throw new AssertionError("Invalid decimal value: " + str);
}
int fractionLen = m.group(1).length();
if (expFractionLen < fractionLen) {
expFractionLen = fractionLen;
}
int expSize = m.group(2).length();
if (expSize > maxExpSize) {
maxExpSize = expSize;
}
if (abs >= 1) {
int intPartLen = (int) Math.log10(abs) + 1;
if (v < 0) {
++intPartLen;
}
if (intPartLen > maxIntPartLen) {
maxIntPartLen = intPartLen;
}
int fullFractionLen = fractionLen + 1 - intPartLen;
if (maxFractionLen < fullFractionLen) {
maxFractionLen = fullFractionLen;
}
} else {
int intPartLen = v < 0 ? 2 : 1;
if (intPartLen > maxIntPartLen) {
maxIntPartLen = intPartLen;
}
int fullFractionLen = fractionLen + Integer.parseInt(m.group(2));
if (maxFractionLen < fullFractionLen) {
maxFractionLen = fullFractionLen;
}
}
if (abs > max) {
max = abs;
}
if (abs < min && abs > 0) {
min = abs;
}
}
double ratio = max / min;
if (max > 1.e8 || min < 0.0001 || ratio > 1000.) {
exponential = true;
precision = Math.min(PRECISION, expFractionLen);
totalLength = precision + 4;
if (sign) {
++totalLength;
}
} else {
precision = Math.min(4, maxFractionLen);
int len = maxIntPartLen + precision + 1;
if (totalLength < len) {
totalLength = len;
}
}
}
/** {@inheritDoc} */
@Override
public CharSequence format(Number value) {
double d = value.doubleValue();
if (Double.isNaN(d)) {
return String.format(Locale.ENGLISH, "%" + totalLength + "s", "nan");
} else if (Double.isInfinite(d)) {
if (d > 0) {
return String.format(Locale.ENGLISH, "%" + totalLength + "s", "inf");
} else {
return String.format(Locale.ENGLISH, "%" + totalLength + "s", "-inf");
}
}
if (exponential) {
precision = Math.max(PRECISION, precision);
return String.format(Locale.ENGLISH, "% ." + precision + "e", value.doubleValue());
}
if (precision == 0) {
String fmt = "%" + (totalLength - 1) + '.' + precision + "f.";
return String.format(Locale.ENGLISH, fmt, value.doubleValue());
}
String fmt = "%" + totalLength + '.' + precision + 'f';
String ret = String.format(Locale.ENGLISH, fmt, value.doubleValue());
// Replace trailing zeros with space
char[] chars = ret.toCharArray();
for (int i = chars.length - 1; i >= 0; --i) {
if (chars[i] == '0') {
chars[i] = ' ';
} else {
break;
}
}
return new String(chars);
}
}
private static final class IntFormat extends NDFormat {
private boolean exponential;
private int precision;
private int totalLength;
/** {@inheritDoc} */
@Override
public void init(NDArray array) {
Number[] values = array.toArray();
// scalar case
if (values.length == 1) {
totalLength = 1;
return;
}
long max = 0;
long negativeMax = 0;
for (Number n : values) {
long v = n.longValue();
long abs = Math.abs(v);
if (v < 0 && abs > negativeMax) {
negativeMax = abs;
}
if (abs > max) {
max = abs;
}
}
if (max >= 1.e8) {
exponential = true;
precision = Math.min(PRECISION, (int) Math.log10(max) + 1);
} else {
int size = (max != 0) ? (int) Math.log10(max) + 1 : 1;
int negativeSize = (negativeMax != 0) ? (int) Math.log10(negativeMax) + 2 : 2;
totalLength = Math.max(size, negativeSize);
}
}
/** {@inheritDoc} */
@Override
public CharSequence format(Number value) {
if (exponential) {
return String.format(Locale.ENGLISH, "% ." + precision + "e", value.floatValue());
}
return String.format(Locale.ENGLISH, "%" + totalLength + "d", value.longValue());
}
}
private static final class BooleanFormat extends NDFormat {
/** {@inheritDoc} */
@Override
public CharSequence format(Number value) {
return value.byteValue() != 0 ? " true" : "false";
}
}
private static final class StringFormat extends NDFormat {
/** {@inheritDoc} */
@Override
public CharSequence format(Number value) {
return null;
}
/** {@inheritDoc} */
@Override
protected String dump(
StringBuilder sb,
NDArray array,
int maxSize,
int maxDepth,
int maxRows,
int maxColumns) {
return Arrays.toString(array.toStringArray());
}
}
}