Many resources are needed to download a project. Please understand that we have to compensate our server costs. Thank you in advance. Project price only 1 $
You can buy this project and download/modify it how often you want.
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.aggregation;
import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.operator.MarkDistinctHash;
import io.trino.operator.UpdateMemory;
import io.trino.operator.Work;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.block.IntArrayBlock;
import io.trino.spi.type.Type;
import io.trino.sql.gen.JoinCompiler;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;
import static com.google.common.base.Preconditions.checkState;
import static io.trino.spi.type.IntegerType.INTEGER;
import static java.util.Objects.requireNonNull;
public class DistinctAccumulatorFactory
implements AccumulatorFactory
{
private final AccumulatorFactory delegate;
private final List argumentTypes;
private final JoinCompiler joinCompiler;
private final Session session;
public DistinctAccumulatorFactory(
AccumulatorFactory delegate,
List argumentTypes,
JoinCompiler joinCompiler,
Session session)
{
this.delegate = requireNonNull(delegate, "delegate is null");
this.argumentTypes = ImmutableList.copyOf(requireNonNull(argumentTypes, "argumentTypes is null"));
this.joinCompiler = requireNonNull(joinCompiler, "joinCompiler is null");
this.session = requireNonNull(session, "session is null");
}
@Override
public List> getLambdaInterfaces()
{
return delegate.getLambdaInterfaces();
}
@Override
public Accumulator createAccumulator(List> lambdaProviders)
{
return new DistinctAccumulator(
delegate.createAccumulator(lambdaProviders),
argumentTypes,
session,
joinCompiler);
}
@Override
public Accumulator createIntermediateAccumulator(List> lambdaProviders)
{
return delegate.createIntermediateAccumulator(lambdaProviders);
}
@Override
public GroupedAccumulator createGroupedAccumulator(List> lambdaProviders)
{
return new DistinctGroupedAccumulator(
delegate.createGroupedAccumulator(lambdaProviders),
argumentTypes,
session,
joinCompiler);
}
@Override
public GroupedAccumulator createGroupedIntermediateAccumulator(List> lambdaProviders)
{
return delegate.createGroupedIntermediateAccumulator(lambdaProviders);
}
@Override
public AggregationMaskBuilder createAggregationMaskBuilder()
{
return delegate.createAggregationMaskBuilder();
}
private static class DistinctAccumulator
implements Accumulator
{
private final Accumulator accumulator;
private final MarkDistinctHash hash;
private DistinctAccumulator(
Accumulator accumulator,
List inputTypes,
Session session,
JoinCompiler joinCompiler)
{
this.accumulator = requireNonNull(accumulator, "accumulator is null");
this.hash = new MarkDistinctHash(
session,
inputTypes,
false,
joinCompiler,
UpdateMemory.NOOP);
}
@Override
public long getEstimatedSize()
{
return hash.getEstimatedSize() + accumulator.getEstimatedSize();
}
@Override
public Accumulator copy()
{
throw new UnsupportedOperationException("Distinct aggregation function state can not be copied");
}
@Override
public void addInput(Page arguments, AggregationMask mask)
{
// 1. filter out positions based on mask, if present
Page filtered = mask.filterPage(arguments);
// 2. compute a distinct mask block
Work work = hash.markDistinctRows(filtered);
checkState(work.process());
Block distinctMask = work.getResult();
// 3. update original mask to the new distinct mask block
mask.reset(filtered.getPositionCount());
mask.applyMaskBlock(distinctMask);
if (mask.isSelectNone()) {
return;
}
// 4. feed a Page with a new mask to the underlying aggregation
accumulator.addInput(filtered, mask);
}
@Override
public void addIntermediate(Block block)
{
throw new UnsupportedOperationException();
}
@Override
public void evaluateIntermediate(BlockBuilder blockBuilder)
{
throw new UnsupportedOperationException();
}
@Override
public void evaluateFinal(BlockBuilder blockBuilder)
{
accumulator.evaluateFinal(blockBuilder);
}
}
private static class DistinctGroupedAccumulator
implements GroupedAccumulator
{
private final GroupedAccumulator accumulator;
private final MarkDistinctHash hash;
private DistinctGroupedAccumulator(
GroupedAccumulator accumulator,
List inputTypes,
Session session,
JoinCompiler joinCompiler)
{
this.accumulator = requireNonNull(accumulator, "accumulator is null");
this.hash = new MarkDistinctHash(
session,
ImmutableList.builder()
.add(INTEGER) // group id column
.addAll(inputTypes)
.build(),
false,
joinCompiler,
UpdateMemory.NOOP);
}
@Override
public long getEstimatedSize()
{
return hash.getEstimatedSize() + accumulator.getEstimatedSize();
}
@Override
public void setGroupCount(long groupCount)
{
accumulator.setGroupCount(groupCount);
}
@Override
public void addInput(int[] groupIds, Page page, AggregationMask mask)
{
// 1. filter out positions based on mask
groupIds = maskGroupIds(groupIds, mask);
page = mask.filterPage(page);
// 2. compute a mask for the distinct rows (including the group id)
Work work = hash.markDistinctRows(page.prependColumn(new IntArrayBlock(page.getPositionCount(), Optional.empty(), groupIds)));
checkState(work.process());
Block distinctMask = work.getResult();
// 3. update original mask to the new distinct mask block
mask.reset(page.getPositionCount());
mask.applyMaskBlock(distinctMask);
if (mask.isSelectNone()) {
return;
}
// 4. feed a Page with a new mask to the underlying aggregation
accumulator.addInput(groupIds, page, mask);
}
private static int[] maskGroupIds(int[] groupIds, AggregationMask mask)
{
if (mask.isSelectAll() || mask.isSelectNone()) {
return groupIds;
}
int[] newGroupIds = new int[mask.getSelectedPositionCount()];
int[] selectedPositions = mask.getSelectedPositions();
for (int i = 0; i < newGroupIds.length; i++) {
newGroupIds[i] = groupIds[selectedPositions[i]];
}
return newGroupIds;
}
@Override
public void addIntermediate(int[] groupIds, Block block)
{
throw new UnsupportedOperationException();
}
@Override
public void evaluateIntermediate(int groupId, BlockBuilder output)
{
throw new UnsupportedOperationException();
}
@Override
public void evaluateFinal(int groupId, BlockBuilder output)
{
accumulator.evaluateFinal(groupId, output);
}
@Override
public void prepareFinal() {}
}
}