All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ctakes.ytex.weka.WekaResultsImporterImpl Maven / Gradle / Ivy

The newest version!
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.ctakes.ytex.weka;

import org.apache.ctakes.ytex.kernel.model.ClassifierEvaluation;
import org.apache.ctakes.ytex.kernel.model.ClassifierInstanceEvaluation;
import org.hibernate.SessionFactory;

import java.io.BufferedReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;


/**
 * parse weka instance output when classifier run with -p option. load results
 * into db.
 */
public class WekaResultsImporterImpl implements WekaResultsImporter {
	// inst# actual predicted error prediction (instance_id,C0000726,C0000731)
	// inst#     actual  predicted error prediction (instance_id)
	// private static final Pattern patHeader = Pattern
	// .compile("\\sinst#\\s+actual\\s+predicted\\s+error\\s+");
	private static final Pattern patHeader = Pattern
			.compile("\\s*inst#.*actual.*predicted.*error");
	// 1 1:0 1:0 0.988 (330478,101,0)
	// private static final Pattern patResult = Pattern
	// .compile("\\s+(\\d+)\\s+(\\d+)\\:\\d+\\s+(\\d+)\\:\\d+\\s+\\+{0,1}\\s+(\\d\\.\\d+)\\s+\\((.*)\\)");
	private static final Pattern patResult = Pattern
			.compile("\\s+(\\d+)\\s+(\\d+)\\:.*\\s+(\\d+)\\:.*\\s+\\+{0,1}\\s+(.*)\\s+\\((.*)\\)");
	private SessionFactory sessionFactory;

	public SessionFactory getSessionFactory() {
		return sessionFactory;
	}

	public void setSessionFactory(SessionFactory sessionFactory) {
		this.sessionFactory = sessionFactory;
	}

	/**
	 * this imports the classification results for a document
	 */
	private DocumentResultInstanceImporter docResultInstanceImporter;

	public DocumentResultInstanceImporter getDocResultInstanceImporter() {
		return docResultInstanceImporter;
	}

	public void setDocResultInstanceImporter(
			DocumentResultInstanceImporter docResultInstanceImporter) {
		this.docResultInstanceImporter = docResultInstanceImporter;
	}

	/**
	 * Delegate to importResults
	 * 
	 * @see org.apache.ctakes.ytex.weka.WekaResultsImporter#importDocumentResults(java.lang.String,
	 *      java.io.BufferedReader)
	 */
	public void importDocumentResults(String task, BufferedReader reader)
			throws IOException {
		this.importResults(docResultInstanceImporter, task, reader);
	}

	/**
	 * Parse results, pass them off to WekaResultInstanceImporter
	 * 
//	 * @see org.apache.ctakes.ytex.weka.WekaResultsImporter#importResults(org.apache.ctakes.ytex.weka.
//	 *      WekaResultInstanceImporter, java.lang.String,
//	 *      java.io.BufferedReader)
	 */
	public void importResults(
			WekaResultInstanceImporter resultInstanceImporter, String task,
			BufferedReader reader) throws IOException {
		String line = null;
		boolean keepGoing = true;
		boolean inResults = false;
		while ((line = reader.readLine()) != null && keepGoing) {
			if (!inResults) {
				// not yet in the results section - see if we found the header
				inResults = patHeader.matcher(line).find();
			} else {
				// in results section - see if this line contains some results
				Matcher matcher = patResult.matcher(line);
				if (matcher.find()) {
					// matches - parse it, pass it to the instance importer
					int instanceNum = Integer.parseInt(matcher.group(1));
					int classGold = Integer.parseInt(matcher.group(2));
					int classAuto = Integer.parseInt(matcher.group(3));
					// split something like
					// *0.988 0.012
					// or
					// *0.988,0.012
					// or just
					// 0.988
					String[] arrPredictionStr = matcher.group(4).split(
							"\\*|,|\\s");
					List listPredictions = new ArrayList(
							arrPredictionStr.length);
					for (String predStr : arrPredictionStr) {
						if (predStr.length() > 0)
							listPredictions.add(Double.parseDouble(predStr));
					}
					List instanceKey = Arrays.asList(matcher.group(5)
							.split(","));
					resultInstanceImporter.importInstanceResult(instanceNum,
							instanceKey, task, classAuto, classGold,
							listPredictions);
				} else {
					// hit end of results - stop
					keepGoing = false;
				}
			}
		}
	}

	public void importClassifierEvaluation(String name, Integer fold,
			String algorithm, String label, String options, String experiment,
			BufferedReader reader) throws IOException {
		ClassifierEvaluation ce = new ClassifierEvaluation();
		ce.setName(name);
		ce.setFold(fold);
		ce.setAlgorithm(algorithm);
		ce.setLabel(label);
		ce.setOptions(options);
		ce.setExperiment(experiment);
		this.getSessionFactory().getCurrentSession().save(ce);
		ClassifierEvaluationInstanceImporter instanceImporter = new ClassifierEvaluationInstanceImporter(
				ce, false);
		this.importResults(instanceImporter, label, reader);
	}

	public class ClassifierEvaluationInstanceImporter implements
			WekaResultInstanceImporter {
		private ClassifierEvaluation classifierEvaluation;
		private boolean storeProbabilities;

		public ClassifierEvaluationInstanceImporter(
				ClassifierEvaluation classifierEvaluation,
				boolean storeProbabilities) {
			super();
			this.classifierEvaluation = classifierEvaluation;
			this.storeProbabilities = storeProbabilities;
		}

		@Override
		public void importInstanceResult(Integer instanceNumber,
				List instanceKey, String task, int classAuto,
				int classGold, List predictions) {
			ClassifierInstanceEvaluation ci = new ClassifierInstanceEvaluation();
			if (instanceKey.size() > 0) {
				try {
					ci.setInstanceId(Integer.parseInt(instanceKey.get(0)));
				} catch (NumberFormatException nfe) {
				}
			} else {
				ci.setInstanceId(instanceNumber);
			}
			ci.setPredictedClassId(classAuto);
			ci.setTargetClassId(classGold);
			if (storeProbabilities && predictions != null
					&& predictions.size() > 0) {
				for (int i = 0; i < predictions.size(); i++) {
					ci.getClassifierInstanceProbabilities().put(i,
							predictions.get(i));
				}
			}
			ci.setClassifierEvaluation(classifierEvaluation);
			sessionFactory.getCurrentSession().save(ci);
		}

	}

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy