All Downloads are FREE. Search and download functionalities are using the official Maven repository.

gate.plugin.learningframework.data.CorpusRepresentationMalletSeq Maven / Gradle / Ivy

Go to download

A GATE plugin that provides many different machine learning algorithms for a wide range of NLP-related machine learning tasks like text classification, tagging, or chunking.

There is a newer version: 4.2
Show newest version
/*
 * Copyright (c) 2015-2016 The University Of Sheffield.
 *
 * This file is part of gateplugin-LearningFramework 
 * (see https://github.com/GateNLP/gateplugin-LearningFramework).
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation, either version 2.1 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this software. If not, see .
 */
package gate.plugin.learningframework.data;

import gate.Annotation;
import gate.AnnotationSet;
import java.util.ArrayList;
import java.util.List;
import cc.mallet.pipe.Noop;
import cc.mallet.pipe.Pipe;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.FeatureVectorSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import gate.plugin.learningframework.ScalingMethod;
import gate.plugin.learningframework.features.FeatureExtractionMalletSparse;
import gate.plugin.learningframework.features.FeatureInfo;
import gate.plugin.learningframework.features.TargetType;
import gate.plugin.learningframework.mallet.LFPipe;
import gate.util.GateRuntimeException;
import org.apache.log4j.Logger;
import static gate.plugin.learningframework.data.CorpusRepresentationMalletTarget.extractIndependentFeaturesHelper;
import gate.plugin.learningframework.features.SeqEncoder;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.net.URL;
import static gate.plugin.learningframework.LFUtils.newURL;
import gate.plugin.learningframework.mallet.LFAlphabet;
import gate.plugin.learningframework.mallet.LFInstanceList;

public class CorpusRepresentationMalletSeq extends CorpusRepresentationMallet {

  static final Logger LOGGER = Logger.getLogger("CorpusRepresentationMalletSeq");

  public CorpusRepresentationMalletSeq(FeatureInfo fi) {
    featureInfo = fi;
    scalingMethod = fi.getGlobalScalingMethod();

    Pipe innerPipe = new Noop(new LFAlphabet(), new LabelAlphabet());
    List pipes = new ArrayList<>();
    pipes.add(innerPipe);
    pipe = new LFPipe(pipes);
    pipe.setFeatureInfo(fi);
    instances = new LFInstanceList(pipe);
    targetType = TargetType.NOMINAL;
  }

  /**
   * Non-public constructor for use when creating from a serialized pipe.
   *
   * @param fi
   */
  CorpusRepresentationMalletSeq(LFPipe pipe) {
    this.pipe = pipe;
    this.featureInfo = pipe.getFeatureInfo();
    this.scalingMethod = this.featureInfo.getGlobalScalingMethod();
    this.instances = new LFInstanceList(pipe);
    this.targetType = TargetType.NOMINAL;
  }

  /**
   * Create a new instance based on the pipe stored in directory.
   *
   * @param directory directory with the saved model
   * @return corpus representation instance
   */
  public static CorpusRepresentationMalletSeq load(URL directory) {
    // load the pipe from a Java object serialization representation
    URL inFile = newURL(directory, "pipe.pipe");
    LFPipe lfpipe = null;
    try (InputStream bom = inFile.openStream();
         ObjectInputStream ois = new ObjectInputStream(bom)) {
      lfpipe = (LFPipe) ois.readObject();      
    } catch (Exception ex) { 
      throw new GateRuntimeException("Could not read pipe from "+inFile,ex);
    }
    CorpusRepresentationMalletSeq crms = new CorpusRepresentationMalletSeq(lfpipe);
    return crms;
  }


  /**
   * Add instances. 
   * 
   * The exact way of how the target is created to the instances depends on which
   * parameters are given and which are null. The parameter sequenceAS must always be non-null for
   * this corpus representation since this corpus representation is always used with sequence
   * tagging algorithms If the parameter classAS is non-null then instances for a sequence tagging
   * task are created, in that case targetFeatureName must be null. If targetFeatureName is non-null
   * then instances for a regression or classification problem are created (depending on targetType)
   * and classAS must be null. if the parameter nameFeatureName is non-null, then a Mallet instance
   * name is added from the source document and annotation.
   *
   * @param instancesAS instance annotation set
   * @param sequenceAS sequence annotation set
   * @param inputAS input annotation set
   * @param classAS class annotation set
   * @param targetFeatureName feature name of target
   * @param targetType type of target
   * @param instanceWeightFeature ignored, this is only relevant for classification/regression
   * @param nameFeatureName feature for instance name, not used at the moment
   * @param seqEncoder sequence encoder instance
   */
  @Override
  public void add(AnnotationSet instancesAS, AnnotationSet sequenceAS, AnnotationSet inputAS, AnnotationSet classAS, String targetFeatureName, TargetType targetType, String instanceWeightFeature, String nameFeatureName, SeqEncoder seqEncoder) {
    if (sequenceAS == null) {
      throw new GateRuntimeException("LF invalid call to CorpusRepresentationMallet.add: sequenceAS must not be null "
              + " for document " + inputAS.getDocument().getName());
    }
    for (Annotation sequenceAnnotation : sequenceAS.inDocumentOrder()) {
      Instance inst = getInstanceForSequence(instancesAS, sequenceAnnotation, inputAS, classAS, targetFeatureName, targetType, nameFeatureName, seqEncoder);
        instances.add(inst);
    }
  }

  @Override
  public void finishAdding() {  
    if(scalingMethod != ScalingMethod.NONE) {
      throw new GateRuntimeException("Scaling not allowed/not yet implemented for sequence tagging representation");
    }
  }
  
  /**
   * Get a single Instance for a sequence annotation. If the
   *
   * @param instancesAS instance annotation set
   * @param sequenceAnnotation sequence annotation set
   * @param inputAS input annotation set
   * @param classAS class annotation set
   * @param targetFeatureName name of target feature
   * @param targetType type of target
   * @param nameFeatureName name feature
   * @param seqEncoder sequence encoder instance
   * @return the Instance instance
   */
  public Instance getInstanceForSequence(
          AnnotationSet instancesAS,
          Annotation sequenceAnnotation,
          AnnotationSet inputAS,
          AnnotationSet classAS,
          String targetFeatureName,
          TargetType targetType,
          String nameFeatureName,
          SeqEncoder seqEncoder) {

    List instanceAnnotations = gate.Utils.getContainedAnnotations(instancesAS, sequenceAnnotation).inDocumentOrder();
    List instanceList = new ArrayList<>(instanceAnnotations.size());
    for (Annotation instanceAnnotation : instanceAnnotations) {
      Instance inst = extractIndependentFeaturesHelper(instanceAnnotation, inputAS, featureInfo, pipe);
      if (targetType != TargetType.NONE) {
        if (classAS != null) {
          // extract the target as required for sequence tagging
          FeatureExtractionMalletSparse.extractClassForSeqTagging(inst, pipe.getTargetAlphabet(), classAS, instanceAnnotation, seqEncoder);
        } else if (targetType == TargetType.NOMINAL) {
          FeatureExtractionMalletSparse.extractClassTarget(inst, pipe.getTargetAlphabet(), targetFeatureName, instanceAnnotation, inputAS);
        } else if (targetType == TargetType.NUMERIC) {
          FeatureExtractionMalletSparse.extractNumericTarget(inst, targetFeatureName, instanceAnnotation, inputAS);
        }
      }
      if (!FeatureExtractionMalletSparse.ignoreInstanceWithMV(inst)) {
        instanceList.add(inst);
      }
    }
    FeatureVector[] vectors = new FeatureVector[instanceList.size()];
    for (int i = 0; i < vectors.length; i++) {
      vectors[i] = (FeatureVector) instanceList.get(i).getData();
    }
    FeatureVectorSequence fvseq = new FeatureVectorSequence(vectors);
    FeatureSequence fseq = null;
    if (targetType != TargetType.NONE) {
      int[] labelidxs = new int[instanceList.size()];
      for (int i = 0; i < labelidxs.length; i++) {
        labelidxs[i] = ((Label) instanceList.get(i).getTarget()).getIndex();
      }
      fseq = new FeatureSequence(pipe.getTargetAlphabet(), labelidxs);
    }
    // create the final instance, if a name feature is given also add the name
    Instance finalInst = new Instance(fvseq, fseq, null, null);
    if (nameFeatureName != null) {
      FeatureExtractionMalletSparse.extractName(finalInst, sequenceAnnotation, inputAS.getDocument());
    }
    return finalInst;

  }

  @Override
  public int nrInstances() {
    if(instances == null) {
      return 0;
    } else {
      return instances.size();
    }
  }
  
  
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy