All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.lucene.classification.utils.DatasetSplitter Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.classification.utils;

import java.io.IOException;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.StoredFields;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.GroupingSearch;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.IOUtils;

/**
 * Utility class for creating training / test / cross validation indexes from the original index.
 */
public class DatasetSplitter {

  private final double crossValidationRatio;
  private final double testRatio;

  /**
   * Create a {@link DatasetSplitter} by giving test and cross validation IDXs sizes
   *
   * @param testRatio the ratio of the original index to be used for the test IDX as a double
   *      between 0.0 and 1.0
   * @param crossValidationRatio the ratio of the original index to be used for the c.v. IDX as a
   *     double between 0.0 and 1.0
   */
  public DatasetSplitter(double testRatio, double crossValidationRatio) {
    this.crossValidationRatio = crossValidationRatio;
    this.testRatio = testRatio;
  }

  /**
   * Split a given index into 3 indexes for training, test and cross validation tasks respectively
   *
   * @param originalIndex an {@link org.apache.lucene.index.LeafReader} on the source index
   * @param trainingIndex a {@link Directory} used to write the training index
   * @param testIndex a {@link Directory} used to write the test index
   * @param crossValidationIndex a {@link Directory} used to write the cross validation index
   * @param analyzer {@link Analyzer} used to create the new docs
   * @param termVectors {@code true} if term vectors should be kept
   * @param classFieldName name of the field used as the label for classification; this must be
   *     indexed with sorted doc values
   * @param fieldNames names of fields that need to be put in the new indexes or null
   *     if all should be used
   * @throws IOException if any writing operation fails on any of the indexes
   */
  public void split(
      IndexReader originalIndex,
      Directory trainingIndex,
      Directory testIndex,
      Directory crossValidationIndex,
      Analyzer analyzer,
      boolean termVectors,
      String classFieldName,
      String... fieldNames)
      throws IOException {

    // create IWs for train / test / cv IDXs
    IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer));
    IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer));
    IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));

    // get the exact no. of existing classes
    int noOfClasses = 0;
    for (LeafReaderContext leave : originalIndex.leaves()) {
      long valueCount = 0;
      SortedDocValues classValues = leave.reader().getSortedDocValues(classFieldName);
      if (classValues != null) {
        valueCount = classValues.getValueCount();
      } else {
        SortedSetDocValues sortedSetDocValues =
            leave.reader().getSortedSetDocValues(classFieldName);
        if (sortedSetDocValues != null) {
          valueCount = sortedSetDocValues.getValueCount();
        }
      }
      if (classValues == null) {
        // approximate with no. of terms
        noOfClasses += leave.reader().terms(classFieldName).size();
      }
      noOfClasses += valueCount;
    }

    try {

      IndexSearcher indexSearcher = new IndexSearcher(originalIndex);
      GroupingSearch gs = new GroupingSearch(classFieldName);
      gs.setGroupSort(Sort.INDEXORDER);
      gs.setSortWithinGroup(Sort.INDEXORDER);
      gs.setAllGroups(true);
      gs.setGroupDocsLimit(originalIndex.maxDoc());
      TopGroups topGroups =
          gs.search(indexSearcher, new MatchAllDocsQuery(), 0, noOfClasses);

      // set the type to be indexed, stored, with term vectors
      FieldType ft = new FieldType(TextField.TYPE_STORED);
      if (termVectors) {
        ft.setStoreTermVectors(true);
        ft.setStoreTermVectorOffsets(true);
        ft.setStoreTermVectorPositions(true);
      }

      int b = 0;

      // iterate over existing documents
      StoredFields storedFields = originalIndex.storedFields();
      for (GroupDocs group : topGroups.groups) {
        assert group.totalHits().relation() == TotalHits.Relation.EQUAL_TO;
        long totalHits = group.totalHits().value();
        double testSize = totalHits * testRatio;
        int tc = 0;
        double cvSize = totalHits * crossValidationRatio;
        int cvc = 0;
        for (ScoreDoc scoreDoc : group.scoreDocs()) {

          // create a new document for indexing
          Document doc = createNewDoc(storedFields, ft, scoreDoc, fieldNames);

          // add it to one of the IDXs
          if (b % 2 == 0 && tc < testSize) {
            testWriter.addDocument(doc);
            tc++;
          } else if (cvc < cvSize) {
            cvWriter.addDocument(doc);
            cvc++;
          } else {
            trainingWriter.addDocument(doc);
          }
          b++;
        }
      }
      // commit
      testWriter.commit();
      cvWriter.commit();
      trainingWriter.commit();

      // merge
      testWriter.forceMerge(3);
      cvWriter.forceMerge(3);
      trainingWriter.forceMerge(3);
    } catch (Exception e) {
      throw new IOException(e);
    } finally {
      // close IWs
      IOUtils.close(testWriter, cvWriter, trainingWriter, originalIndex);
    }
  }

  private Document createNewDoc(
      StoredFields originalFields, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames)
      throws IOException {
    Document doc = new Document();
    Document document = originalFields.document(scoreDoc.doc);
    if (fieldNames != null && fieldNames.length > 0) {
      for (String fieldName : fieldNames) {
        IndexableField field = document.getField(fieldName);
        if (field != null) {
          doc.add(new Field(fieldName, field.stringValue(), ft));
        }
      }
    } else {
      for (IndexableField field : document.getFields()) {
        if (field.readerValue() != null) {
          doc.add(new Field(field.name(), field.readerValue(), ft));
        } else if (field.binaryValue() != null) {
          doc.add(new Field(field.name(), field.binaryValue(), ft));
        } else if (field.stringValue() != null) {
          doc.add(new Field(field.name(), field.stringValue(), ft));
        } else if (field.numericValue() != null) {
          doc.add(new Field(field.name(), field.numericValue().toString(), ft));
        }
      }
    }
    return doc;
  }
}