All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.spark.ml.BaseAlgorithmEstimator.scala Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml

import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StringType}


/**
  * 7/27/16 WilliamZhu([email protected])
  */
trait BaseAlgorithmEstimator {

  def name: String


  def source(training: DataFrame, vectorSize: Int): DataFrame = {

    training.schema.fields.find(p => p.name == "features").head.dataType match {
      case a: StringType =>
        val t = udf { (features: String) =>

          if (!features.contains(":")) {
            val v = features.split(",|\\s+").map(_.toDouble)
            Vectors.dense(v)
          } else {
            val v = features.split(",|\\s+").map(_.split(":")).map(f => (f(0).toInt, f(1).toDouble))
            Vectors.sparse(vectorSize, v)
          }
        }

        training.select(
          col("label") cast (DoubleType),
          t(col("features")) as "features"
        )
      case _ =>
        training
    }


  }

  def mlParams(multiParams: Array[Map[String, Any]]): Array[ParamMap] = {
    val paramGrid = new ParamGridBuilder()

    val params = multiParams.flatMap(f => f).groupBy(f => f._1).map(f => (f._1, f._2.map(k => k._2))).toArray

    params.foreach { p =>
      val ab = algorithm.getParam(p._1)
      paramGrid.addGrid(ab, p._2)
    }
    paramGrid.build()
  }

  def algorithm: Estimator[_]

  def evaluator: Evaluator

  def fit: Model[_]
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy