boofcv.deepboof.ImageClassifierNiNImageNet Maven / Gradle / Ivy
/*
* Copyright (c) 2021, Peter Abeles. All Rights Reserved.
*
* This file is part of BoofCV (http://boofcv.org).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package boofcv.deepboof;
import boofcv.alg.misc.GPixelMath;
import boofcv.struct.image.GrayF32;
import boofcv.struct.image.Planar;
import deepboof.Function;
import deepboof.io.torch7.ConvertTorchToBoofForward;
import deepboof.io.torch7.ParseAsciiTorch7;
import deepboof.io.torch7.ParseBinaryTorch7;
import deepboof.io.torch7.SequenceAndParameters;
import deepboof.io.torch7.struct.*;
import deepboof.tensors.Tensor_F32;
import java.io.File;
import java.io.IOException;
import java.util.List;
import static deepboof.misc.TensorOps.WI;
/**
* Pretrained Network-in-Network (NiN) image classifier using imagenet data. Trained by szagoruyko [1,2] and
* achieves 62.6% top1 center crop accuracy on validation set.
*
*
* [1] https://gist.github.com/szagoruyko/0f5b4c5e2d2b18472854
* [2] https://github.com/soumith/imagenet-multiGPU.torch/blob/master/models/ninbn.lua
*
*
* @author Peter Abeles
*/
@SuppressWarnings({"NullAway.Init"})
public class ImageClassifierNiNImageNet extends BaseImageClassifier {
// normalization parameters
float[] mean;
float[] stdev;
// int imageSize = 256;
static final int imageCrop = 224;
// Input image with the bands in the correct order
Planar imageBgr = new Planar<>(GrayF32.class, imageCrop, imageCrop, 3);
public ImageClassifierNiNImageNet() {
super(imageCrop);
}
@Override
public void loadModel( File directory ) throws IOException {
List list = new ParseBinaryTorch7().parse(new File(directory, "nin_bn_final.t7"));
TorchGeneric torchSequence = ((TorchGeneric)list.get(0)).get("model");
TorchGeneric torchNorm = torchSequence.get("transform");
mean = torchListToArray((TorchList)torchNorm.get("mean"));
stdev = torchListToArray((TorchList)torchNorm.get("std"));
SequenceAndParameters> seqparam =
ConvertTorchToBoofForward.convert(torchSequence);
network = seqparam.createForward(3, imageCrop, imageCrop);
tensorOutput = new Tensor_F32(WI(1, network.getOutputShape()));
TorchList torchCategories = (TorchList)new ParseAsciiTorch7().parse(new File(directory, "synset.t7")).get(0);
categories.clear();
for (int i = 0; i < torchCategories.list.size(); i++) {
categories.add(((TorchString)torchCategories.list.get(i)).message);
}
}
private float[] torchListToArray( TorchList torch ) {
float[] ret = new float[torch.list.size()];
for (int i = 0; i < ret.length; i++) {
ret[i] = (float)((TorchNumber)torch.list.get(i)).value;
}
return ret;
}
/**
* Massage the input image into a format recognized by the network
*/
@Override
protected Planar preprocess( Planar image ) {
super.preprocess(image);
// image net is BGR color order
imageBgr.bands[0] = imageRgb.bands[2];
imageBgr.bands[1] = imageRgb.bands[1];
imageBgr.bands[2] = imageRgb.bands[0];
// image needs to be between 0 and 1
GPixelMath.divide(imageBgr, 255, imageBgr);
// Normalize the image's statistics
for (int band = 0; band < 3; band++) {
DataManipulationOps.normalize(imageBgr.getBand(band), mean[band], stdev[band]);
}
return imageBgr;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy