All Downloads are FREE. Search and download functionalities are using the official Maven repository.

opennlp.tools.ml.model.TwoPassDataIndexer Maven / Gradle / Ivy

There is a newer version: 2.5.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package opennlp.tools.ml.model;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.CRC32C;
import java.util.zip.CheckedInputStream;
import java.util.zip.CheckedOutputStream;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import opennlp.tools.util.ObjectStream;

/**
 * Collecting event and context counts by making two passes over the events.
 * 

* The first pass determines which contexts will be used by the model, and the * second pass creates the events in memory containing only the contexts which * will be used. This greatly reduces the amount of memory required for storing * the events. During the first pass a temporary event file is created which * is read during the second pass. * * @see DataIndexer * @see AbstractDataIndexer */ public class TwoPassDataIndexer extends AbstractDataIndexer { private static final Logger logger = LoggerFactory.getLogger(TwoPassDataIndexer.class); public TwoPassDataIndexer() {} /** * {@inheritDoc} */ @Override public void index(ObjectStream eventStream) throws IOException { int cutoff = trainingParameters.getIntParameter(CUTOFF_PARAM, CUTOFF_DEFAULT); boolean sort = trainingParameters.getBooleanParameter(SORT_PARAM, SORT_DEFAULT); logger.info("Indexing events with TwoPass using cutoff of {}", cutoff); logger.info("Computing event counts..."); long start = System.currentTimeMillis(); Map predicateIndex = new HashMap<>(); File tmp = Files.createTempFile("events", null).toFile(); tmp.deleteOnExit(); int numEvents; long writeChecksum; try (BufferedOutputStream out = new BufferedOutputStream(new FileOutputStream(tmp)); CheckedOutputStream writeStream = new CheckedOutputStream(out, new CRC32C()); DataOutputStream dos = new DataOutputStream(writeStream)) { numEvents = computeEventCounts(eventStream, dos, predicateIndex, cutoff); writeChecksum = writeStream.getChecksum().getValue(); logger.info("done. {} events", numEvents); } List eventsToCompare; long readChecksum; try (BufferedInputStream in = new BufferedInputStream(new FileInputStream(tmp)); CheckedInputStream readStream = new CheckedInputStream(in, new CRC32C()); EventStream readEventsStream = new EventStream(new DataInputStream(readStream))) { logger.info("Indexing..."); eventsToCompare = index(readEventsStream, predicateIndex); readChecksum = readStream.getChecksum().getValue(); } tmp.delete(); if (readChecksum != writeChecksum) { throw new IOException("Checksum for writing and reading events did not match."); } else { logger.info("done."); if (sort) { logger.info("Sorting and merging events... "); } else { logger.info("Collecting events... "); } sortAndMerge(eventsToCompare,sort); logger.info(String.format("Done indexing in %.2f s.", (System.currentTimeMillis() - start) / 1000d)); } } /** * Reads events from eventStream into a linked list. The * predicates associated with each event are counted and any which * occur at least cutoff times are added to the * predicatesInOut map along with a unique integer index. *

* Protocol: * 1 - (utf string) - Event outcome * 2 - (int) - Event context array length * 3+ - (utf string) - Event context string * 4 - (int) - Event values array length * 5+ - (float) - Event value * * @param eventStream an EventStream value * @param eventStore a writer to which the events are written to for later processing. * @param predicatesInOut a TObjectIntHashMap value * @param cutoff an int value */ private int computeEventCounts(ObjectStream eventStream, DataOutputStream eventStore, Map predicatesInOut, int cutoff) throws IOException { Map counter = new HashMap<>(); int eventCount = 0; Event ev; while ((ev = eventStream.read()) != null) { eventCount++; eventStore.writeUTF(ev.getOutcome()); eventStore.writeInt(ev.getContext().length); String[] ec = ev.getContext(); update(ec, counter); for (String ctxString : ec) eventStore.writeUTF(ctxString); if (ev.getValues() == null) { eventStore.writeInt(0); } else { eventStore.writeInt(ev.getValues().length); for (float value : ev.getValues()) eventStore.writeFloat(value); } } String[] predicateSet = counter.entrySet().stream() .filter(entry -> entry.getValue() >= cutoff) .map(Map.Entry::getKey).sorted() .toArray(String[]::new); predCounts = new int[predicateSet.length]; for (int i = 0; i < predicateSet.length; i++) { predCounts[i] = counter.get(predicateSet[i]); predicatesInOut.put(predicateSet[i], i); } return eventCount; } private static class EventStream implements ObjectStream { private final DataInputStream inputStream; public EventStream(DataInputStream dataInputStream) { this.inputStream = dataInputStream; } @Override public Event read() throws IOException { if (inputStream.available() != 0) { String outcome = inputStream.readUTF(); int contextLength = inputStream.readInt(); String[] context = new String[contextLength]; for (int i = 0; i < contextLength; i++) context[i] = inputStream.readUTF(); int valuesLength = inputStream.readInt(); float[] values = null; if (valuesLength > 0) { values = new float[valuesLength]; for (int i = 0; i < valuesLength; i++) values[i] = inputStream.readFloat(); } return new Event(outcome, context, values); } else { return null; } } @Override public void reset() throws IOException, UnsupportedOperationException { throw new UnsupportedOperationException(); } @Override public void close() throws IOException { inputStream.close(); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy