All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.intel.analytics.bigdl.example.utils.TextTransformer.scala Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 The BigDL Authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.intel.analytics.bigdl.example.utils

object SimpleTokenizer {
  /**
   * Simple tokenizer to split text into separated tokens.
   * @param text text to be split.
   * @param lower convert to lower case or not.
   * @return An array of separated tokens.
   */
  def toTokens(text: String, lower: Boolean = true): Array[String] = {
    text.replaceAll("[^a-zA-Z]", " ").toLowerCase().split("\\s+").filter(_.size > 2)
  }

  /**
   * Transform sample text into tokens and ignore those unknown tokens.
   * @param word2Meta Indicate the included words.
   */
  def toTokens(text: String, word2Meta: Map[String, WordMeta]): Array[Float] = {
    SimpleTokenizer.toTokens(text).map { word: String =>
      if (word2Meta.contains(word)) {
        Some(word2Meta(word).index.toFloat)
      } else {
        None
      }
    }.flatten
  }

  /**
   * Shape the token sequence to the specified length.
   * The sequence would be either padded or truncated.
   * @param sequenceLen the desired seq length
   * @param trunc truncated from pre or post.
   */
  def shaping(tokens: Array[Float], sequenceLen: Int, trunc: String = "pre")
  : Array[Float] = {
    val paddedTokens = if (tokens.length > sequenceLen) {
      if ("pre" == trunc) {
        tokens.slice(tokens.length - sequenceLen, tokens.length)
      } else {
        tokens.slice(0, sequenceLen)
      }
    } else {
      tokens ++ Array.fill[Float](sequenceLen - tokens.length)(0)
    }
    paddedTokens
  }

  /**
   * Transform word to pre-trained vector.
   * @param embeddingSize size of the pre-trained vector
   * @param word2Vec pre-trained word2Vec
   */
  def vectorization(tokens: Array[Float], embeddingSize: Int, word2Vec: Map[Float, Array[Float]])
  : Array[Array[Float]] = {
    tokens.map { word =>
      if (word2Vec.contains(word)) {
        word2Vec(word)
      } else {
        // Treat it as zeros if cannot be found from pre-trained word2Vec
        Array.fill[Float](embeddingSize)(0)
      }
    }
  }
}

/**
 * @param count frequency of the word.
 * @param index index of the word which ranked by the frequency from high to low.
 */
case class WordMeta(count: Int, index: Int)




© 2015 - 2024 Weber Informatics LLC | Privacy Policy