All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.ctakes.ytex.kernel.dao.ClassifierEvaluationDaoImpl Maven / Gradle / Ivy

The newest version!
/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.ctakes.ytex.kernel.dao;

import org.apache.ctakes.ytex.dao.DBUtil;
import org.apache.ctakes.ytex.kernel.InfoContentEvaluator;
import org.apache.ctakes.ytex.kernel.IntrinsicInfoContentEvaluator;
import org.apache.ctakes.ytex.kernel.metric.ConceptInfo;
import org.apache.ctakes.ytex.kernel.model.*;
import org.hibernate.Query;
import org.hibernate.SessionFactory;

import java.util.*;


public class ClassifierEvaluationDaoImpl implements ClassifierEvaluationDao {
	private SessionFactory sessionFactory;

	public SessionFactory getSessionFactory() {
		return sessionFactory;
	}

	public void setSessionFactory(SessionFactory sessionFactory) {
		this.sessionFactory = sessionFactory;
	}

	@SuppressWarnings("unchecked")
	@Override
	public void deleteCrossValidationFoldByName(String corpusName,
			String splitName) {
		Query q = this.getSessionFactory().getCurrentSession()
				.getNamedQuery("getCrossValidationFoldByName");
		q.setString("corpusName", corpusName);
		q.setString("splitName", nullToEmptyString(splitName));
		List folds = q.list();
		for (CrossValidationFold fold : folds)
			this.getSessionFactory().getCurrentSession().delete(fold);
	}

	@Override
	public CrossValidationFold getCrossValidationFold(String corpusName,
			String splitName, String label, int run, int fold) {
		Query q = this.getSessionFactory().getCurrentSession()
				.getNamedQuery("getCrossValidationFold");
		q.setString("corpusName", corpusName);
		q.setString("splitName", nullToEmptyString(splitName));
		q.setString("label", nullToEmptyString(label));
		q.setInteger("run", run);
		q.setInteger("fold", fold);
		return (CrossValidationFold) q.uniqueResult();
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * org.apache.ctakes.ytex.kernel.dao.ClassifierEvaluationDao#saveClassifierEvaluation(org.apache.ctakes.ytex
	 * .kernel.model.ClassifierEvaluation)
	 */
	public void saveClassifierEvaluation(ClassifierEvaluation eval,
			Map irClassMap, boolean saveInstanceEval) {
		saveClassifierEvaluation(eval, irClassMap, saveInstanceEval, true, null);
	}

	public void saveClassifierEvaluation(ClassifierEvaluation eval,
			Map irClassMap, boolean saveInstanceEval,
			boolean saveIRStats, Integer excludeTargetClassId) {
		this.getSessionFactory().getCurrentSession().save(eval);
		if (saveIRStats)
			this.saveIRStats(eval, irClassMap, excludeTargetClassId);
		if (saveInstanceEval) {
			for (ClassifierInstanceEvaluation instanceEval : eval
					.getClassifierInstanceEvaluations().values()) {
				this.getSessionFactory().getCurrentSession().save(instanceEval);
			}
		}
	}

	void saveIRStats(ClassifierEvaluation eval,
			Map irClassMap, Integer excludeTargetClassId) {
		Set classIds = this.getClassIds(eval, excludeTargetClassId);
		// setup stats
		for (Integer irClassId : classIds) {
			String irClass = null;
			if (irClassMap != null)
				irClass = irClassMap.get(irClassId);
			if (irClass == null)
				irClass = Integer.toString(irClassId);
			ClassifierEvaluationIRStat irStat = calcIRStats(irClass, irClassId,
					eval, excludeTargetClassId);
			this.getSessionFactory().getCurrentSession().save(irStat);
		}
	}

	/**
	 * 
	 * @param irClassId
	 *            the target class id with respect to ir statistics will be
	 *            calculated
	 * @param eval
	 *            the object to update
	 * @param excludeTargetClassId
	 *            class id to be excluded from computation of ir stats.
	 * @return
	 */
	private ClassifierEvaluationIRStat calcIRStats(String irClass,
			Integer irClassId, ClassifierEvaluation eval,
			Integer excludeTargetClassId) {
		int tp = 0;
		int tn = 0;
		int fp = 0;
		int fn = 0;
		for (ClassifierInstanceEvaluation instanceEval : eval
				.getClassifierInstanceEvaluations().values()) {

			if (instanceEval.getTargetClassId() != null
					&& (excludeTargetClassId == null || instanceEval
							.getTargetClassId() != excludeTargetClassId
							.intValue())) {
				if (instanceEval.getTargetClassId() == irClassId) {
					if (instanceEval.getPredictedClassId() == instanceEval
							.getTargetClassId()) {
						tp++;
					} else {
						fn++;
					}
				} else {
					if (instanceEval.getPredictedClassId() == irClassId) {
						fp++;
					} else {
						tn++;
					}
				}
			}
		}
		return new ClassifierEvaluationIRStat(eval, null, irClass, irClassId,
				tp, tn, fp, fn);
	}

	private Set getClassIds(ClassifierEvaluation eval,
			Integer excludeTargetClassId) {
		Set classIds = new HashSet();
		for (ClassifierInstanceEvaluation instanceEval : eval
				.getClassifierInstanceEvaluations().values()) {
			classIds.add(instanceEval.getPredictedClassId());
			if (instanceEval.getTargetClassId() != null
					&& (excludeTargetClassId == null || instanceEval
							.getTargetClassId() != excludeTargetClassId
							.intValue()))
				classIds.add(instanceEval.getTargetClassId());
		}
		return classIds;
	}

	@Override
	public void saveFold(CrossValidationFold fold) {
		this.getSessionFactory().getCurrentSession().save(fold);
	}

	// @Override
	// public void saveInfogain(List foldInfogainList) {
	// for(FeatureInfogain ig : foldInfogainList) {
	// this.getSessionFactory().getCurrentSession().save(ig);
	// }
	// }

	@Override
	public void saveFeatureEvaluation(FeatureEvaluation featureEvaluation,
			List features) {
		this.getSessionFactory().getCurrentSession().save(featureEvaluation);
		for (FeatureRank r : features)
			this.getSessionFactory().getCurrentSession().save(r);
	}

	@SuppressWarnings("unchecked")
	@Override
	public void deleteFeatureEvaluationByNameAndType(String corpusName,
			String featureSetName, String type) {
		Query q = this.getSessionFactory().getCurrentSession()
				.getNamedQuery("getFeatureEvaluationByNameAndType");
		q.setString("corpusName", corpusName);
		q.setString("featureSetName", nullToEmptyString(featureSetName));
		q.setString("type", type);
		for (FeatureEvaluation fe : (List) q.list())
			this.getSessionFactory().getCurrentSession().delete(fe);
	}

	@SuppressWarnings("unchecked")
	@Override
	public List getTopFeatures(String corpusName,
			String featureSetName, String label, String evaluationType,
			Integer foldId, double param1, String param2,
			Integer parentConceptTopThreshold) {
		Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
				label, evaluationType, foldId, param1, param2, "getTopFeatures");
		q.setMaxResults(parentConceptTopThreshold);
		return q.list();
	}

	@Override
	public Double getMaxFeatureEvaluation(String corpusName,
			String featureSetName, String label, String evaluationType,
			Integer foldId, double param1, String param2) {
		Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
				label, evaluationType, foldId, param1, param2,
				"getMaxFeatureEvaluation");
		return (Double) q.uniqueResult();
	}

	private Query prepareUniqueFeatureEvalQuery(String corpusName,
			String featureSetName, String label, String evaluationType,
			Integer foldId, Double param1, String param2, String queryName) {
		Query q = this.sessionFactory.getCurrentSession().getNamedQuery(
				queryName);
		q.setString("corpusName", nullToEmptyString(corpusName));
		q.setString("featureSetName", nullToEmptyString(featureSetName));
		q.setString("label", nullToEmptyString(label));
		q.setString("evaluationType", evaluationType);
		q.setDouble("param1", param1 == null ? 0 : param1);
		q.setString("param2", nullToEmptyString(param2));
		q.setInteger("crossValidationFoldId", foldId == null ? 0 : foldId);
		return q;
	}

	/**
	 * todo for oracle need to handle empty strings differently
	 * 
	 * @param param1
	 * @return
	 */
	private String nullToEmptyString(String param1) {
		return DBUtil.nullToEmptyString(param1);
	}

	@SuppressWarnings("unchecked")
	@Override
	public List getThresholdFeatures(String corpusName,
			String featureSetName, String label, String evaluationType,
			Integer foldId, double param1, String param2,
			double evaluationThreshold) {
		Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
				label, evaluationType, foldId, param1, param2,
				"getThresholdFeatures");
		q.setDouble("evaluation", evaluationThreshold);
		return q.list();
	}

	@Override
	public void deleteFeatureEvaluation(String corpusName,
			String featureSetName, String label, String evaluationType,
			Integer foldId, Double param1, String param2) {
		Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
				label, evaluationType, foldId, param1, param2,
				"getFeatureEvaluationByNK");
		FeatureEvaluation fe = (FeatureEvaluation) q.uniqueResult();
		if (fe != null) {
			// for some reason this isn't working - execute batch updates
			// this.sessionFactory.getCurrentSession().delete(fe);
			q = this.sessionFactory.getCurrentSession().getNamedQuery(
					"deleteFeatureRank");
			q.setInteger("featureEvaluationId", fe.getFeatureEvaluationId());
			q.executeUpdate();
			q = this.sessionFactory.getCurrentSession().getNamedQuery(
					"deleteFeatureEval");
			q.setInteger("featureEvaluationId", fe.getFeatureEvaluationId());
			q.executeUpdate();
		}
	}

	public Map getFeatureRanks(Set featureNames,
			String corpusName, String featureSetName, String label,
			String evaluationType, Integer foldId, double param1, String param2) {
		Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
				label, evaluationType, foldId, param1, param2,
				"getFeatureRankEvaluations");
		q.setParameterList("featureNames", featureNames);
		@SuppressWarnings("unchecked")
		List featureRanks = q.list();
		Map frMap = new HashMap(
				featureRanks.size());
		for (FeatureRank fr : featureRanks)
			frMap.put(fr.getFeatureName(), fr);
		return frMap;
	}

	public Map getFeatureRankEvaluations(
			Set featureNames, String corpusName, String featureSetName,
			String label, String evaluationType, Integer foldId, double param1,
			String param2) {
		Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
				label, evaluationType, foldId, param1, param2,
				"getFeatureRankEvaluations");
		q.setParameterList("featureNames", featureNames);
		List featureRanks = q.list();
		Map evalMap = new HashMap(
				featureRanks.size());
		for (FeatureRank fr : featureRanks)
			evalMap.put(fr.getFeatureName(), fr.getEvaluation());
		return evalMap;
	}

	@Override
	public Map getFeatureRankEvaluations(String corpusName,
			String featureSetName, String label, String evaluationType,
			Integer foldId, double param1, String param2) {
		Query q = prepareUniqueFeatureEvalQuery(corpusName, featureSetName,
				label, evaluationType, foldId, param1, param2, "getTopFeatures");
		@SuppressWarnings("unchecked")
		List listFeatureRank = q.list();
		Map mapFeatureEval = new HashMap(
				listFeatureRank.size());
		for (FeatureRank r : listFeatureRank) {
			mapFeatureEval.put(r.getFeatureName(), r.getEvaluation());
		}
		return mapFeatureEval;
	}

	@Override
	@SuppressWarnings("unchecked")
	public List getCorpusCuiTuis(String corpusName,
			String conceptGraphName, String conceptSetName) {
		Query q = prepareUniqueFeatureEvalQuery(corpusName, conceptSetName,
				null, InfoContentEvaluator.INFOCONTENT, 0, 0d,
				conceptGraphName, "getCorpusCuiTuis");
		return q.list();
	}

	@Override
	public Map getInfoContent(String corpusName,
			String conceptGraphName, String conceptSet) {
		return getFeatureRankEvaluations(corpusName, conceptSet, null,
				InfoContentEvaluator.INFOCONTENT, 0, 0, conceptGraphName);
	}

	@Override
	public List getIntrinsicInfoContent(
			String conceptGraphName) {
		Query q = prepareUniqueFeatureEvalQuery(null, null, null,
				IntrinsicInfoContentEvaluator.INTRINSIC_INFOCONTENT, null, null,
				conceptGraphName, "getIntrinsicInfoContent");
		return (List)q.list();
	}
	public Integer getMaxDepth(String conceptGraphName) {
		Query q = prepareUniqueFeatureEvalQuery(null, null, null,
				IntrinsicInfoContentEvaluator.INTRINSIC_INFOCONTENT, null, null,
				conceptGraphName, "getMaxFeatureRank");
		return (Integer)q.uniqueResult();
	}

	@Override
	public void saveFeatureParentChild(FeatureParentChild parchd) {
		this.sessionFactory.getCurrentSession().save(parchd);
	}

	@Override
	public List getImputedFeaturesByPropagatedCutoff(
			String corpusName, String conceptSetName, String label,
			String evaluationType, String conceptGraphName,
			String propEvaluationType, int propRankCutoff) {
		Query q = prepareUniqueFeatureEvalQuery(corpusName, conceptSetName,
				label, evaluationType, 0, 0d, conceptGraphName,
				"getImputedFeaturesByPropagatedCutoff");
		q.setInteger("propRankCutoff", propRankCutoff);
		q.setString("propEvaluationType", propEvaluationType);
		return q.list();
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy