All Downloads are FREE. Search and download functionalities are using the official Maven repository.

hivemall.evaluation.F1ScoreUDAF Maven / Gradle / Ivy

The newest version!
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package hivemall.evaluation;

import hivemall.utils.hadoop.WritableUtils;

import java.util.List;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDAF;
import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.io.IntWritable;

@SuppressWarnings("deprecation")
@Description(name = "f1score", value = "_FUNC_(array[int], array[int]) - Return a F1 score")
public final class F1ScoreUDAF extends UDAF {

    public static class Evaluator implements UDAFEvaluator {

        public static class PartialResult {
            long tp;
            /** tp + fn */
            long totalActual;
            /** tp + fp */
            long totalPredicted;

            PartialResult() {
                this.tp = 0L;
                this.totalPredicted = 0L;
                this.totalActual = 0L;
            }

            void updateScore(final List actual, final List predicted) {
                final int numActual = actual.size();
                final int numPredicted = predicted.size();
                int countTp = 0;
                for (int i = 0; i < numPredicted; i++) {
                    IntWritable p = predicted.get(i);
                    if (actual.contains(p)) {
                        countTp++;
                    }
                }
                this.tp += countTp;
                this.totalActual += numActual;
                this.totalPredicted += numPredicted;
            }

            void merge(PartialResult other) {
                this.tp += other.tp;
                this.totalActual += other.totalActual;
                this.totalPredicted += other.totalPredicted;
            }
        }

        private PartialResult partial;

        @Override
        public void init() {
            this.partial = null;
        }

        public boolean iterate(List actual, List predicted) {
            if (partial == null) {
                this.partial = new PartialResult();
            }
            partial.updateScore(actual, predicted);
            return true;
        }

        public PartialResult terminatePartial() {
            return partial;
        }

        public boolean merge(PartialResult other) {
            if (other == null) {
                return true;
            }
            if (partial == null) {
                this.partial = new PartialResult();
            }
            partial.merge(other);
            return true;
        }

        public DoubleWritable terminate() {
            if (partial == null) {
                return null;
            }
            double score = f1Score(partial);
            return WritableUtils.val(score);
        }

        /**
         * @return 2 * precision * recall / (precision + recall)
         */
        private static double f1Score(final PartialResult partial) {
            double precision = precision(partial);
            double recall = recall(partial);
            double divisor = precision + recall;
            if (divisor > 0) {
                return (2.d * precision * recall) / divisor;
            } else {
                return -1d;
            }
        }

        private static double precision(final PartialResult partial) {
            return (partial.totalPredicted == 0L) ? 0d
                    : partial.tp / (double) partial.totalPredicted;
        }

        private static double recall(final PartialResult partial) {
            return (partial.totalActual == 0L) ? 0d : partial.tp / (double) partial.totalActual;
        }

    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy