org.apache.lucene.classification.utils.DatasetSplitter Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of lucene-classification Show documentation
Show all versions of lucene-classification Show documentation
Apache Lucene (module: classification)
The newest version!
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.classification.utils;
import java.io.IOException;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.GroupingSearch;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.IOUtils;
/**
* Utility class for creating training / test / cross validation indexes from the original index.
*/
public class DatasetSplitter {
private final double crossValidationRatio;
private final double testRatio;
/**
* Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes
*
* @param testRatio the ratio of the original index to be used for the test IDX as a double
*
between 0.0 and 1.0
* @param crossValidationRatio the ratio of the original index to be used for the c.v. IDX as a
* double
between 0.0 and 1.0
*/
public DatasetSplitter(double testRatio, double crossValidationRatio) {
this.crossValidationRatio = crossValidationRatio;
this.testRatio = testRatio;
}
/**
* Split a given index into 3 indexes for training, test and cross validation tasks respectively
*
* @param originalIndex an {@link org.apache.lucene.index.LeafReader} on the source index
* @param trainingIndex a {@link Directory} used to write the training index
* @param testIndex a {@link Directory} used to write the test index
* @param crossValidationIndex a {@link Directory} used to write the cross validation index
* @param analyzer {@link Analyzer} used to create the new docs
* @param termVectors {@code true} if term vectors should be kept
* @param classFieldName name of the field used as the label for classification; this must be
* indexed with sorted doc values
* @param fieldNames names of fields that need to be put in the new indexes or null
* if all should be used
* @throws IOException if any writing operation fails on any of the indexes
*/
public void split(
IndexReader originalIndex,
Directory trainingIndex,
Directory testIndex,
Directory crossValidationIndex,
Analyzer analyzer,
boolean termVectors,
String classFieldName,
String... fieldNames)
throws IOException {
// create IWs for train / test / cv IDXs
IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer));
IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer));
IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
// get the exact no. of existing classes
int noOfClasses = 0;
for (LeafReaderContext leave : originalIndex.leaves()) {
long valueCount = 0;
SortedDocValues classValues = leave.reader().getSortedDocValues(classFieldName);
if (classValues != null) {
valueCount = classValues.getValueCount();
} else {
SortedSetDocValues sortedSetDocValues =
leave.reader().getSortedSetDocValues(classFieldName);
if (sortedSetDocValues != null) {
valueCount = sortedSetDocValues.getValueCount();
}
}
if (classValues == null) {
// approximate with no. of terms
noOfClasses += leave.reader().terms(classFieldName).size();
}
noOfClasses += valueCount;
}
try {
IndexSearcher indexSearcher = new IndexSearcher(originalIndex);
GroupingSearch gs = new GroupingSearch(classFieldName);
gs.setGroupSort(Sort.INDEXORDER);
gs.setSortWithinGroup(Sort.INDEXORDER);
gs.setAllGroups(true);
gs.setGroupDocsLimit(originalIndex.maxDoc());
TopGroups