boofcv.alg.bow.LearnSceneFromFiles Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of boofcv-learning Show documentation
Show all versions of boofcv-learning Show documentation
BoofCV is an open source Java library for real-time computer vision and robotics applications.
The newest version!
/*
* Copyright (c) 2021, Peter Abeles. All Rights Reserved.
*
* This file is part of BoofCV (http://boofcv.org).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package boofcv.alg.bow;
import boofcv.struct.learning.ClassificationHistogram;
import boofcv.struct.learning.Confusion;
import org.jetbrains.annotations.Nullable;
import java.io.File;
import java.util.*;
/**
* Abstract class which provides a frame work for learning a scene classifier from a set of images.
*
* TODO describe how it provides learning
*
* @author Peter Abeles
*/
@SuppressWarnings("NullAway.Init")
public abstract class LearnSceneFromFiles {
protected Random rand;
protected List scenes = new ArrayList<>();
// The minimum number of images in each type of set
int minimumTrain;
int minimumCross;
int minimumTest;
// how to divide the input set up
double fractionTrain;
double fractionCross;
// maps for each set of images
protected Map> train;
protected Map> cross;
protected Map> test;
public Confusion evaluateTest() {
return evaluate(test);
}
/**
* Given a set of images with known classification, predict which scene each one belongs in and compute
* a confusion matrix for the results.
*
* @param set Set of classified images
* @return Confusion matrix
*/
protected Confusion evaluate( Map> set ) {
ClassificationHistogram histogram = new ClassificationHistogram(scenes.size());
int total = 0;
for (int i = 0; i < scenes.size(); i++) {
total += Objects.requireNonNull(set.get(scenes.get(i))).size();
}
System.out.println("total images " + total);
for (int sceneIdx = 0; sceneIdx < scenes.size(); sceneIdx++) {
String scene = scenes.get(sceneIdx);
List images = Objects.requireNonNull(set.get(scene));
System.out.println(" " + scene + " " + images.size());
for (int imageIdx = 0; imageIdx < images.size(); imageIdx++) {
String image = images.get(imageIdx);
int predicted = classify(image);
histogram.increment(sceneIdx, predicted);
}
}
return histogram.createConfusion();
}
/**
* Given an image compute which scene it belongs to
*
* @param path Path to input image
* @return integer corresponding to the scene
*/
protected abstract int classify( String path );
public void loadSets( File dirTraining, File dirCross, File dirTest ) {
train = Objects.requireNonNull(findImages(dirTraining));
if (dirCross != null)
cross = Objects.requireNonNull(findImages(dirCross));
test = Objects.requireNonNull(findImages(dirTest));
extractKeys(train);
extractKeys(test);
}
private void extractKeys( Map> images ) {
Set keys = images.keySet();
for (String key : keys) { // lint:forbidden ignore_line
if (!scenes.contains(key)) {
scenes.add(key);
}
}
}
public void loadThenSplit( File directory ) {
Map> all = Objects.requireNonNull(findImages(directory));
train = new HashMap<>();
if (fractionCross != 0)
cross = new HashMap<>();
test = new HashMap<>();
Set keys = all.keySet();
for (String key : keys) { // lint:forbidden ignore_line
List allImages = Objects.requireNonNull(all.get(key));
// randomize the ordering to remove bias
Collections.shuffle(allImages, rand);
int numTrain = (int)(allImages.size()*fractionTrain);
numTrain = Math.max(minimumTrain, numTrain);
int numCross = (int)(allImages.size()*fractionCross);
numCross = Math.max(minimumCross, numCross);
int numTest = allImages.size() - numTrain - numCross;
if (numTest < minimumTest)
throw new RuntimeException("Not enough images to create test set. " + key + " total = " + allImages.size());
createSubSet(key, allImages, train, 0, numTrain);
if (cross != null) {
createSubSet(key, allImages, cross, numTrain, numCross + numTrain);
}
createSubSet(key, allImages, test, numCross + numTrain, allImages.size());
}
scenes.addAll(keys);
}
private void createSubSet( String key, List allImages, Map> subset,
int start, int end ) {
List trainImages = new ArrayList<>();
for (int i = start; i < end; i++) {
trainImages.add(allImages.get(i));
}
subset.put(key, trainImages);
}
/**
* Loads the paths to image files contained in subdirectories of the root directory. Each sub directory
* is assumed to be a different category of images.
*/
public static @Nullable Map> findImages( File rootDir ) {
File[] files = rootDir.listFiles();
if (files == null)
return null;
List imageDirectories = new ArrayList<>();
for (File f : files) { // lint:forbidden ignore_line
if (f.isDirectory()) {
imageDirectories.add(f);
}
}
Map> out = new HashMap<>();
for (File d : imageDirectories) { // lint:forbidden ignore_line
List images = new ArrayList<>();
files = d.listFiles();
if (files == null)
throw new RuntimeException("Should be a directory!");
for (File f : files) { // lint:forbidden ignore_line
if (f.isHidden() || f.isDirectory() || f.getName().endsWith(".txt")) {
continue;
}
images.add(f.getPath());
}
String key = d.getName().toLowerCase();
out.put(key, images);
}
return out;
}
public List getScenes() {
return scenes;
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy