All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ai.h2o.sparkling.ml.params.EnumParamValidator.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package ai.h2o.sparkling.ml.params

import java.util

object EnumParamValidator {

  def getValidatedEnumValue[T <: Enum[T]](name: String)(implicit ctag: reflect.ClassTag[T]): String = {
    getValidatedEnumValue(ctag.runtimeClass, name)
  }

  def getValidatedEnumValues[T <: Enum[T]](inputNames: Array[String], nullEnabled: Boolean = false)(
      implicit ctag: reflect.ClassTag[T]): Array[String] = {
    getValidatedEnumValues(ctag.runtimeClass, inputNames, nullEnabled)
  }

  // Method exposed for PySparkling so we can do the same checks there
  def getValidatedEnumValue(className: String, name: String): String = {
    getValidatedEnumValue(Class.forName(className), name)
  }

  // Method exposed for PySparkling so we can do the same checks there
  def getValidatedEnumValues(
      className: String,
      inputNames: util.ArrayList[String],
      nullEnabled: Boolean): Array[String] = {
    val arr = if (inputNames == null) null else inputNames.toArray(new Array[String](inputNames.size()))
    getValidatedEnumValues(Class.forName(className), arr, nullEnabled)
  }

  private def getValidatedEnumValue(clazz: Class[_], name: String): String = {
    val names = getEnumValues(clazz)
    if (name == null) {
      throw new IllegalArgumentException(s"Null is not a valid value. Allowed values are: ${names.mkString(", ")}")
    }

    if (!names.map(_.toLowerCase()).contains(name.toLowerCase())) {
      throw new IllegalArgumentException(s"'$name' is not a valid value. Allowed values are: ${names.mkString(", ")}")
    }
    names.find(_.toLowerCase() == name.toLowerCase).get
  }

  private def getValidatedEnumValues(
      clazz: Class[_],
      inputNames: Array[String],
      nullEnabled: Boolean): Array[String] = {
    val names = getEnumValues(clazz)
    if (inputNames == null) {
      if (nullEnabled) {
        return null
      } else {
        throw new IllegalArgumentException(
          s"Null is not a valid value. Allowed input is array with any of the following elements: ${names.mkString(", ")}")
      }
    }

    inputNames.foreach { name =>
      val nullStr = if (nullEnabled) " null or " else " "
      if (name == null) {
        throw new IllegalArgumentException(
          s"Null can not be specified as the input array element. " +
            s"Allowed input is${nullStr}array with any of the following elements: ${names.mkString(", ")}")
      }
      if (!names.map(_.toLowerCase()).contains(name.toLowerCase())) {
        throw new IllegalArgumentException(
          s"'$name' is not a valid value. Allowed input is${nullStr}array with" +
            s" any of the following elements: ${names.mkString(", ")}")
      }
    }

    names.filter(name => inputNames.map(_.toLowerCase).contains(name.toLowerCase()))
  }

  private def getEnumValues(clazz: Class[_]): Array[String] = {
    clazz.getDeclaredMethod("values").invoke(null).asInstanceOf[Array[Enum[_]]].map(_.name())
  }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy