org.nd4j.linalg.util.NDArrayPreconditionsFormat Maven / Gradle / Ivy
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.util;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.base.PreconditionsFormat;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.NDArrayIndex;
import java.util.Arrays;
import java.util.List;
/**
* Preconditions format: Defines a set of tags for use with {@link Preconditions} class.
* %ndRank: rank of INDArray
* %ndShape: shape of INDArray
* %ndStride: stride of INDArray
* %ndLength: length of INDArray
* %ndSInfo: shape info of INDArray
* %nd10: First 10 values of INDArray (or all values if length <= 10
*
* @author Alex Black
*/
public class NDArrayPreconditionsFormat implements PreconditionsFormat {
private static final List TAGS = Arrays.asList(
"%ndRank", "%ndShape", "%ndStride", "%ndLength", "%ndSInfo", "%nd10");
@Override
public List formatTags() {
return TAGS;
}
@Override
public String format(String tag, Object arg) {
if(arg == null)
return "null";
INDArray arr = (INDArray)arg;
switch (tag){
case "%ndRank":
return String.valueOf(arr.rank());
case "%ndShape":
return Arrays.toString(arr.shape());
case "%ndStride":
return Arrays.toString(arr.stride());
case "%ndLength":
return String.valueOf(arr.length());
case "%ndSInfo":
return arr.shapeInfoToString().replaceAll("\n","");
case "%nd10":
if(arr.isScalar() || arr.isEmpty()){
return arr.toString();
}
INDArray sub = arr.reshape(arr.length()).get(NDArrayIndex.interval(0, Math.min(arr.length(), 10)));
return sub.toString();
default:
//Should never happen
throw new IllegalStateException("Unknown format tag: " + tag);
}
}
}