All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.microsoft.ml.spark.vw.VowpalWabbitRegressor.scala Maven / Gradle / Ivy

The newest version!
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.vw

import com.microsoft.ml.spark.core.env.InternalWrapper
import com.microsoft.ml.spark.core.serialize.ConstructorReadable
import org.apache.spark.ml.{BaseRegressor, ComplexParamsReadable, ComplexParamsWritable}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions.col
import org.apache.spark.ml.regression.RegressionModel

object VowpalWabbitRegressor extends DefaultParamsReadable[VowpalWabbitRegressor]

class VowpalWabbitRegressor(override val uid: String)
  extends BaseRegressor[Row, VowpalWabbitRegressor, VowpalWabbitRegressorModel]
    with VowpalWabbitBase
{
  def this() = this(Identifiable.randomUID("VowpalWabbitRegressor"))

  override def train(dataset: Dataset[_]): VowpalWabbitRegressorModel = {
    val binaryModel = trainInternal(dataset)

    new VowpalWabbitRegressorModel(uid)
      .setModel(binaryModel)
      .setFeaturesCol(getFeaturesCol)
      .setAdditionalFeatures(getAdditionalFeatures)
      .setPredictionCol(getPredictionCol)
  }

  override def copy(extra: ParamMap): VowpalWabbitRegressor = defaultCopy(extra)
}

class VowpalWabbitRegressorModel(override val uid: String)
  extends RegressionModel[Row, VowpalWabbitRegressorModel]
    with VowpalWabbitBaseModel
    with ComplexParamsWritable
{
  protected override def transformImpl(dataset: Dataset[_]): DataFrame = {
    transformImplInternal(dataset)
      .withColumn($(predictionCol), col($(rawPredictionCol)))
  }

  override def predict(features: Row): Double = {
    throw new NotImplementedError("Not implement")
  }

  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
}

object VowpalWabbitRegressorModel extends ComplexParamsReadable[VowpalWabbitRegressorModel]




© 2015 - 2024 Weber Informatics LLC | Privacy Policy