All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.johnsnowlabs.nlp.HasInputAnnotationCols.scala Maven / Gradle / Ivy

/*
 * Copyright 2017-2022 John Snow Labs
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.johnsnowlabs.nlp

import org.apache.spark.ml.param.{Params, StringArrayParam}
import org.apache.spark.sql.types.StructType

trait HasInputAnnotationCols extends Params {

  /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator
    * type
    */
  val inputAnnotatorTypes: Array[String]

  val optionalInputAnnotatorTypes: Array[String] = Array()

  /** columns that contain annotations necessary to run this annotator AnnotatorType is used both
    * as input and output columns if not specified
    */
  protected final val inputCols: StringArrayParam =
    new StringArrayParam(this, "inputCols", "the input annotation columns")

  /** Overrides required annotators column if different than default */
  def setInputCols(value: Array[String]): this.type = {
    if (optionalInputAnnotatorTypes.isEmpty) {
      require(
        value.length == inputAnnotatorTypes.length,
        s"setInputCols in ${this.uid} expecting ${inputAnnotatorTypes.length} columns. " +
          s"Provided column amount: ${value.length}. " +
          s"Which should be columns from the following annotators: ${inputAnnotatorTypes.mkString(", ")} ")
    } else {
      val expectedColumns = inputAnnotatorTypes.length + optionalInputAnnotatorTypes.length
      require(
        value.length == inputAnnotatorTypes.length || value.length == expectedColumns,
        s"setInputCols in ${this.uid} expecting at least ${inputAnnotatorTypes.length} columns. " +
          s"Provided column amount: ${value.length}. " +
          s"Which should be columns from at least the following annotators: ${inputAnnotatorTypes
              .mkString(", ")} ")
    }
    set(inputCols, value)
  }

  protected def msgHelper(schema: StructType): String = {
    val schemaInfo = schema.map(sc =>
      (
        "column_name=" + sc.name,
        "is_nlp_annotator=" + sc.metadata.contains("annotatorType") + {
          if (sc.metadata.contains("annotatorType"))
            ",type=" + sc.metadata.getString("annotatorType")
          else ""
        }))
    s"\nCurrent inputCols: ${getInputCols.mkString(",")}. Dataset's columns:\n${schemaInfo.mkString("\n")}."
  }

  final protected def checkSchema(schema: StructType, inputAnnotatorType: String): Boolean = {
    schema.exists { field =>
      {
        field.metadata.contains("annotatorType") &&
        field.metadata.getString("annotatorType") == inputAnnotatorType &&
        getInputCols.contains(field.name)
      }
    }
  }

  final def setInputCols(value: String*): this.type = setInputCols(value.toArray)

  /** @return input annotations columns currently used */
  def getInputCols: Array[String] =
    get(inputCols)
      .orElse(getDefault(inputCols))
      .getOrElse(throw new Exception(s"inputCols not provided." +
        s" Requires columns for ${inputAnnotatorTypes.mkString(", ")} annotators"))
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy