All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.etsy.conjecture.scalding.evaluate.GenericEvaluator.scala Maven / Gradle / Ivy

There is a newer version: 0.2.3
Show newest version
package com.etsy.conjecture.scalding.evaluate

import com.twitter.scalding._
import com.etsy.conjecture._
import com.etsy.conjecture.data._
import com.etsy.conjecture.evaluation._
import com.etsy.conjecture.model._

import cascading.pipe.Pipe

abstract class GenericEvaluator[L <: Label] extends Serializable {

    import Dsl._

    def build(): ModelEvaluation[L]

    def evaluate(instance_pipe: Pipe, predict_field: Symbol, label_field: Symbol, evaluation_field: Symbol): Pipe = {
        val partialEval = '__partial_eval
        instance_pipe
          .map((label_field, predict_field) -> partialEval){
            pair : (L, L) =>
            val eval = build
            eval.add(pair._1, pair._2)
            eval
          }
          .groupAll{
            _.reduce(partialEval -> evaluation_field){
              (eval : ModelEvaluation[L], final_eval : ModelEvaluation[L]) =>
              final_eval.merge(eval)
              final_eval
            }
          }
          .project(evaluation_field)
    }

    def evaluate(instance_pipe: Pipe, instance_field: Symbol, label_field: Symbol, model_pipe: Pipe, model_field: Symbol, evaluation_field: Symbol): Pipe = {
        val instances_with_predictions = assign_predictions(instance_pipe, instance_field, label_field, model_pipe, model_field, 'prediction)
        evaluate(instances_with_predictions, label_field, 'prediction, evaluation_field)
    }

    def assign_predictions(instance_pipe: Pipe, instance_field: Symbol, label_field: Symbol, model_pipe: Pipe, model_field: Symbol, prediction_field: Symbol = 'prediction) = {
        instance_pipe.crossWithTiny(model_pipe)
            .map((instance_field, model_field) -> (label_field, prediction_field)) { x: (LabeledInstance[L], Model[L]) => (x._1.getLabel, x._2.predict(x._1.getVector)) }
            .project(label_field, prediction_field)
    }
}

class BinaryEvaluator extends GenericEvaluator[BinaryLabel] {
    def build() = new BinaryModelEvaluation()
}

class MulticlassEvaluator(categories: Array[String]) extends GenericEvaluator[MulticlassLabel] {
    def build() = new MulticlassModelEvaluation(categories)
}

class RegressionEvaluator extends GenericEvaluator[RealValuedLabel] {
    def build() = new RegressionModelEvaluation()
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy