All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.trino.operator.join.BigintPagesHash Maven / Gradle / Ivy

There is a newer version: 465
Show newest version
/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package io.trino.operator.join;

import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import io.trino.operator.HashArraySizeSupplier;
import io.trino.operator.PagesHashStrategy;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.Block;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;

import java.util.Arrays;
import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.slice.SizeOf.instanceSize;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.airlift.slice.SizeOf.sizeOfIntArray;
import static io.airlift.slice.SizeOf.sizeOfLongArray;
import static io.airlift.units.DataSize.Unit.KILOBYTE;
import static io.trino.operator.SyntheticAddress.decodePosition;
import static io.trino.operator.SyntheticAddress.decodeSliceIndex;
import static io.trino.operator.join.PagesHash.getHashPosition;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

/**
 * This implementation assumes:
 * -There is only one join channel and it is of type bigint
 * -arrays used in the hash are always a power of 2.
 */
public final class BigintPagesHash
        implements PagesHash
{
    private static final int INSTANCE_SIZE = instanceSize(BigintPagesHash.class);
    private static final DataSize CACHE_SIZE = DataSize.of(128, KILOBYTE);

    private final LongArrayList addresses;
    private final List joinChannelBlocks;
    private final PagesHashStrategy pagesHashStrategy;

    private final int mask;
    private final int[] keys;
    private final long[] values;
    private final long size;

    public BigintPagesHash(
            LongArrayList addresses,
            PagesHashStrategy pagesHashStrategy,
            PositionLinks.FactoryBuilder positionLinks,
            HashArraySizeSupplier hashArraySizeSupplier,
            List pages,
            int joinChannel)
    {
        this.addresses = requireNonNull(addresses, "addresses is null");
        this.pagesHashStrategy = requireNonNull(pagesHashStrategy, "pagesHashStrategy is null");
        requireNonNull(pages, "pages is null");
        ImmutableList.Builder joinChannelBlocksBuilder = ImmutableList.builder();
        for (Page page : pages) {
            joinChannelBlocksBuilder.add(page.getBlock(joinChannel));
        }
        joinChannelBlocks = joinChannelBlocksBuilder.build();

        // reserve memory for the arrays
        int hashSize = hashArraySizeSupplier.getHashArraySize(addresses.size());

        mask = hashSize - 1;
        keys = new int[hashSize];
        values = new long[addresses.size()];
        Arrays.fill(keys, -1);

        // We will process addresses in batches, to improve spatial and temporal memory locality
        int positionsInStep = Math.min(addresses.size() + 1, (int) CACHE_SIZE.toBytes() / Integer.SIZE);

        for (int step = 0; step * positionsInStep <= addresses.size(); step++) {
            int stepBeginPosition = step * positionsInStep;
            int stepEndPosition = Math.min((step + 1) * positionsInStep, addresses.size());
            int stepSize = stepEndPosition - stepBeginPosition;

            indexPages(addresses, positionLinks, stepBeginPosition, stepSize);
        }

        size = sizeOf(addresses.elements()) + pagesHashStrategy.getSizeInBytes() +
                sizeOf(keys) + sizeOf(values);
    }

    private void indexPages(LongArrayList addresses, PositionLinks.FactoryBuilder positionLinks, int stepBeginPosition, int stepSize)
    {
        // index pages
        for (int batchIndex = 0; batchIndex < stepSize; batchIndex++) {
            int addressIndex = batchIndex + stepBeginPosition;
            if (isPositionNull(addressIndex)) {
                continue;
            }

            long address = addresses.getLong(addressIndex);
            int blockIndex = decodeSliceIndex(address);
            int blockPosition = decodePosition(address);
            long value = BIGINT.getLong(joinChannelBlocks.get(blockIndex), blockPosition);

            int pos = getHashPosition(value, mask);

            insertValue(positionLinks, addressIndex, value, pos);
        }
    }

    private void insertValue(PositionLinks.FactoryBuilder positionLinks, int addressIndex, long value, int pos)
    {
        // look for an empty slot or a slot containing this key
        while (keys[pos] != -1) {
            int currentKey = keys[pos];
            if (value == values[currentKey]) {
                // found a slot for this key
                // link the new key position to the current key position
                addressIndex = positionLinks.link(addressIndex, currentKey);

                // key[pos] updated outside of this loop
                break;
            }
            // increment position and mask to handler wrap around
            pos = (pos + 1) & mask;
        }

        keys[pos] = addressIndex;
        values[addressIndex] = value;
    }

    @Override
    public int getPositionCount()
    {
        return addresses.size();
    }

    @Override
    public long getInMemorySizeInBytes()
    {
        return INSTANCE_SIZE + size;
    }

    @Override
    public int getAddressIndex(int position, Page hashChannelsPage, long rawHash)
    {
        return getAddressIndex(position, hashChannelsPage);
    }

    @Override
    public int getAddressIndex(int position, Page hashChannelsPage)
    {
        long value = BIGINT.getLong(hashChannelsPage.getBlock(0), position);
        int pos = getHashPosition(value, mask);

        while (keys[pos] != -1) {
            if (value == values[keys[pos]]) {
                return keys[pos];
            }
            // increment position and mask to handler wrap around
            pos = (pos + 1) & mask;
        }
        return -1;
    }

    @Override
    public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes)
    {
        return getAddressIndex(positions, hashChannelsPage);
    }

    @Override
    public int[] getAddressIndex(int[] positions, Page hashChannelsPage)
    {
        checkArgument(hashChannelsPage.getChannelCount() == 1, "Multiple channel page passed to BigintPagesHash");

        int positionCount = positions.length;
        long[] incomingValues = new long[positionCount];
        int[] hashPositions = new int[positionCount];

        extractAndHashValues(positions, hashChannelsPage, positionCount, incomingValues, hashPositions);

        int[] found = new int[positionCount];
        int foundCount = 0;
        int[] result = new int[positionCount];
        Arrays.fill(result, -1);
        int[] foundKeys = new int[positionCount];

        // Search for positions in the hash array. This is the most CPU-consuming part as
        // it relies on random memory accesses
        findPositions(positionCount, hashPositions, foundKeys);
        // Found positions are put into `found` array
        for (int i = 0; i < positionCount; i++) {
            if (foundKeys[i] != -1) {
                found[foundCount++] = i;
            }
        }

        // At this step we determine if the found keys were indeed the proper ones or it is a hash collision.
        // The result array is updated for the found ones, while the collisions land into `remaining` array.
        int remainingCount = checkFoundPositions(incomingValues, found, foundCount, result, foundKeys);
        int[] remaining = found; // Rename for readability

        // At this point for any reasoable load factor of a hash array (< .75), there is no more than
        // 10 - 15% of positions left. We search for them in a sequential order and update the result array.
        findRemainingPositions(incomingValues, hashPositions, result, remaining, remainingCount);

        return result;
    }

    private void findRemainingPositions(long[] incomingValues, int[] hashPositions, int[] result, int[] remaining, int remainingCount)
    {
        for (int i = 0; i < remainingCount; i++) {
            int index = remaining[i];
            int position = (hashPositions[index] + 1) & mask; // hashPositions[index] position has already been checked

            while (keys[position] != -1) {
                if (values[keys[position]] == incomingValues[index]) {
                    result[index] = keys[position];
                    break;
                }
                // increment position and mask to handler wrap around
                position = (position + 1) & mask;
            }
        }
    }

    private int checkFoundPositions(long[] incomingValues, int[] found, int foundCount, int[] result, int[] foundKeys)
    {
        int[] remaining = found; // Rename for readability
        int remainingCount = 0;
        for (int i = 0; i < foundCount; i++) {
            int index = found[i];
            if (values[foundKeys[index]] == incomingValues[index]) {
                result[index] = foundKeys[index];
            }
            else {
                remaining[remainingCount++] = index;
            }
        }
        return remainingCount;
    }

    private void findPositions(int positionCount, int[] hashPositions, int[] foundKeys)
    {
        for (int i = 0; i < positionCount; i++) {
            foundKeys[i] = keys[hashPositions[i]];
        }
    }

    private void extractAndHashValues(int[] positions, Page hashChannelsPage, int positionCount, long[] incomingValues, int[] hashPositions)
    {
        for (int i = 0; i < positionCount; i++) {
            incomingValues[i] = BIGINT.getLong(hashChannelsPage.getBlock(0), positions[i]);
            hashPositions[i] = getHashPosition(incomingValues[i], mask);
        }
    }

    @Override
    public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset)
    {
        long pageAddress = addresses.getLong(toIntExact(position));
        int blockIndex = decodeSliceIndex(pageAddress);
        int blockPosition = decodePosition(pageAddress);

        pagesHashStrategy.appendTo(blockIndex, blockPosition, pageBuilder, outputChannelOffset);
    }

    private boolean isPositionNull(int position)
    {
        long pageAddress = addresses.getLong(position);
        int blockIndex = decodeSliceIndex(pageAddress);
        int blockPosition = decodePosition(pageAddress);

        return joinChannelBlocks.get(blockIndex).isNull(blockPosition);
    }

    public static long getEstimatedRetainedSizeInBytes(
            int positionCount,
            HashArraySizeSupplier hashArraySizeSupplier,
            LongArrayList addresses,
            List> channels,
            long blocksSizeInBytes)
    {
        return sizeOf(addresses.elements()) +
                (channels.size() > 0 ? sizeOf(channels.get(0).elements()) * channels.size() : 0) +
                blocksSizeInBytes +
                sizeOfIntArray(hashArraySizeSupplier.getHashArraySize(positionCount)) +
                sizeOfLongArray(positionCount);
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy