io.trino.plugin.ml.EvaluateClassifierPredictionsAggregation Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of trino-ml Show documentation
Show all versions of trino-ml Show documentation
Trino - Machine Learning Plugin
The newest version!
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.ml;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.LiteralParameters;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.type.StandardTypes;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import static com.google.common.collect.Sets.union;
import static io.airlift.slice.SizeOf.SIZE_OF_INT;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
@AggregationFunction("evaluate_classifier_predictions")
public final class EvaluateClassifierPredictionsAggregation
{
private EvaluateClassifierPredictionsAggregation() {}
@InputFunction
public static void input(@AggregationState EvaluateClassifierPredictionsState state, @SqlType(StandardTypes.BIGINT) long truth, @SqlType(StandardTypes.BIGINT) long prediction)
{
input(state, Slices.utf8Slice(String.valueOf(truth)), Slices.utf8Slice(String.valueOf(prediction)));
}
@InputFunction
@LiteralParameters({"x", "y"})
public static void input(@AggregationState EvaluateClassifierPredictionsState state, @SqlType("varchar(x)") Slice truth, @SqlType("varchar(y)") Slice prediction)
{
if (truth.equals(prediction)) {
String key = truth.toStringUtf8();
if (!state.getTruePositives().containsKey(key)) {
state.addMemoryUsage(truth.length() + SIZE_OF_INT);
}
state.getTruePositives().put(key, state.getTruePositives().getOrDefault(key, 0) + 1);
}
else {
String truthKey = truth.toStringUtf8();
String predictionKey = prediction.toStringUtf8();
if (!state.getFalsePositives().containsKey(predictionKey)) {
state.addMemoryUsage(prediction.length() + SIZE_OF_INT);
}
state.getFalsePositives().put(predictionKey, state.getFalsePositives().getOrDefault(predictionKey, 0) + 1);
if (!state.getFalseNegatives().containsKey(truthKey)) {
state.addMemoryUsage(truth.length() + SIZE_OF_INT);
}
state.getFalseNegatives().put(truthKey, state.getFalseNegatives().getOrDefault(truthKey, 0) + 1);
}
}
@CombineFunction
public static void combine(@AggregationState EvaluateClassifierPredictionsState state, @AggregationState EvaluateClassifierPredictionsState scratchState)
{
int size = 0;
size += mergeMaps(state.getTruePositives(), scratchState.getTruePositives());
size += mergeMaps(state.getFalsePositives(), scratchState.getFalsePositives());
size += mergeMaps(state.getFalseNegatives(), scratchState.getFalseNegatives());
state.addMemoryUsage(size);
}
// Returns the estimated memory increase in map
private static int mergeMaps(Map map, Map other)
{
int deltaSize = 0;
for (Map.Entry entry : other.entrySet()) {
if (!map.containsKey(entry.getKey())) {
deltaSize += entry.getKey().getBytes(UTF_8).length + SIZE_OF_INT;
}
map.put(entry.getKey(), map.getOrDefault(entry.getKey(), 0) + other.getOrDefault(entry.getKey(), 0));
}
return deltaSize;
}
@OutputFunction(StandardTypes.VARCHAR)
public static void output(@AggregationState EvaluateClassifierPredictionsState state, BlockBuilder out)
{
StringBuilder sb = new StringBuilder();
long correct = state.getTruePositives()
.values()
.stream()
.reduce(0, Integer::sum);
long total = correct + state.getFalsePositives().values().stream().reduce(0, Integer::sum);
sb.append(format(Locale.US, "Accuracy: %d/%d (%.2f%%)\n", correct, total, 100.0 * correct / (double) total));
Set labels = union(union(state.getTruePositives().keySet(), state.getFalsePositives().keySet()), state.getFalseNegatives().keySet());
for (String label : labels) {
int truePositives = state.getTruePositives().getOrDefault(label, 0);
int falsePositives = state.getFalsePositives().getOrDefault(label, 0);
int falseNegatives = state.getFalseNegatives().getOrDefault(label, 0);
sb.append(format(Locale.US, "Class '%s'\n", label));
sb.append(format(Locale.US, "Precision: %d/%d (%.2f%%)\n", truePositives, truePositives + falsePositives, 100.0 * truePositives / (double) (truePositives + falsePositives)));
sb.append(format(Locale.US, "Recall: %d/%d (%.2f%%)\n", truePositives, truePositives + falseNegatives, 100.0 * truePositives / (double) (truePositives + falseNegatives)));
}
VARCHAR.writeString(out, sb.toString());
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy