All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.nd4j.linalg.util.NDArrayPreconditionsFormat Maven / Gradle / Ivy

There is a newer version: 1.0.0-M2.1
Show newest version
/*******************************************************************************
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.nd4j.linalg.util;

import org.nd4j.common.base.Preconditions;
import org.nd4j.common.base.PreconditionsFormat;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.util.Arrays;
import java.util.List;

/**
 * Preconditions format: Defines a set of tags for use with {@link Preconditions} class.
* %ndRank: rank of INDArray
* %ndShape: shape of INDArray
* %ndStride: stride of INDArray
* %ndLength: length of INDArray
* %ndSInfo: shape info of INDArray
* %nd10: First 10 values of INDArray (or all values if length <= 10
* * @author Alex Black */ public class NDArrayPreconditionsFormat implements PreconditionsFormat { private static final List TAGS = Arrays.asList( "%ndRank", "%ndShape", "%ndStride", "%ndLength", "%ndSInfo", "%nd10"); @Override public List formatTags() { return TAGS; } @Override public String format(String tag, Object arg) { if(arg == null) return "null"; INDArray arr = (INDArray)arg; switch (tag){ case "%ndRank": return String.valueOf(arr.rank()); case "%ndShape": return Arrays.toString(arr.shape()); case "%ndStride": return Arrays.toString(arr.stride()); case "%ndLength": return String.valueOf(arr.length()); case "%ndSInfo": return arr.shapeInfoToString().replaceAll("\n",""); case "%nd10": if(arr.isScalar() || arr.isEmpty()){ return arr.toString(); } INDArray sub = arr.reshape(arr.length()).get(NDArrayIndex.interval(0, Math.min(arr.length(), 10))); return sub.toString(); default: //Should never happen throw new IllegalStateException("Unknown format tag: " + tag); } } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy