Please wait. This can take some minutes ...
Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance.
Project price only 1 $
You can buy this project and download/modify it how often you want.
org.nd4j.linalg.api.ops.impl.image.CropAndResize Maven / Gradle / Ivy
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.linalg.api.ops.impl.image;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.*;
@NoArgsConstructor
public class CropAndResize extends DynamicCustomOp {
public enum Method {BILINEAR, NEAREST};
protected Method method = Method.BILINEAR;
protected double extrapolationValue = 0.0;
public CropAndResize(@NonNull SameDiff sameDiff, @NonNull SDVariable image, @NonNull SDVariable cropBoxes, @NonNull SDVariable boxIndices,
@NonNull SDVariable cropOutSize, @NonNull Method method, double extrapolationValue){
super(sameDiff, new SDVariable[]{image, cropBoxes, boxIndices, cropOutSize});
this.method = method;
this.extrapolationValue = extrapolationValue;
addArgs();
}
public CropAndResize(@NonNull SameDiff sameDiff, SDVariable image, SDVariable cropBoxes, SDVariable boxIndices,
SDVariable cropOutSize, double extrapolationValue) {
this(sameDiff, image, cropBoxes, boxIndices, cropOutSize, null, extrapolationValue);
}
public CropAndResize(@NonNull INDArray image, @NonNull INDArray cropBoxes, @NonNull INDArray boxIndices,
@NonNull INDArray cropOutSize, @NonNull Method method, double extrapolationValue,
INDArray output){
super(new INDArray[]{image, cropBoxes, boxIndices, cropOutSize}, null);
Preconditions.checkArgument(image.rank() == 4, "Input image must be rank 4 with shape [batch, height, width, channels], got %ndShape", image);
Preconditions.checkArgument(cropBoxes.rank() == 2 && cropBoxes.size(1) == 4, "Crop boxes must be rank 4 with shape [num_boxes, 5], got %ndShape", cropBoxes);
Preconditions.checkArgument(boxIndices.rank() == 1 && cropBoxes.size(0) == boxIndices.size(0),
"Box indices must be rank 1 array with shape [num_boxes] (same as cropBoxes.size(0), got array with shape %ndShape", boxIndices);
this.method = method;
this.extrapolationValue = extrapolationValue;
addArgs();
outputArguments.add(output);
}
public CropAndResize(INDArray image, INDArray cropBoxes, INDArray boxIndices, INDArray cropOutSize, double extrapolationValue ) {
this(image, cropBoxes, boxIndices, cropOutSize, null, extrapolationValue, null);
}
@Override
public String opName() {
return "crop_and_resize";
}
@Override
public String tensorflowName() {
return "CropAndResize";
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) {
String method = attributesForNode.get("method").getS().toStringUtf8();
if(method.equalsIgnoreCase("nearest")){
this.method = Method.NEAREST;
} else {
this.method = Method.BILINEAR;
}
if(attributesForNode.containsKey("extrapolation_value")){
extrapolationValue = attributesForNode.get("extrapolation_value").getF();
}
addArgs();
}
protected void addArgs() {
addIArgument(method == Method.BILINEAR ? 0 : 1);
addTArgument(extrapolationValue);
}
@Override
public List doDiff(List f1) {
//TODO we can probably skip this sometimes...
List out = new ArrayList<>();
for(SDVariable v : args()){
out.add(sameDiff.zerosLike(v));
}
return out;
}
@Override
public List calculateOutputDataTypes(List inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 4,
"Expected 4 input datatypes for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(DataType.FLOAT); //TF import: always returns float32...
}
}