All Downloads are FREE. Search and download functionalities are using the official Maven repository.

boofcv.deepboof.ImageClassifierNiNImageNet Maven / Gradle / Ivy

/*
 * Copyright (c) 2021, Peter Abeles. All Rights Reserved.
 *
 * This file is part of BoofCV (http://boofcv.org).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package boofcv.deepboof;

import boofcv.alg.misc.GPixelMath;
import boofcv.struct.image.GrayF32;
import boofcv.struct.image.Planar;
import deepboof.Function;
import deepboof.io.torch7.ConvertTorchToBoofForward;
import deepboof.io.torch7.ParseAsciiTorch7;
import deepboof.io.torch7.ParseBinaryTorch7;
import deepboof.io.torch7.SequenceAndParameters;
import deepboof.io.torch7.struct.*;
import deepboof.tensors.Tensor_F32;

import java.io.File;
import java.io.IOException;
import java.util.List;

import static deepboof.misc.TensorOps.WI;

/**
 * 

Pretrained Network-in-Network (NiN) image classifier using imagenet data. Trained by szagoruyko [1,2] and * achieves 62.6% top1 center crop accuracy on validation set.

* *

* [1] https://gist.github.com/szagoruyko/0f5b4c5e2d2b18472854
* [2] https://github.com/soumith/imagenet-multiGPU.torch/blob/master/models/ninbn.lua *

* * @author Peter Abeles */ @SuppressWarnings({"NullAway.Init"}) public class ImageClassifierNiNImageNet extends BaseImageClassifier { // normalization parameters float[] mean; float[] stdev; // int imageSize = 256; static final int imageCrop = 224; // Input image with the bands in the correct order Planar imageBgr = new Planar<>(GrayF32.class, imageCrop, imageCrop, 3); public ImageClassifierNiNImageNet() { super(imageCrop); } @Override public void loadModel( File directory ) throws IOException { List list = new ParseBinaryTorch7().parse(new File(directory, "nin_bn_final.t7")); TorchGeneric torchSequence = ((TorchGeneric)list.get(0)).get("model"); TorchGeneric torchNorm = torchSequence.get("transform"); mean = torchListToArray((TorchList)torchNorm.get("mean")); stdev = torchListToArray((TorchList)torchNorm.get("std")); SequenceAndParameters> seqparam = ConvertTorchToBoofForward.convert(torchSequence); network = seqparam.createForward(3, imageCrop, imageCrop); tensorOutput = new Tensor_F32(WI(1, network.getOutputShape())); TorchList torchCategories = (TorchList)new ParseAsciiTorch7().parse(new File(directory, "synset.t7")).get(0); categories.clear(); for (int i = 0; i < torchCategories.list.size(); i++) { categories.add(((TorchString)torchCategories.list.get(i)).message); } } private float[] torchListToArray( TorchList torch ) { float[] ret = new float[torch.list.size()]; for (int i = 0; i < ret.length; i++) { ret[i] = (float)((TorchNumber)torch.list.get(i)).value; } return ret; } /** * Massage the input image into a format recognized by the network */ @Override protected Planar preprocess( Planar image ) { super.preprocess(image); // image net is BGR color order imageBgr.bands[0] = imageRgb.bands[2]; imageBgr.bands[1] = imageRgb.bands[1]; imageBgr.bands[2] = imageRgb.bands[0]; // image needs to be between 0 and 1 GPixelMath.divide(imageBgr, 255, imageBgr); // Normalize the image's statistics for (int band = 0; band < 3; band++) { DataManipulationOps.normalize(imageBgr.getBand(band), mean[band], stdev[band]); } return imageBgr; } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy