boofcv.deepboof.ImageClassifierVggCifar10 Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of boofcv-recognition Show documentation
Show all versions of boofcv-recognition Show documentation
BoofCV is an open source Java library for real-time computer vision and robotics applications.
/*
* Copyright (c) 2011-2017, Peter Abeles. All Rights Reserved.
*
* This file is part of BoofCV (http://boofcv.org).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package boofcv.deepboof;
import boofcv.alg.color.ColorYuv;
import boofcv.alg.filter.stat.ImageLocalNormalization;
import boofcv.core.image.border.BorderType;
import boofcv.struct.convolve.Kernel1D_F32;
import boofcv.struct.image.GrayF32;
import boofcv.struct.image.Planar;
import deepboof.Function;
import deepboof.datasets.UtilCifar10;
import deepboof.io.torch7.ParseBinaryTorch7;
import deepboof.io.torch7.SequenceAndParameters;
import deepboof.models.DeepModelIO;
import deepboof.models.YuvStatistics;
import deepboof.tensors.Tensor_F32;
import java.io.File;
import java.io.IOException;
import static deepboof.misc.TensorOps.WI;
/**
* Image classification using VGG network trained in CIFAR 10 data. On the CIFAR 10 training set it get has
* 89.9% accuracy. This dataset contains images in 10 categories and 32x32 images.
*
* @see szagoruyko/cifar.torch
*
* @author Peter Abeles
*/
public class ImageClassifierVggCifar10 extends BaseImageClassifier {
static final int inputSize = 32;
Planar imageYuv = new Planar<>(GrayF32.class,inputSize,inputSize,3);
ImageLocalNormalization localNorm;
YuvStatistics stats;
Kernel1D_F32 kernel;
public ImageClassifierVggCifar10() {
super(inputSize);
categories.addAll(UtilCifar10.getClassNames());
}
/**
* Expects there to be two files in the provided directory:
* YuvStatistics.txt
* model.net
*
* @param directory Directory containing model files
* @throws IOException Throw if anything goes wrong while reading data
*/
@Override
public void loadModel(File directory) throws IOException {
stats = DeepModelIO.load(new File(directory,"YuvStatistics.txt"));
SequenceAndParameters> sequence =
new ParseBinaryTorch7().parseIntoBoof(new File(directory,"model.net"));
network = sequence.createForward(3,inputSize,inputSize);
tensorOutput = new Tensor_F32(WI(1,network.getOutputShape()));
BorderType type = BorderType.valueOf(stats.border);
localNorm = new ImageLocalNormalization<>(GrayF32.class, type);
kernel = DataManipulationOps.create1D_F32(stats.kernel);
}
@Override
protected Planar preprocess(Planar image) {
super.preprocess(image);
ColorYuv.rgbToYuv_F32(imageRgb, imageYuv);
// Normalize the image
localNorm.zeroMeanStdOne(kernel, imageYuv.getBand(0),255.0,1e-4, imageYuv.getBand(0));
DataManipulationOps.normalize(imageYuv.getBand(1), (float)stats.meanU, (float)stats.stdevU);
DataManipulationOps.normalize(imageYuv.getBand(2), (float)stats.meanV, (float)stats.stdevV);
return imageYuv;
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy