All Downloads are FREE. Search and download functionalities are using the official Maven repository.

edu.stanford.nlp.ie.crf.CRFLogConditionalObjectiveFunctionWithDropout Maven / Gradle / Ivy

Go to download

Stanford CoreNLP provides a set of natural language analysis tools which can take raw English language text input and give the base forms of words, their parts of speech, whether they are names of companies, people, etc., normalize dates, times, and numeric quantities, mark up the structure of sentences in terms of phrases and word dependencies, and indicate which noun phrases refer to the same entities. It provides the foundational building blocks for higher level text understanding applications.

There is a newer version: 4.5.7
Show newest version
package edu.stanford.nlp.ie.crf; 
import edu.stanford.nlp.util.logging.Redwood;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.util.concurrent.*;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Quadruple;

import java.util.*;

/**
 * @author Mengqiu Wang
 */

public class CRFLogConditionalObjectiveFunctionWithDropout extends CRFLogConditionalObjectiveFunction  {

  /** A logger for this class */
  private static Redwood.RedwoodChannels log = Redwood.channels(CRFLogConditionalObjectiveFunctionWithDropout.class);

  private final double delta;
  private final double dropoutScale;
  private double[][] dropoutPriorGradTotal;
  private final boolean dropoutApprox;
  private double[][] weightSquare;

  private final int[][][][] totalData;  // data[docIndex][tokenIndex][][]
  private int unsupDropoutStartIndex;
  private final double unsupDropoutScale;

  private List>> dataFeatureHash;
  private List>> condensedMap;
  private int[][] dataFeatureHashByDoc;
  private int edgeLabelIndexSize;
  private int nodeLabelIndexSize;
  private int[][] edgeLabels;
  private Map> currPrevLabelsMap;
  private Map> currNextLabelsMap;

  private ThreadsafeProcessor, Quadruple, Map>> dropoutPriorThreadProcessor =
        new ThreadsafeProcessor, Quadruple, Map>>() {
    @Override
    public Quadruple, Map> process(Pair docIndexUnsup) {
      return expectedCountsAndValueForADoc(docIndexUnsup.first(), false, docIndexUnsup.second());
    }
    @Override
    public ThreadsafeProcessor, Quadruple, Map>> newInstance() {
      return this;
    }
  };

  //TODO(Mengqiu) Need to figure out what to do with dataDimension() in case of
  // mixed supervised+unsupervised data for SGD (AdaGrad)
  CRFLogConditionalObjectiveFunctionWithDropout(int[][][][] data, int[][] labels, int window, Index classIndex, List> labelIndices, int[] map, String priorType, String backgroundSymbol, double sigma, double[][][][] featureVal, double delta, double dropoutScale, int multiThreadGrad, boolean dropoutApprox, double unsupDropoutScale, int[][][][] unsupDropoutData) {
    super(data, labels, window, classIndex, labelIndices, map, priorType, backgroundSymbol, sigma, featureVal, multiThreadGrad);
    this.delta = delta;
    this.dropoutScale = dropoutScale;
    this.dropoutApprox = dropoutApprox;
    dropoutPriorGradTotal = empty2D();
    this.unsupDropoutStartIndex = data.length;
    this.unsupDropoutScale = unsupDropoutScale;
    if (unsupDropoutData != null) {
      this.totalData = new int[data.length + unsupDropoutData.length][][][];
      for (int i=0; i edgeLabelIndex = labelIndices.get(1);
    edgeLabelIndexSize = edgeLabelIndex.size();
    Index nodeLabelIndex = labelIndices.get(0);
    nodeLabelIndexSize = nodeLabelIndex.size();
    currPrevLabelsMap = new HashMap<>();
    currNextLabelsMap = new HashMap<>();
    edgeLabels = new int[edgeLabelIndexSize][];
    for (int k=0; k < edgeLabelIndexSize; k++) {
      int[] labelPair = edgeLabelIndex.get(k).getLabel();
      edgeLabels[k] = labelPair;
      int curr = labelPair[1];
      int prev = labelPair[0];
      if (!currPrevLabelsMap.containsKey(curr))
        currPrevLabelsMap.put(curr, new ArrayList<>(numClasses));
      currPrevLabelsMap.get(curr).add(prev);
      if (!currNextLabelsMap.containsKey(prev))
        currNextLabelsMap.put(prev, new ArrayList<>(numClasses));
      currNextLabelsMap.get(prev).add(curr);
    }
  }

  private Map sparseE(Set activeFeatures) {
    Map aMap = new HashMap<>(activeFeatures.size());
    for (int f: activeFeatures) {
      // System.err.printf("aMap.put(%d, new double[%d])\n", f, map[f]+1);
      aMap.put(f,new double[map[f] == 0 ? nodeLabelIndexSize : edgeLabelIndexSize]);
    }
    return aMap;
  }

  private Map sparseE(int[] activeFeatures) {
    Map aMap = new HashMap<>(activeFeatures.length);
    for (int f: activeFeatures) {
      // System.err.printf("aMap.put(%d, new double[%d])\n", f, map[f]+1);
      aMap.put(f,new double[map[f] == 0 ? nodeLabelIndexSize : edgeLabelIndexSize]);
    }
    return aMap;
  }

  private Quadruple, Map> expectedCountsAndValueForADoc(int docIndex,
      boolean skipExpectedCountCalc, boolean skipValCalc) {

    int[] activeFeatures = dataFeatureHashByDoc[docIndex];
    List> docDataHash = dataFeatureHash.get(docIndex);
    Map> condensedFeaturesMap = condensedMap.get(docIndex);

    double prob = 0;
    int[][][] docData = totalData[docIndex];
    int[] docLabels = null;
    if (docIndex < labels.length)
      docLabels = labels[docIndex];

    Timing timer = new Timing();
    double[][][] featureVal3DArr = null;
    if (featureVal != null)
      featureVal3DArr = featureVal[docIndex];

    // make a clique tree for this document
    CRFCliqueTree cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, labelIndices, numClasses, classIndex, backgroundSymbol, cliquePotentialFunc, featureVal3DArr);

    if (!skipValCalc) {
      if (TIMED)
        timer.start();
      // compute the log probability of the document given the model with the parameters x
      int[] given = new int[window - 1];
      Arrays.fill(given, classIndex.indexOf(backgroundSymbol));
      if (docLabels.length>docData.length) { // only true for self-training
        // fill the given array with the extra docLabels
        System.arraycopy(docLabels, 0, given, 0, given.length);
        // shift the docLabels array left
        int[] newDocLabels = new int[docData.length];
        System.arraycopy(docLabels, docLabels.length-newDocLabels.length, newDocLabels, 0, newDocLabels.length);
        docLabels = newDocLabels;
      }

      double startPosLogProb = cliqueTree.logProbStartPos();
      if (VERBOSE)
        System.err.printf("P_-1(Background) = % 5.3f\n", startPosLogProb);
      prob += startPosLogProb;

      // iterate over the positions in this document
      for (int i = 0; i < docData.length; i++) {
        int label = docLabels[i];
        double p = cliqueTree.condLogProbGivenPrevious(i, label, given);
        if (VERBOSE) {
          log.info("P(" + label + "|" + ArrayMath.toString(given) + ")=" + Math.exp(p));
        }
        prob += p;
        System.arraycopy(given, 1, given, 0, given.length - 1);
        given[given.length - 1] = label;
      }
      if (TIMED) {
        long elapsedMs = timer.stop();
        log.info("Calculate objective took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
      }
    }

    Map EForADoc = sparseE(activeFeatures);
    List> EForADocPos = null;
    if (dropoutApprox) {
      EForADocPos = new ArrayList<>(docData.length);
    }

    if (!skipExpectedCountCalc) {
      if (TIMED)
        timer.start();
      // compute the expected counts for this document, which we will need to compute the derivative
      // iterate over the positions in this document
      double fVal = 1.0;
      for (int i = 0; i < docData.length; i++) {
        Set docDataHashI = docDataHash.get(i);
        Map EForADocPosAtI = null;
        if (dropoutApprox)
          EForADocPosAtI = sparseE(docDataHashI);

        for (int fIndex: docDataHashI) {
          int j= map[fIndex];
          Index labelIndex = labelIndices.get(j);
          // for each possible labeling for that clique
          for (int k = 0; k < labelIndex.size(); k++) {
            int[] label = labelIndex.get(k).getLabel();
            double p = cliqueTree.prob(i, label); // probability of these labels occurring in this clique with these features
            if (dropoutApprox)
              increScore(EForADocPosAtI, fIndex, k, fVal * p);
            increScore(EForADoc, fIndex, k, fVal * p);
          }
        }
        if (dropoutApprox) {
          for (int fIndex: docDataHashI) {
            if (condensedFeaturesMap.containsKey(fIndex)) {
              List aList = condensedFeaturesMap.get(fIndex);
              for (int toCopyInto: aList) {
                double[] arr = EForADocPosAtI.get(fIndex);
                double[] targetArr = new double[arr.length];
                for (int q=0; q < arr.length; q++)
                  targetArr[q] = arr[q];
                EForADocPosAtI.put(toCopyInto, targetArr);
              }
            }
          }
          EForADocPos.add(EForADocPosAtI);
        }
      }

      // copy for condensedFeaturesMap
      for (Map.Entry> entry: condensedFeaturesMap.entrySet()) {
        int key = entry.getKey();
        List aList = entry.getValue();
        for (int toCopyInto: aList) {
          double[] arr = EForADoc.get(key);
          double[] targetArr = new double[arr.length];
          for (int i=0; i < arr.length; i++)
            targetArr[i] = arr[i];
          EForADoc.put(toCopyInto, targetArr);
        }
      }

      if (TIMED) {
        long elapsedMs = timer.stop();
        log.info("Expected count took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
      }
    }

    Map dropoutPriorGrad = null;
    if (prior == DROPOUT_PRIOR) {
      if (TIMED)
        timer.start();
      // we can optimize this, this is too large, don't need this big
      dropoutPriorGrad = sparseE(activeFeatures);

      // log.info("computing dropout prior for doc " + docIndex + " ... ");
      prob -= getDropoutPrior(cliqueTree, docData, EForADoc, docDataHash, activeFeatures, dropoutPriorGrad, condensedFeaturesMap, EForADocPos);
      // log.info(" done!");
      if (TIMED) {
        long elapsedMs = timer.stop();
        log.info("Dropout took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
      }
    }

    return new Quadruple<>(docIndex, prob, EForADoc, dropoutPriorGrad);
  }

  private void increScore(Map aMap, int fIndex, int k, double val) {
    aMap.get(fIndex)[k] += val;
  }

  private void increScoreAllowNull(Map aMap, int fIndex, int k, double val) {
    if (!aMap.containsKey(fIndex)) {
      aMap.put(fIndex, new double[map[fIndex] == 0 ? nodeLabelIndexSize : edgeLabelIndexSize]);
    }
    aMap.get(fIndex)[k] += val;
  }

  private void initializeDataFeatureHash() {
    int macroActiveFeatureTotalCount = 0;
    int macroCondensedTotalCount = 0;
    int macroDocPosCount = 0;

    log.info("initializing data feature hash, sup-data size: " + data.length + ", unsup data size: " + (totalData.length-data.length));
    dataFeatureHash = new ArrayList<>(totalData.length);
    condensedMap = new ArrayList<>(totalData.length);
    dataFeatureHashByDoc = new int[totalData.length][];
    for (int m=0; m < totalData.length; m++) {
      Map occurPos = new HashMap<>();

      int[][][] aDoc = totalData[m];
      List> aList = new ArrayList<>(aDoc.length);
      Set setOfFeatures = new HashSet<>();
      for (int i=0; i< aDoc.length; i++) { // positions in docI
        Set aSet = new HashSet<>();
        int[][] dataI = aDoc[i];
        for (int j=0; j < dataI.length; j++) {
          int[] dataJ = dataI[j];
          for (int item: dataJ) {
            if (j == 0) {
              if (occurPos.containsKey(item))
                occurPos.put(item, -1);
              else
                occurPos.put(item, i);
            }

            aSet.add(item);
          }
        }
        aList.add(aSet);
        setOfFeatures.addAll(aSet);
      }
      macroDocPosCount += aDoc.length;
      macroActiveFeatureTotalCount += setOfFeatures.size();

      if (CONDENSE) {
        if (DEBUG3)
          log.info("Before condense, activeFeatures = " + setOfFeatures.size());
        // examine all singletons, merge ones in the same position
        Map> condensedFeaturesMap = new HashMap<>();
        int[] representFeatures = new int[aDoc.length];
        Arrays.fill(representFeatures, -1);

        for (Map.Entry entry: occurPos.entrySet()) {
          int key = entry.getKey();
          int pos = entry.getValue();
          if (pos != -1) {
            if (representFeatures[pos] == -1) { // use this as representFeatures
              representFeatures[pos] = key;
              condensedFeaturesMap.put(key, new ArrayList<>());
            } else { // condense this one
              int rep = representFeatures[pos];
              condensedFeaturesMap.get(rep).add(key);
              // remove key
              aList.get(pos).remove(key);
              setOfFeatures.remove(key);
            }
          }
        }
        int condensedCount = 0;
        for(Iterator>> it = condensedFeaturesMap.entrySet().iterator(); it.hasNext(); ) {
          Map.Entry> entry = it.next();
          if(entry.getValue().size() == 0) {
            it.remove();
          } else {
            if (DEBUG3) {
              condensedCount += entry.getValue().size();
              for (int cond: entry.getValue())
                log.info("condense " + cond + " to " + entry.getKey());
            }
          }
        }
        if (DEBUG3)
          log.info("After condense, activeFeatures = " + setOfFeatures.size() + ", condensedCount = " + condensedCount);
        macroCondensedTotalCount += setOfFeatures.size();
        condensedMap.add(condensedFeaturesMap);
      }

      dataFeatureHash.add(aList);
      int[] arrOfIndex = new int[setOfFeatures.size()];
      int pos2 = 0;
      for(Integer ind: setOfFeatures)
        arrOfIndex[pos2++] = ind;
      dataFeatureHashByDoc[m] = arrOfIndex;
    }
    log.info("Avg. active features per position: " + (macroActiveFeatureTotalCount/ (macroDocPosCount+0.0)));
    log.info("Avg. condensed features per position: " + (macroCondensedTotalCount / (macroDocPosCount+0.0)));
    log.info("initializing data feature hash done!");
  }

  private double getDropoutPrior(CRFCliqueTree cliqueTree, int[][][] docData,
      Map EForADoc, List> docDataHash, int[] activeFeatures, Map dropoutPriorGrad,
      Map> condensedFeaturesMap, List> EForADocPos) {

    Map dropoutPriorGradFirstHalf = sparseE(activeFeatures);

    if (TIMED)
      log.info("activeFeatures size: "+activeFeatures.length + ", dataLen: " + docData.length);

    Timing timer = new Timing();
    if (TIMED)
      timer.start();

    double priorValue = 0;

    long elapsedMs = 0;
    Pair condProbs = getCondProbs(cliqueTree, docData);
    if (TIMED) {
      elapsedMs = timer.stop();
      log.info("\t cond prob took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
    }

    // first index position is curr index, second index curr-class, third index prev-class
    // e.g. [1][2][3] means curr is at position 1 with class 2, prev is at position 0 with class 3
    double[][][] prevGivenCurr = condProbs.first();
    // first index position is curr index, second index curr-class, third index next-class
    // e.g. [0][2][3] means curr is at position 0 with class 2, next is at position 1 with class 3
    double[][][] nextGivenCurr = condProbs.second();

    // first dim is doc length (i)
    // second dim is numOfFeatures (fIndex)
    // third dim is numClasses (y)
    // fourth dim is labelIndexSize (matching the clique type of fIndex, for \theta)
    double[][][][] FAlpha = null;
    double[][][][] FBeta  = null;
    if (!dropoutApprox) {
      FAlpha = new double[docData.length][][][];
      FBeta  = new double[docData.length][][][];
    }
    for (int i = 0; i < docData.length; i++) {
      if (!dropoutApprox) {
        FAlpha[i] = new double[activeFeatures.length][][];
        FBeta[i]  = new double[activeFeatures.length][][];
      }
    }

    if (!dropoutApprox) {
      if (TIMED) {
        timer.start();
      }
      // computing FAlpha
      int fIndex = 0;
      double aa, bb, cc = 0;
      boolean prevFeaturePresent  = false;
      for (int i = 1; i < docData.length; i++) {
        // for each possible clique at this position
        Set docDataHashIMinusOne = docDataHash.get(i-1);
        for (int fIndexPos = 0; fIndexPos < activeFeatures.length; fIndexPos++) {
          fIndex = activeFeatures[fIndexPos];
          prevFeaturePresent = docDataHashIMinusOne.contains(fIndex);
          int j = map[fIndex];
          Index labelIndex = labelIndices.get(j);
          int labelIndexSize = labelIndex.size();

          if (FAlpha[i-1][fIndexPos] == null) {
            FAlpha[i-1][fIndexPos] = new double[numClasses][labelIndexSize];
            for (int q = 0; q < numClasses; q++)
              FAlpha[i-1][fIndexPos][q] = new double[labelIndexSize];
          }

          for (Map.Entry> entry : currPrevLabelsMap.entrySet()) {
            int y = entry.getKey(); // value at i-1
            double[] sum = new double[labelIndexSize];
            for (int yPrime: entry.getValue()) { // value at i-2
              for (int kk = 0; kk < labelIndexSize; kk++) {
                int[] prevLabel = labelIndex.get(kk).getLabel();
                aa = (prevGivenCurr[i-1][y][yPrime]);
                bb = (prevFeaturePresent && ((j == 0 && prevLabel[0] == y) || (j == 1 && prevLabel[1] == y && prevLabel[0] == yPrime)) ? 1 : 0);
                cc = 0;
                if (FAlpha[i-1][fIndexPos][yPrime] != null)
                  cc = FAlpha[i-1][fIndexPos][yPrime][kk];
                sum[kk] +=  aa * (bb + cc);
                // sum[kk] += (prevGivenCurr[i-1][y][yPrime]) * ((prevFeaturePresent && ((j == 0 && prevLabel[0] == y) || (j == 1 && prevLabel[1] == y && prevLabel[0] == yPrime)) ? 1 : 0) + FAlpha[i-1][fIndexPos][yPrime][kk]);
                if (DEBUG2)
                  System.err.printf("alpha[%d][%d][%d][%d] += % 5.3f * (%d + % 5.3f), prevLabel=%s\n", i, fIndex, y, kk, (prevGivenCurr[i-1][y][yPrime]), (prevFeaturePresent && ((j == 0 && prevLabel[0] == y) || (j == 1 && prevLabel[1] == y && prevLabel[0] == yPrime)) ? 1 : 0) , FAlpha[i-1][fIndexPos][yPrime][kk], Arrays.toString(prevLabel));
              }
            }
            if (FAlpha[i][fIndexPos] == null) {
              FAlpha[i][fIndexPos] = new double[numClasses][];
            }
            FAlpha[i][fIndexPos][y] = sum;
            if (DEBUG2)
              log.info("FAlpha["+i+"]["+fIndexPos+"]["+y+"] = " + Arrays.toString(sum));

          }
        }
      }
      if (TIMED) {
        elapsedMs = timer.stop();
        log.info("\t alpha took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
        timer.start();
      }
      // computing FBeta
      int docDataLen = docData.length;
      for (int i = docDataLen-2; i >= 0; i--) {
        Set docDataHashIPlusOne = docDataHash.get(i+1);
        // for each possible clique at this position
        for (int fIndexPos = 0; fIndexPos < activeFeatures.length; fIndexPos++) {
          fIndex = activeFeatures[fIndexPos];
          boolean nextFeaturePresent = docDataHashIPlusOne.contains(fIndex);
          int j = map[fIndex];
          Index labelIndex = labelIndices.get(j);
          int labelIndexSize = labelIndex.size();

          if (FBeta[i+1][fIndexPos] == null) {
            FBeta[i+1][fIndexPos] = new double[numClasses][labelIndexSize];
            for (int q = 0; q < numClasses; q++)
              FBeta[i+1][fIndexPos][q] = new double[labelIndexSize];
          }

          for (Map.Entry> entry : currNextLabelsMap.entrySet()) {
            int y = entry.getKey(); // value at i
            double[] sum = new double[labelIndexSize];
            for (int yPrime: entry.getValue()) { // value at i+1
              for (int kk=0; kk < labelIndexSize; kk++) {
                int[] nextLabel = labelIndex.get(kk).getLabel();
                // log.info("labelIndexSize:"+labelIndexSize+", nextGivenCurr:"+nextGivenCurr+", nextLabel:"+nextLabel+", FBeta["+(i+1)+"]["+ fIndexPos +"]["+yPrime+"] :"+FBeta[i+1][fIndexPos][yPrime]);
                aa = (nextGivenCurr[i][y][yPrime]);
                bb = (nextFeaturePresent && ((j == 0 && nextLabel[0] == yPrime) || (j == 1 && nextLabel[0] == y && nextLabel[1] == yPrime)) ? 1 : 0);
                cc = 0;
                if (FBeta[i+1][fIndexPos][yPrime] != null)
                  cc = FBeta[i+1][fIndexPos][yPrime][kk];
                sum[kk] +=  aa * ( bb + cc);
                // sum[kk] += (nextGivenCurr[i][y][yPrime]) * ( (nextFeaturePresent && ((j == 0 && nextLabel[0] == yPrime) || (j == 1 && nextLabel[0] == y && nextLabel[1] == yPrime)) ? 1 : 0) + FBeta[i+1][fIndexPos][yPrime][kk]);
                if (DEBUG2)
                  System.err.printf("beta[%d][%d][%d][%d] += % 5.3f * (%d + % 5.3f)\n", i, fIndex, y, kk, (nextGivenCurr[i][y][yPrime]), (nextFeaturePresent && ((j == 0 && nextLabel[0] == yPrime) || (j == 1 && nextLabel[0] == y && nextLabel[1] == yPrime)) ? 1 : 0), FBeta[i+1][fIndexPos][yPrime][kk]);
              }
            }
            if (FBeta[i][fIndexPos] == null) {
              FBeta[i][fIndexPos] = new double[numClasses][];
            }
            FBeta[i][fIndexPos][y] = sum;
            if (DEBUG2)
              log.info("FBeta["+i+"]["+fIndexPos+"]["+y+"] = " + Arrays.toString(sum));
          }
        }
      }
      if (TIMED) {
        elapsedMs = timer.stop();
        log.info("\t beta took: " + Timing.toMilliSecondsString(elapsedMs) + " ms");
      }
    }
    if (TIMED) {
      timer.start();
    }

    // derivative equals: VarU' * PtYYp * (1-PtYYp) + VarU * PtYYp' * (1-PtYYp) + VarU * PtYYp * (1-PtYYp)'
    // derivative equals: VarU' * PtYYp * (1-PtYYp) + VarU * PtYYp' * (1-PtYYp) + VarU * PtYYp * -PtYYp'
    // derivative equals: VarU' * PtYYp * (1-PtYYp) + VarU * PtYYp' * (1 - 2 * PtYYp)

    double deltaDivByOneMinusDelta = delta / (1.0-delta);

    Timing innerTimer = new Timing();
    long eTiming = 0;
    long dropoutTiming= 0;

    boolean containsFeature = false;
    // iterate over the positions in this document
    for (int i = 1; i < docData.length; i++) {
      Set docDataHashI = docDataHash.get(i);
      Map EForADocPosAtI = null;
      if (dropoutApprox)
        EForADocPosAtI = EForADocPos.get(i);

      // for each possible clique at this position
      for (int k = 0; k < edgeLabelIndexSize; k++) { // sum over (y, y')
        int[] label = edgeLabels[k];
        int y = label[0];
        int yP = label[1];

        if (TIMED)
          innerTimer.start();

        // important to use label as an int[] for calculating cliqueTree.prob()
        // if it's a node clique, and label index is 2, if we don't use int[]{2} but just pass 2,
        // cliqueTree is going to treat it as index of the edge clique labels, and convert 2
        // into int[]{0,2}, and return the edge prob marginal instead of node marginal
        double PtYYp = cliqueTree.prob(i, label);
        double PtYYpTimesOneMinusPtYYp = PtYYp * (1.0 - PtYYp);
        double oneMinus2PtYYp = (1.0 - 2 * PtYYp);
        double USum = 0;
        int fIndex;
        for (int jjj=0; jjj fLabelIndex = labelIndices.get(jj);
          for (int kk = 0; kk < fLabelIndex.size(); kk++) { // for all parameter \theta
            int[] fLabel = fLabelIndex.get(kk).getLabel();
            // if (FAlpha[i] != null)
            //   log.info("fIndex: " + fIndex+", FAlpha[i].size:"+FAlpha[i].length);
            double fCount = containsFeature && ((jj == 0 && fLabel[0] == yP) || (jj == 1 && k == kk)) ? 1 : 0;

            double alpha;
            double beta;
            double condE;
            double PtYYpPrime;
            if (!dropoutApprox) {
              alpha = ((FAlpha[i][fIndexPos] == null || FAlpha[i][fIndexPos][y] == null) ? 0 : FAlpha[i][fIndexPos][y][kk]);
              beta = ((FBeta[i][fIndexPos] == null || FBeta[i][fIndexPos][yP] == null) ? 0 : FBeta[i][fIndexPos][yP][kk]);
              condE = fCount + alpha + beta;
              if (DEBUG2)
                System.err.printf("fLabel=%s, yP = %d, fCount:%f = ((jj == 0 && fLabel[0] == yP)=%b || (jj == 1 && k == kk))=%b\n", Arrays.toString(fLabel),yP, fCount,(jj == 0 && fLabel[0] == yP) , (jj == 1 && k == kk));
              PtYYpPrime = PtYYp * (condE - EForADoc.get(fIndex)[kk]);
            } else {
              double E = 0;
              if (EForADocPosAtI.containsKey(fIndex))
                E = EForADocPosAtI.get(fIndex)[kk];
              condE = fCount;
              PtYYpPrime = PtYYp * (condE - E);
            }

            if (DEBUG2)
              System.err.printf("for i=%d, k=%d, y=%d, yP=%d, fIndex=%d, kk=%d, PtYYpPrime=% 5.3f, PtYYp=% 3.3f, (condE-E[fIndex][kk])=% 3.3f, condE=% 3.3f, E[fIndex][k]=% 3.3f, alpha=% 3.3f, beta=% 3.3f, fCount=% 3.3f\n", i, k, y, yP, fIndex, kk, PtYYpPrime, PtYYp, (condE - EForADoc.get(fIndex)[kk]), condE, EForADoc.get(fIndex)[kk], alpha, beta, fCount);

            increScore(dropoutPriorGrad, fIndex, kk, VarUTimesOneMinus2PtYYp * PtYYpPrime);
          }

          if (DEBUG2)
            log.info();
        }
        if (TIMED)
          dropoutTiming += innerTimer.stop();
      }
    }
    if (CONDENSE) {
      // copy for condensedFeaturesMap
      for (Map.Entry> entry: condensedFeaturesMap.entrySet()) {
        int key = entry.getKey();
        List aList = entry.getValue();
        for (int toCopyInto: aList) {
          double[] arr = dropoutPriorGrad.get(key);
          double[] targetArr = new double[arr.length];
          for (int i=0; i < arr.length; i++)
            targetArr[i] = arr[i];
          dropoutPriorGrad.put(toCopyInto, targetArr);
        }
      }
    }

    if (DEBUG3) {
      log.info("dropoutPriorGradFirstHalf.keys:[");
      for (int key: dropoutPriorGradFirstHalf.keySet())
        log.info(" "+key);
      log.info("]");

      log.info("dropoutPriorGrad.keys:[");
      for (int key: dropoutPriorGrad.keySet())
        log.info(" "+key);
      log.info("]");
    }

    for (Map.Entry entry: dropoutPriorGrad.entrySet()) {
      Integer key = entry.getKey();
      double[] target = entry.getValue();
      if (dropoutPriorGradFirstHalf.containsKey(key)) {
        double[] source = dropoutPriorGradFirstHalf.get(key);
        for (int i=0; i, Quadruple, Map>> wrapper =
            new MulticoreWrapper<>(multiThreadGrad, dropoutPriorThreadProcessor);
    // supervised part
    for (int m = 0; m < totalData.length; m++) {
      boolean submitIsUnsup = (m >= unsupDropoutStartIndex);
      wrapper.put(new Pair<>(m, submitIsUnsup));
      while (wrapper.peek()) {
        Quadruple, Map> result = wrapper.poll();
        int docIndex = result.first();
        boolean isUnsup = docIndex >= unsupDropoutStartIndex;
        if (isUnsup) {
          prob += unsupDropoutScale * result.second();
        } else {
          prob += result.second();
        }

        Map partialDropout = result.fourth();
        if (partialDropout != null) {
          if (isUnsup) {
            combine2DArr(dropoutPriorGradTotal, partialDropout, unsupDropoutScale);
          } else {
            combine2DArr(dropoutPriorGradTotal, partialDropout);
          }
        }

        if (!isUnsup) {
          Map partialE = result.third();
          if (partialE != null)
            combine2DArr(E, partialE);
        }
      }
    }
    wrapper.join();
    while (wrapper.peek()) {
      Quadruple, Map> result = wrapper.poll();
      int docIndex = result.first();
      boolean isUnsup = docIndex >= unsupDropoutStartIndex;
      if (isUnsup) {
        prob += unsupDropoutScale * result.second();
      } else {
        prob += result.second();
      }

      Map partialDropout = result.fourth();
      if (partialDropout != null) {
        if (isUnsup) {
          combine2DArr(dropoutPriorGradTotal, partialDropout, unsupDropoutScale);
        } else {
          combine2DArr(dropoutPriorGradTotal, partialDropout);
        }
      }

      if (!isUnsup) {
        Map partialE = result.third();
        if (partialE != null)
          combine2DArr(E, partialE);
      }
    }


    if (Double.isNaN(prob)) { // shouldn't be the case
      throw new RuntimeException("Got NaN for prob in CRFLogConditionalObjectiveFunctionWithDropout.calculate()" +
              " - this may well indicate numeric underflow due to overly long documents.");
    }

    // because we minimize -L(\theta)
    value = -prob;
    if (VERBOSE) {
      log.info("value is " + Math.exp(-value));
    }

    // compute the partial derivative for each feature by comparing expected counts to empirical counts
    int index = 0;
    for (int i = 0; i < E.length; i++) {
      for (int j = 0; j < E[i].length; j++) {
        // because we minimize -L(\theta)
        derivative[index] = (E[i][j] - Ehat[i][j]);
        derivative[index] += dropoutScale * dropoutPriorGradTotal[i][j];
        if (VERBOSE) {
          log.info("deriv(" + i + ',' + j + ") = " + E[i][j] + " - " + Ehat[i][j] + " = " + derivative[index]);
        }
        index++;
      }
    }
  }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy