All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.cassandra.cql3.functions.UDAggregate Maven / Gradle / Ivy

There is a newer version: 4.3.1.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.cassandra.cql3.functions;

import java.nio.ByteBuffer;
import java.util.*;

import com.google.common.base.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.schema.Functions;
import org.apache.cassandra.tracing.Tracing;

/**
 * Base class for user-defined-aggregates.
 */
public class UDAggregate extends AbstractFunction implements AggregateFunction
{
    protected static final Logger logger = LoggerFactory.getLogger(UDAggregate.class);

    protected final AbstractType stateType;
    protected final ByteBuffer initcond;
    private final ScalarFunction stateFunction;
    private final ScalarFunction finalFunction;

    public UDAggregate(FunctionName name,
                       List> argTypes,
                       AbstractType returnType,
                       ScalarFunction stateFunc,
                       ScalarFunction finalFunc,
                       ByteBuffer initcond)
    {
        super(name, argTypes, returnType);
        this.stateFunction = stateFunc;
        this.finalFunction = finalFunc;
        this.stateType = stateFunc != null ? stateFunc.returnType() : null;
        this.initcond = initcond;
    }

    public static UDAggregate create(Functions functions,
                                     FunctionName name,
                                     List> argTypes,
                                     AbstractType returnType,
                                     FunctionName stateFunc,
                                     FunctionName finalFunc,
                                     AbstractType stateType,
                                     ByteBuffer initcond)
    throws InvalidRequestException
    {
        List> stateTypes = new ArrayList<>(argTypes.size() + 1);
        stateTypes.add(stateType);
        stateTypes.addAll(argTypes);
        List> finalTypes = Collections.>singletonList(stateType);
        return new UDAggregate(name,
                               argTypes,
                               returnType,
                               resolveScalar(functions, name, stateFunc, stateTypes),
                               finalFunc != null ? resolveScalar(functions, name, finalFunc, finalTypes) : null,
                               initcond);
    }

    public static UDAggregate createBroken(FunctionName name,
                                           List> argTypes,
                                           AbstractType returnType,
                                           ByteBuffer initcond,
                                           final InvalidRequestException reason)
    {
        return new UDAggregate(name, argTypes, returnType, null, null, initcond)
        {
            public Aggregate newAggregate() throws InvalidRequestException
            {
                throw new InvalidRequestException(String.format("Aggregate '%s' exists but hasn't been loaded successfully for the following reason: %s. "
                                                                + "Please see the server log for more details",
                                                                this,
                                                                reason.getMessage()));
            }
        };
    }

    public boolean hasReferenceTo(Function function)
    {
        return stateFunction == function || finalFunction == function;
    }

    @Override
    public void addFunctionsTo(List functions)
    {
        functions.add(this);
        if (stateFunction != null)
        {
            stateFunction.addFunctionsTo(functions);

            if (finalFunction != null)
                finalFunction.addFunctionsTo(functions);
        }
    }

    public boolean isAggregate()
    {
        return true;
    }

    public boolean isNative()
    {
        return false;
    }

    public ScalarFunction stateFunction()
    {
        return stateFunction;
    }

    public ScalarFunction finalFunction()
    {
        return finalFunction;
    }

    public ByteBuffer initialCondition()
    {
        return initcond;
    }

    public AbstractType stateType()
    {
        return stateType;
    }

    public Aggregate newAggregate() throws InvalidRequestException
    {
        return new Aggregate()
        {
            private long stateFunctionCount;
            private long stateFunctionDuration;

            private ByteBuffer state;
            {
                reset();
            }

            public void addInput(int protocolVersion, List values) throws InvalidRequestException
            {
                long startTime = System.nanoTime();
                stateFunctionCount++;
                List fArgs = new ArrayList<>(values.size() + 1);
                fArgs.add(state);
                fArgs.addAll(values);
                if (stateFunction instanceof UDFunction)
                {
                    UDFunction udf = (UDFunction)stateFunction;
                    if (udf.isCallableWrtNullable(fArgs))
                        state = udf.execute(protocolVersion, fArgs);
                }
                else
                {
                    state = stateFunction.execute(protocolVersion, fArgs);
                }
                stateFunctionDuration += (System.nanoTime() - startTime) / 1000;
            }

            public ByteBuffer compute(int protocolVersion) throws InvalidRequestException
            {
                // final function is traced in UDFunction
                Tracing.trace("Executed UDA {}: {} call(s) to state function {} in {}\u03bcs", name(), stateFunctionCount, stateFunction.name(), stateFunctionDuration);
                if (finalFunction == null)
                    return state;

                List fArgs = Collections.singletonList(state);
                ByteBuffer result = finalFunction.execute(protocolVersion, fArgs);
                return result;
            }

            public void reset()
            {
                state = initcond != null ? initcond.duplicate() : null;
                stateFunctionDuration = 0;
                stateFunctionCount = 0;
            }
        };
    }

    private static ScalarFunction resolveScalar(Functions functions, FunctionName aName, FunctionName fName, List> argTypes) throws InvalidRequestException
    {
        Optional fun = functions.find(fName, argTypes);
        if (!fun.isPresent())
            throw new InvalidRequestException(String.format("Referenced state function '%s %s' for aggregate '%s' does not exist",
                                                            fName,
                                                            Arrays.toString(UDHelper.driverTypes(argTypes)),
                                                            aName));

        if (!(fun.get() instanceof ScalarFunction))
            throw new InvalidRequestException(String.format("Referenced state function '%s %s' for aggregate '%s' is not a scalar function",
                                                            fName,
                                                            Arrays.toString(UDHelper.driverTypes(argTypes)),
                                                            aName));
        return (ScalarFunction) fun.get();
    }

    @Override
    public boolean equals(Object o)
    {
        if (!(o instanceof UDAggregate))
            return false;

        UDAggregate that = (UDAggregate) o;
        return Objects.equal(name, that.name)
            && Functions.typesMatch(argTypes, that.argTypes)
            && Functions.typesMatch(returnType, that.returnType)
            && Objects.equal(stateFunction, that.stateFunction)
            && Objects.equal(finalFunction, that.finalFunction)
            && ((stateType == that.stateType) || ((stateType != null) && stateType.equals(that.stateType, true)))  // ignore freezing
            && Objects.equal(initcond, that.initcond);
    }

    @Override
    public int hashCode()
    {
        return Objects.hashCode(name, Functions.typeHashCode(argTypes), Functions.typeHashCode(returnType), stateFunction, finalFunction, stateType, initcond);
    }
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy