All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tencent.angel.sona.ml.feature.CountVectorizer.scala Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.tencent.angel.sona.ml.feature

import org.apache.hadoop.fs.Path
import com.tencent.angel.sona.ml.{Estimator, Model}
import com.tencent.angel.sona.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
import com.tencent.angel.sona.ml.param.{BooleanParam, DoubleParam, IntParam, ParamMap, ParamValidators, Params}
import com.tencent.angel.sona.ml.param.shared.{HasInputCol, HasOutputCol}
import com.tencent.angel.sona.ml.util._
import org.apache.spark.sql.util.SONASchemaUtils
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.{DatasetUtil, OpenHashMap}
import org.apache.spark.linalg._

/**
  * Params for [[CountVectorizer]] and [[CountVectorizerModel]].
  */
private[sona] trait CountVectorizerParams extends Params with HasInputCol with HasOutputCol {

  /**
    * Max size of the vocabulary.
    * CountVectorizer will build a vocabulary that only considers the top
    * vocabSize terms ordered by term frequency across the corpus.
    *
    * Default: 2^18^
    *
    * @group param
    */
  val vocabSize: IntParam =
    new IntParam(this, "vocabSize", "max size of the vocabulary", ParamValidators.gt(0))

  /** @group getParam */
  def getVocabSize: Int = $(vocabSize)

  /**
    * Specifies the minimum number of different documents a term must appear in to be included
    * in the vocabulary.
    * If this is an integer greater than or equal to 1, this specifies the number of documents
    * the term must appear in; if this is a double in [0,1), then this specifies the fraction of
    * documents.
    *
    * Default: 1.0
    *
    * @group param
    */
  val minDF: DoubleParam = new DoubleParam(this, "minDF", "Specifies the minimum number of" +
    " different documents a term must appear in to be included in the vocabulary." +
    " If this is an integer >= 1, this specifies the number of documents the term must" +
    " appear in; if this is a double in [0,1), then this specifies the fraction of documents.",
    ParamValidators.gtEq(0.0))

  /** @group getParam */
  def getMinDF: Double = $(minDF)

  /**
    * Specifies the maximum number of different documents a term could appear in to be included
    * in the vocabulary. A term that appears more than the threshold will be ignored. If this is an
    * integer greater than or equal to 1, this specifies the maximum number of documents the term
    * could appear in; if this is a double in [0,1), then this specifies the maximum fraction of
    * documents the term could appear in.
    *
    * Default: (2^63^) - 1
    *
    * @group param
    */
  val maxDF: DoubleParam = new DoubleParam(this, "maxDF", "Specifies the maximum number of" +
    " different documents a term could appear in to be included in the vocabulary." +
    " A term that appears more than the threshold will be ignored. If this is an integer >= 1," +
    " this specifies the maximum number of documents the term could appear in;" +
    " if this is a double in [0,1), then this specifies the maximum fraction of" +
    " documents the term could appear in.",
    ParamValidators.gtEq(0.0))

  /** @group getParam */
  def getMaxDF: Double = $(maxDF)

  /** Validates and transforms the input schema. */
  protected def validateAndTransformSchema(schema: StructType): StructType = {
    val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false))
    SONASchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
    SONASchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
  }

  /**
    * Filter to ignore rare words in a document. For each document, terms with
    * frequency/count less than the given threshold are ignored.
    * If this is an integer greater than or equal to 1, then this specifies a count (of times the
    * term must appear in the document);
    * if this is a double in [0,1), then this specifies a fraction (out of the document's token
    * count).
    *
    * Note that the parameter is only used in transform of [[CountVectorizerModel]] and does not
    * affect fitting.
    *
    * Default: 1.0
    *
    * @group param
    */
  val minTF: DoubleParam = new DoubleParam(this, "minTF", "Filter to ignore rare words in" +
    " a document. For each document, terms with frequency/count less than the given threshold are" +
    " ignored. If this is an integer >= 1, then this specifies a count (of times the term must" +
    " appear in the document); if this is a double in [0,1), then this specifies a fraction (out" +
    " of the document's token count). Note that the parameter is only used in transform of" +
    " CountVectorizerModel and does not affect fitting.", ParamValidators.gtEq(0.0))

  /** @group getParam */
  def getMinTF: Double = $(minTF)

  /**
    * Binary toggle to control the output vector values.
    * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for
    * discrete probabilistic models that model binary events rather than integer counts.
    * Default: false
    *
    * @group param
    */
  val binary: BooleanParam =
    new BooleanParam(this, "binary", "If True, all non zero counts are set to 1.")

  /** @group getParam */
  def getBinary: Boolean = $(binary)

  setDefault(vocabSize -> (1 << 18),
    minDF -> 1.0,
    maxDF -> Long.MaxValue,
    minTF -> 1.0,
    binary -> false)
}

/**
  * Extracts a vocabulary from document collections and generates a [[CountVectorizerModel]].
  */

class CountVectorizer(override val uid: String)
  extends Estimator[CountVectorizerModel] with CountVectorizerParams with DefaultParamsWritable {


  def this() = this(Identifiable.randomUID("cntVec"))

  /** @group setParam */

  def setInputCol(value: String): this.type = set(inputCol, value)

  /** @group setParam */

  def setOutputCol(value: String): this.type = set(outputCol, value)

  /** @group setParam */

  def setVocabSize(value: Int): this.type = set(vocabSize, value)

  /** @group setParam */

  def setMinDF(value: Double): this.type = set(minDF, value)

  /** @group setParam */

  def setMaxDF(value: Double): this.type = set(maxDF, value)

  /** @group setParam */

  def setMinTF(value: Double): this.type = set(minTF, value)

  /** @group setParam */

  def setBinary(value: Boolean): this.type = set(binary, value)


  override def fit(dataset: Dataset[_]): CountVectorizerModel = {
    transformSchema(dataset.schema, logging = true)
    val vocSize = $(vocabSize)
    val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
    val countingRequired = $(minDF) < 1.0 || $(maxDF) < 1.0
    val maybeInputSize = if (countingRequired) {
      Some(input.cache().count())
    } else {
      None
    }
    val minDf = if ($(minDF) >= 1.0) {
      $(minDF)
    } else {
      $(minDF) * maybeInputSize.get
    }
    val maxDf = if ($(maxDF) >= 1.0) {
      $(maxDF)
    } else {
      $(maxDF) * maybeInputSize.get
    }
    require(maxDf >= minDf, "maxDF must be >= minDF.")
    val allWordCounts = input.flatMap { case (tokens) =>
      val wc = new OpenHashMap[String, Long]
      tokens.foreach { w =>
        wc.changeValue(w, 1L, _ + 1L)
      }
      wc.map { case (word, count) => (word, (count, 1)) }
    }.reduceByKey { case ((wc1, df1), (wc2, df2)) =>
      (wc1 + wc2, df1 + df2)
    }

    val filteringRequired = isSet(minDF) || isSet(maxDF)
    val maybeFilteredWordCounts = if (filteringRequired) {
      allWordCounts.filter { case (_, (_, df)) => df >= minDf && df <= maxDf }
    } else {
      allWordCounts
    }

    val wordCounts = maybeFilteredWordCounts
      .map { case (word, (count, _)) => (word, count) }
      .cache()

    if (countingRequired) {
      input.unpersist()
    }

    val fullVocabSize = wordCounts.count()

    val vocab = wordCounts
      .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2))
      .map(_._1)

    require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
    copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
  }


  override def transformSchema(schema: StructType): StructType = {
    validateAndTransformSchema(schema)
  }


  override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra)
}


object CountVectorizer extends DefaultParamsReadable[CountVectorizer] {


  override def load(path: String): CountVectorizer = super.load(path)
}

/**
  * Converts a text document to a sparse vector of token counts.
  *
  * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
  */

class CountVectorizerModel(
                            override val uid: String,
                            val vocabulary: Array[String])
  extends Model[CountVectorizerModel] with CountVectorizerParams with MLWritable {

  import CountVectorizerModel._


  def this(vocabulary: Array[String]) = {
    this(Identifiable.randomUID("cntVecModel"), vocabulary)
    set(vocabSize, vocabulary.length)
  }

  /** @group setParam */

  def setInputCol(value: String): this.type = set(inputCol, value)

  /** @group setParam */

  def setOutputCol(value: String): this.type = set(outputCol, value)

  /** @group setParam */

  def setMinTF(value: Double): this.type = set(minTF, value)

  /** @group setParam */

  def setBinary(value: Boolean): this.type = set(binary, value)

  /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
  private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None


  override def transform(dataset: Dataset[_]): DataFrame = {
    transformSchema(dataset.schema, logging = true)
    if (broadcastDict.isEmpty) {
      val dict = vocabulary.zipWithIndex.toMap
      broadcastDict = Some(dataset.sparkSession.sparkContext.broadcast(dict))
    }
    val dictBr = broadcastDict.get
    val minTf = $(minTF)
    val vectorizer = udf { (document: Seq[String]) =>
      val termCounts = new OpenHashMap[Int, Double]
      var tokenCount = 0L
      document.foreach { term =>
        dictBr.value.get(term) match {
          case Some(index) => termCounts.changeValue(index, 1.0, _ + 1.0)
          case None => // ignore terms not in the vocabulary
        }
        tokenCount += 1
      }
      val effectiveMinTF = if (minTf >= 1.0) minTf else tokenCount * minTf
      val effectiveCounts = if ($(binary)) {
        termCounts.filter(_._2 >= effectiveMinTF).map(p => (p._1, 1.0)).toSeq
      } else {
        termCounts.filter(_._2 >= effectiveMinTF).toSeq
      }

      Vectors.sparse(dictBr.value.size, effectiveCounts)
    }
    val attrs = vocabulary.map(_ => new NumericAttribute).asInstanceOf[Array[Attribute]]
    val metadata = new AttributeGroup($(outputCol), attrs).toMetadata
    DatasetUtil.withColumn(dataset, $(outputCol), vectorizer(col($(inputCol))), metadata)
  }


  override def transformSchema(schema: StructType): StructType = {
    validateAndTransformSchema(schema)
  }


  override def copy(extra: ParamMap): CountVectorizerModel = {
    val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent)
    copyValues(copied, extra)
  }


  override def write: MLWriter = new CountVectorizerModelWriter(this)
}


object CountVectorizerModel extends MLReadable[CountVectorizerModel] {

  private[CountVectorizerModel]
  class CountVectorizerModelWriter(instance: CountVectorizerModel) extends MLWriter {

    private case class Data(vocabulary: Seq[String])

    override protected def saveImpl(path: String): Unit = {
      DefaultParamsWriter.saveMetadata(instance, path, sc)
      val data = Data(instance.vocabulary)
      val dataPath = new Path(path, "data").toString
      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
    }
  }

  private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] {

    private val className = classOf[CountVectorizerModel].getName

    override def load(path: String): CountVectorizerModel = {
      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
      val dataPath = new Path(path, "data").toString
      val data = sparkSession.read.parquet(dataPath)
        .select("vocabulary")
        .head()
      val vocabulary = data.getAs[Seq[String]](0).toArray
      val model = new CountVectorizerModel(metadata.uid, vocabulary)
      metadata.getAndSetParams(model)
      model
    }
  }


  override def read: MLReader[CountVectorizerModel] = new CountVectorizerModelReader


  override def load(path: String): CountVectorizerModel = super.load(path)
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy