All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.runtime.operators.hash.InPlaceMutableHashTable Maven / Gradle / Ivy

There is a newer version: 1.19.0
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.operators.hash;

import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.typeutils.SameTypePairComparator;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypePairComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.disk.RandomAccessInputView;
import org.apache.flink.runtime.memory.AbstractPagedOutputView;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MathUtils;
import org.apache.flink.util.MutableObjectIterator;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.EOFException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * This hash table supports updating elements. If the new element has the same size as the old
 * element, then the update is done in-place. Otherwise a hole is created at the place of the old
 * record, which will eventually be removed by a compaction.
 *
 * 

The memory is divided into three areas: - Bucket area: they contain bucket heads: an 8 byte * pointer to the first link of a linked list in the record area - Record area: this contains the * actual data in linked list elements. A linked list element starts with an 8 byte pointer to the * next element, and then the record follows. - Staging area: This is a small, temporary storage * area for writing updated records. This is needed, because before serializing a record, there is * no way to know in advance how large will it be. Therefore, we can't serialize directly into the * record area when we are doing an update, because if it turns out to be larger than the old * record, then it would override some other record that happens to be after the old one in memory. * The solution is to serialize to the staging area first, and then copy it to the place of the * original if it has the same size, otherwise allocate a new linked list element at the end of the * record area, and mark the old one as abandoned. This creates "holes" in the record area, so * compactions are eventually needed. * *

Compaction happens by deleting everything in the bucket area, and then reinserting all * elements. The reinsertion happens by forgetting the structure (the linked lists) of the record * area, and reading it sequentially, and inserting all non-abandoned records, starting from the * beginning of the record area. Note, that insertions never override a record that hasn't been read * by the reinsertion sweep, because both the insertions and readings happen sequentially in the * record area, and the insertions obviously never overtake the reading sweep. * *

Note: we have to abandon the old linked list element even when the updated record has a * smaller size than the original, because otherwise we wouldn't know where the next record starts * during a reinsertion sweep. * *

The number of buckets depends on how large are the records. The serializer might be able to * tell us this, so in this case, we will calculate the number of buckets upfront, and won't do * resizes. If the serializer doesn't know the size, then we start with a small number of buckets, * and do resizes as more elements are inserted than the number of buckets. * *

The number of memory segments given to the staging area is usually one, because it just needs * to hold one record. * *

Note: For hashing, we couldn't just take the lower bits, but have to use a proper hash * function from MathUtils because of its avalanche property, so that changing only some high bits * of the original value won't leave the lower bits of the hash unaffected. This is because when * choosing the bucket for a record, we mask only the lower bits (see numBucketsMask). Lots of * collisions would occur when, for example, the original value that is hashed is some bitset, where * lots of different values that are different only in the higher bits will actually occur. */ public class InPlaceMutableHashTable extends AbstractMutableHashTable { private static final Logger LOG = LoggerFactory.getLogger(InPlaceMutableHashTable.class); /** * The minimum number of memory segments InPlaceMutableHashTable needs to be supplied with in * order to work. */ private static final int MIN_NUM_MEMORY_SEGMENTS = 3; // Note: the following two constants can't be negative, because negative values are reserved for // storing the // negated size of the record, when it is abandoned (not part of any linked list). /** The last link in the linked lists will have this as next pointer. */ private static final long END_OF_LIST = Long.MAX_VALUE; /** * This value means that prevElemPtr is "pointing to the bucket head", and not into the record * segments. */ private static final long INVALID_PREV_POINTER = Long.MAX_VALUE - 1; private static final long RECORD_OFFSET_IN_LINK = 8; /** * This initially contains all the memory we have, and then segments are taken from it by * bucketSegments, recordArea, and stagingSegments. */ private final ArrayList freeMemorySegments; private final int numAllMemorySegments; private final int segmentSize; /** * These will contain the bucket heads. The bucket heads are pointers to the linked lists * containing the actual records. */ private MemorySegment[] bucketSegments; private static final int bucketSize = 8, bucketSizeBits = 3; private int numBuckets; private int numBucketsMask; private final int numBucketsPerSegment, numBucketsPerSegmentBits, numBucketsPerSegmentMask; /** The segments where the actual data is stored. */ private final RecordArea recordArea; /** Segments for the staging area. (It should contain at most one record at all times.) */ private final ArrayList stagingSegments; private final RandomAccessInputView stagingSegmentsInView; private final StagingOutputView stagingSegmentsOutView; private T reuse; /** This is the internal prober that insertOrReplaceRecord uses. */ private final HashTableProber prober; /** The number of elements currently held by the table. */ private long numElements = 0; /** * The number of bytes wasted by updates that couldn't overwrite the old record due to size * change. */ private long holes = 0; /** * If the serializer knows the size of the records, then we can calculate the optimal number of * buckets upfront, so we don't need resizes. */ private boolean enableResize; public InPlaceMutableHashTable( TypeSerializer serializer, TypeComparator comparator, List memory) { super(serializer, comparator); this.numAllMemorySegments = memory.size(); this.freeMemorySegments = new ArrayList<>(memory); // some sanity checks first if (freeMemorySegments.size() < MIN_NUM_MEMORY_SEGMENTS) { throw new IllegalArgumentException( "Too few memory segments provided. InPlaceMutableHashTable needs at least " + MIN_NUM_MEMORY_SEGMENTS + " memory segments."); } // Get the size of the first memory segment and record it. All further buffers must have the // same size. // the size must also be a power of 2 segmentSize = freeMemorySegments.get(0).size(); if ((segmentSize & segmentSize - 1) != 0) { throw new IllegalArgumentException( "Hash Table requires buffers whose size is a power of 2."); } this.numBucketsPerSegment = segmentSize / bucketSize; this.numBucketsPerSegmentBits = MathUtils.log2strict(this.numBucketsPerSegment); this.numBucketsPerSegmentMask = (1 << this.numBucketsPerSegmentBits) - 1; recordArea = new RecordArea(segmentSize); stagingSegments = new ArrayList<>(); stagingSegments.add(forcedAllocateSegment()); stagingSegmentsInView = new RandomAccessInputView(stagingSegments, segmentSize); stagingSegmentsOutView = new StagingOutputView(stagingSegments, segmentSize); prober = new HashTableProber<>( buildSideComparator, new SameTypePairComparator<>(buildSideComparator)); enableResize = buildSideSerializer.getLength() == -1; } /** * Gets the total capacity of this hash table, in bytes. * * @return The hash table's total capacity. */ public long getCapacity() { return numAllMemorySegments * (long) segmentSize; } /** * Gets the number of bytes currently occupied in this hash table. * * @return The number of bytes occupied. */ public long getOccupancy() { return numAllMemorySegments * segmentSize - freeMemorySegments.size() * segmentSize; } private void open(int numBucketSegments) { synchronized (stateLock) { if (!closed) { throw new IllegalStateException("currently not closed."); } closed = false; } allocateBucketSegments(numBucketSegments); stagingSegments.add(forcedAllocateSegment()); reuse = buildSideSerializer.createInstance(); } /** Initialize the hash table */ @Override public void open() { open(calcInitialNumBucketSegments()); } @Override public void close() { // make sure that we close only once synchronized (stateLock) { if (closed) { // We have to do this here, because the ctor already allocates a segment to the // record area and // the staging area, even before we are opened. So we might have segments to free, // even if we // are closed. recordArea.giveBackSegments(); freeMemorySegments.addAll(stagingSegments); stagingSegments.clear(); return; } closed = true; } LOG.debug("Closing InPlaceMutableHashTable and releasing resources."); releaseBucketSegments(); recordArea.giveBackSegments(); freeMemorySegments.addAll(stagingSegments); stagingSegments.clear(); numElements = 0; holes = 0; } @Override public void abort() { LOG.debug("Aborting InPlaceMutableHashTable."); close(); } @Override public List getFreeMemory() { if (!this.closed) { throw new IllegalStateException( "Cannot return memory while InPlaceMutableHashTable is open."); } return freeMemorySegments; } private int calcInitialNumBucketSegments() { int recordLength = buildSideSerializer.getLength(); double fraction; // fraction of memory to use for the buckets if (recordLength == -1) { // We don't know the record length, so we start with a small number of buckets, and do // resizes if // necessary. // It seems that resizing is quite efficient, so we can err here on the too few bucket // segments side. // Even with small records, we lose only ~15% speed. fraction = 0.1; } else { // We know the record length, so we can find a good value for the number of buckets // right away, and // won't need any resizes later. (enableResize is false in this case, so no resizing // will happen.) // Reasoning behind the formula: // We are aiming for one bucket per record, and one bucket contains one 8 byte pointer. // The total // memory overhead of an element will be approximately 8+8 bytes, as the record in the // record area // is preceded by a pointer (for the linked list). fraction = 8.0 / (16 + recordLength); } // We make the number of buckets a power of 2 so that taking modulo is efficient. int ret = Math.max(1, MathUtils.roundDownToPowerOf2((int) (numAllMemorySegments * fraction))); // We can't handle more than Integer.MAX_VALUE buckets (eg. because hash functions return // int) if ((long) ret * numBucketsPerSegment > Integer.MAX_VALUE) { ret = MathUtils.roundDownToPowerOf2(Integer.MAX_VALUE / numBucketsPerSegment); } return ret; } private void allocateBucketSegments(int numBucketSegments) { if (numBucketSegments < 1) { throw new RuntimeException("Bug in InPlaceMutableHashTable"); } bucketSegments = new MemorySegment[numBucketSegments]; for (int i = 0; i < bucketSegments.length; i++) { bucketSegments[i] = forcedAllocateSegment(); // Init all pointers in all buckets to END_OF_LIST for (int j = 0; j < numBucketsPerSegment; j++) { bucketSegments[i].putLong(j << bucketSizeBits, END_OF_LIST); } } numBuckets = numBucketSegments * numBucketsPerSegment; numBucketsMask = (1 << MathUtils.log2strict(numBuckets)) - 1; } private void releaseBucketSegments() { freeMemorySegments.addAll(Arrays.asList(bucketSegments)); bucketSegments = null; } private MemorySegment allocateSegment() { int s = freeMemorySegments.size(); if (s > 0) { return freeMemorySegments.remove(s - 1); } else { return null; } } private MemorySegment forcedAllocateSegment() { MemorySegment segment = allocateSegment(); if (segment == null) { throw new RuntimeException( "Bug in InPlaceMutableHashTable: A free segment should have been available."); } return segment; } /** * Searches the hash table for a record with the given key. If it is found, then it is * overridden with the specified record. Otherwise, the specified record is inserted. * * @param record The record to insert or to replace with. * @throws IOException (EOFException specifically, if memory ran out) */ @Override public void insertOrReplaceRecord(T record) throws IOException { if (closed) { return; } T match = prober.getMatchFor(record, reuse); if (match == null) { prober.insertAfterNoMatch(record); } else { prober.updateMatch(record); } } /** * Inserts the given record into the hash table. Note: this method doesn't care about whether a * record with the same key is already present. * * @param record The record to insert. * @throws IOException (EOFException specifically, if memory ran out) */ @Override public void insert(T record) throws IOException { if (closed) { return; } final int hashCode = MathUtils.jenkinsHash(buildSideComparator.hash(record)); final int bucket = hashCode & numBucketsMask; final int bucketSegmentIndex = bucket >>> numBucketsPerSegmentBits; // which segment contains the bucket final MemorySegment bucketSegment = bucketSegments[bucketSegmentIndex]; final int bucketOffset = (bucket & numBucketsPerSegmentMask) << bucketSizeBits; // offset of the bucket in the segment final long firstPointer = bucketSegment.getLong(bucketOffset); try { final long newFirstPointer = recordArea.appendPointerAndRecord(firstPointer, record); bucketSegment.putLong(bucketOffset, newFirstPointer); } catch (EOFException ex) { compactOrThrow(); insert(record); return; } numElements++; resizeTableIfNecessary(); } private void resizeTableIfNecessary() throws IOException { if (enableResize && numElements > numBuckets) { final long newNumBucketSegments = 2L * bucketSegments.length; // Checks: // - we can't handle more than Integer.MAX_VALUE buckets // - don't take more memory than the free memory we have left // - the buckets shouldn't occupy more than half of all our memory if (newNumBucketSegments * numBucketsPerSegment < Integer.MAX_VALUE && newNumBucketSegments - bucketSegments.length < freeMemorySegments.size() && newNumBucketSegments < numAllMemorySegments / 2) { // do the resize rebuild(newNumBucketSegments); } } } /** * Returns an iterator that can be used to iterate over all the elements in the table. WARNING: * Doing any other operation on the table invalidates the iterator! (Even using getMatchFor of a * prober!) * * @return the iterator */ @Override public EntryIterator getEntryIterator() { return new EntryIterator(); } public HashTableProber getProber( TypeComparator probeTypeComparator, TypePairComparator pairComparator) { return new HashTableProber<>(probeTypeComparator, pairComparator); } /** * This function reinitializes the bucket segments, reads all records from the record segments * (sequentially, without using the pointers or the buckets), and rebuilds the hash table. */ private void rebuild() throws IOException { rebuild(bucketSegments.length); } /** Same as above, but the number of bucket segments of the new table can be specified. */ private void rebuild(long newNumBucketSegments) throws IOException { // Get new bucket segments releaseBucketSegments(); allocateBucketSegments((int) newNumBucketSegments); T record = buildSideSerializer.createInstance(); try { EntryIterator iter = getEntryIterator(); recordArea.resetAppendPosition(); recordArea.setWritePosition(0); while ((record = iter.next(record)) != null && !closed) { final int hashCode = MathUtils.jenkinsHash(buildSideComparator.hash(record)); final int bucket = hashCode & numBucketsMask; final int bucketSegmentIndex = bucket >>> numBucketsPerSegmentBits; // which segment contains the bucket final MemorySegment bucketSegment = bucketSegments[bucketSegmentIndex]; final int bucketOffset = (bucket & numBucketsPerSegmentMask) << bucketSizeBits; // offset of the bucket in the segment final long firstPointer = bucketSegment.getLong(bucketOffset); long ptrToAppended = recordArea.noSeekAppendPointerAndRecord(firstPointer, record); bucketSegment.putLong(bucketOffset, ptrToAppended); } recordArea.freeSegmentsAfterAppendPosition(); holes = 0; } catch (EOFException ex) { throw new RuntimeException( "Bug in InPlaceMutableHashTable: we shouldn't get out of memory during a rebuild, " + "because we aren't allocating any new memory."); } } /** * If there is wasted space (due to updated records not fitting in their old places), then do a * compaction. Else, throw EOFException to indicate that memory ran out. * * @throws IOException */ private void compactOrThrow() throws IOException { if (holes > (double) recordArea.getTotalSize() * 0.05) { rebuild(); } else { throw new EOFException( "InPlaceMutableHashTable memory ran out. " + getMemoryConsumptionString()); } } /** @return String containing a summary of the memory consumption for error messages */ private String getMemoryConsumptionString() { return "InPlaceMutableHashTable memory stats:\n" + "Total memory: " + numAllMemorySegments * segmentSize + "\n" + "Free memory: " + freeMemorySegments.size() * segmentSize + "\n" + "Bucket area: " + numBuckets * 8 + "\n" + "Record area: " + recordArea.getTotalSize() + "\n" + "Staging area: " + stagingSegments.size() * segmentSize + "\n" + "Num of elements: " + numElements + "\n" + "Holes total size: " + holes; } /** * This class encapsulates the memory segments that belong to the record area. It - can append a * record - can overwrite a record at an arbitrary position (WARNING: the new record must have * the same size as the old one) - can be rewritten by calling resetAppendPosition - takes * memory from InPlaceMutableHashTable.freeMemorySegments on append */ private final class RecordArea { private final ArrayList segments = new ArrayList<>(); private final RecordAreaOutputView outView; private final RandomAccessInputView inView; private final int segmentSizeBits; private final int segmentSizeMask; private long appendPosition = 0; public RecordArea(int segmentSize) { int segmentSizeBits = MathUtils.log2strict(segmentSize); if ((segmentSize & (segmentSize - 1)) != 0) { throw new IllegalArgumentException("Segment size must be a power of 2!"); } this.segmentSizeBits = segmentSizeBits; this.segmentSizeMask = segmentSize - 1; outView = new RecordAreaOutputView(segmentSize); try { addSegment(); } catch (EOFException ex) { throw new RuntimeException( "Bug in InPlaceMutableHashTable: we should have caught it earlier " + "that we don't have enough segments."); } inView = new RandomAccessInputView(segments, segmentSize); } private void addSegment() throws EOFException { MemorySegment m = allocateSegment(); if (m == null) { throw new EOFException(); } segments.add(m); } /** * Moves all its memory segments to freeMemorySegments. Warning: this will leave the * RecordArea in an unwritable state: you have to call setWritePosition before writing * again. */ public void giveBackSegments() { freeMemorySegments.addAll(segments); segments.clear(); resetAppendPosition(); } public long getTotalSize() { return segments.size() * (long) segmentSize; } // ----------------------- Output ----------------------- private void setWritePosition(long position) throws EOFException { if (position > appendPosition) { throw new IndexOutOfBoundsException(); } final int segmentIndex = (int) (position >>> segmentSizeBits); final int offset = (int) (position & segmentSizeMask); // If position == appendPosition and the last buffer is full, // then we will be seeking to the beginning of a new segment if (segmentIndex == segments.size()) { addSegment(); } outView.currentSegmentIndex = segmentIndex; outView.seekOutput(segments.get(segmentIndex), offset); } /** * Sets appendPosition and the write position to 0, so that appending starts overwriting * elements from the beginning. (This is used in rebuild.) * *

Note: if data was written to the area after the current appendPosition before a call * to resetAppendPosition, it should still be readable. To release the segments after the * current append position, call freeSegmentsAfterAppendPosition() */ public void resetAppendPosition() { appendPosition = 0; // this is just for safety (making sure that we fail immediately // if a write happens without calling setWritePosition) outView.currentSegmentIndex = -1; outView.seekOutput(null, -1); } /** * Releases the memory segments that are after the current append position. Note: The * situation that there are segments after the current append position can arise from a call * to resetAppendPosition(). */ public void freeSegmentsAfterAppendPosition() { final int appendSegmentIndex = (int) (appendPosition >>> segmentSizeBits); while (segments.size() > appendSegmentIndex + 1 && !closed) { freeMemorySegments.add(segments.get(segments.size() - 1)); segments.remove(segments.size() - 1); } } /** * Overwrites the long value at the specified position. * * @param pointer Points to the position to overwrite. * @param value The value to write. * @throws IOException */ public void overwritePointerAt(long pointer, long value) throws IOException { setWritePosition(pointer); outView.writeLong(value); } /** * Overwrites a record at the specified position. The record is read from a DataInputView * (this will be the staging area). WARNING: The record must not be larger than the original * record. * * @param pointer Points to the position to overwrite. * @param input The DataInputView to read the record from * @param size The size of the record * @throws IOException */ public void overwriteRecordAt(long pointer, DataInputView input, int size) throws IOException { setWritePosition(pointer); outView.write(input, size); } /** * Appends a pointer and a record. The record is read from a DataInputView (this will be the * staging area). * * @param pointer The pointer to write (Note: this is NOT the position to write to!) * @param input The DataInputView to read the record from * @param recordSize The size of the record * @return A pointer to the written data * @throws IOException (EOFException specifically, if memory ran out) */ public long appendPointerAndCopyRecord(long pointer, DataInputView input, int recordSize) throws IOException { setWritePosition(appendPosition); final long oldLastPosition = appendPosition; outView.writeLong(pointer); outView.write(input, recordSize); appendPosition += 8 + recordSize; return oldLastPosition; } /** * Appends a pointer and a record. * * @param pointer The pointer to write (Note: this is NOT the position to write to!) * @param record The record to write * @return A pointer to the written data * @throws IOException (EOFException specifically, if memory ran out) */ public long appendPointerAndRecord(long pointer, T record) throws IOException { setWritePosition(appendPosition); return noSeekAppendPointerAndRecord(pointer, record); } /** * Appends a pointer and a record. Call this function only if the write position is at the * end! * * @param pointer The pointer to write (Note: this is NOT the position to write to!) * @param record The record to write * @return A pointer to the written data * @throws IOException (EOFException specifically, if memory ran out) */ public long noSeekAppendPointerAndRecord(long pointer, T record) throws IOException { final long oldLastPosition = appendPosition; final long oldPositionInSegment = outView.getCurrentPositionInSegment(); final long oldSegmentIndex = outView.currentSegmentIndex; outView.writeLong(pointer); buildSideSerializer.serialize(record, outView); appendPosition += outView.getCurrentPositionInSegment() - oldPositionInSegment + outView.getSegmentSize() * (outView.currentSegmentIndex - oldSegmentIndex); return oldLastPosition; } public long getAppendPosition() { return appendPosition; } // ----------------------- Input ----------------------- public void setReadPosition(long position) { inView.setReadPosition(position); } public long getReadPosition() { return inView.getReadPosition(); } /** * Note: this is sometimes a negated length instead of a pointer (see * HashTableProber.updateMatch). */ public long readPointer() throws IOException { return inView.readLong(); } public T readRecord(T reuse) throws IOException { return buildSideSerializer.deserialize(reuse, inView); } public void skipBytesToRead(int numBytes) throws IOException { inView.skipBytesToRead(numBytes); } // ----------------------------------------------------- private final class RecordAreaOutputView extends AbstractPagedOutputView { public int currentSegmentIndex; public RecordAreaOutputView(int segmentSize) { super(segmentSize, 0); } @Override protected MemorySegment nextSegment(MemorySegment current, int positionInCurrent) throws EOFException { currentSegmentIndex++; if (currentSegmentIndex == segments.size()) { addSegment(); } return segments.get(currentSegmentIndex); } @Override public void seekOutput(MemorySegment seg, int position) { super.seekOutput(seg, position); } } } private final class StagingOutputView extends AbstractPagedOutputView { private final ArrayList segments; private final int segmentSizeBits; private int currentSegmentIndex; public StagingOutputView(ArrayList segments, int segmentSize) { super(segmentSize, 0); this.segmentSizeBits = MathUtils.log2strict(segmentSize); this.segments = segments; } /** Seeks to the beginning. */ public void reset() { seekOutput(segments.get(0), 0); currentSegmentIndex = 0; } @Override protected MemorySegment nextSegment(MemorySegment current, int positionInCurrent) throws EOFException { currentSegmentIndex++; if (currentSegmentIndex == segments.size()) { MemorySegment m = allocateSegment(); if (m == null) { throw new EOFException(); } segments.add(m); } return segments.get(currentSegmentIndex); } public long getWritePosition() { return (((long) currentSegmentIndex) << segmentSizeBits) + getCurrentPositionInSegment(); } } /** * A prober for accessing the table. In addition to getMatchFor and updateMatch, it also has * insertAfterNoMatch. Warning: Don't modify the table between calling getMatchFor and the other * methods! * * @param The type of the records that we are probing with */ public final class HashTableProber extends AbstractHashTableProber { public HashTableProber( TypeComparator probeTypeComparator, TypePairComparator pairComparator) { super(probeTypeComparator, pairComparator); } private int bucketSegmentIndex; private int bucketOffset; private long curElemPtr; private long prevElemPtr; private long nextPtr; private long recordEnd; /** * Searches the hash table for the record with the given key. (If there would be multiple * matches, only one is returned.) * * @param record The record whose key we are searching for * @param targetForMatch If a match is found, it will be written here * @return targetForMatch if a match is found, otherwise null. */ @Override public T getMatchFor(PT record, T targetForMatch) { if (closed) { return null; } final int hashCode = MathUtils.jenkinsHash(probeTypeComparator.hash(record)); final int bucket = hashCode & numBucketsMask; bucketSegmentIndex = bucket >>> numBucketsPerSegmentBits; // which segment contains the bucket final MemorySegment bucketSegment = bucketSegments[bucketSegmentIndex]; bucketOffset = (bucket & numBucketsPerSegmentMask) << bucketSizeBits; // offset of the bucket in the segment curElemPtr = bucketSegment.getLong(bucketOffset); pairComparator.setReference(record); T currentRecordInList = targetForMatch; prevElemPtr = INVALID_PREV_POINTER; try { while (curElemPtr != END_OF_LIST && !closed) { recordArea.setReadPosition(curElemPtr); nextPtr = recordArea.readPointer(); currentRecordInList = recordArea.readRecord(currentRecordInList); recordEnd = recordArea.getReadPosition(); if (pairComparator.equalToReference(currentRecordInList)) { // we found an element with a matching key, and not just a hash collision return currentRecordInList; } prevElemPtr = curElemPtr; curElemPtr = nextPtr; } } catch (IOException ex) { throw new RuntimeException( "Error deserializing record from the hashtable: " + ex.getMessage(), ex); } return null; } @Override public T getMatchFor(PT probeSideRecord) { return getMatchFor(probeSideRecord, buildSideSerializer.createInstance()); } /** * This method can be called after getMatchFor returned a match. It will overwrite the * record that was found by getMatchFor. Warning: The new record should have the same key as * the old! WARNING; Don't do any modifications to the table between getMatchFor and * updateMatch! * * @param newRecord The record to override the old record with. * @throws IOException (EOFException specifically, if memory ran out) */ @Override public void updateMatch(T newRecord) throws IOException { if (closed) { return; } if (curElemPtr == END_OF_LIST) { throw new RuntimeException( "updateMatch was called after getMatchFor returned no match"); } try { // determine the new size stagingSegmentsOutView.reset(); buildSideSerializer.serialize(newRecord, stagingSegmentsOutView); final int newRecordSize = (int) stagingSegmentsOutView.getWritePosition(); stagingSegmentsInView.setReadPosition(0); // Determine the size of the place of the old record. final int oldRecordSize = (int) (recordEnd - (curElemPtr + RECORD_OFFSET_IN_LINK)); if (newRecordSize == oldRecordSize) { // overwrite record at its original place recordArea.overwriteRecordAt( curElemPtr + RECORD_OFFSET_IN_LINK, stagingSegmentsInView, newRecordSize); } else { // new record has a different size than the old one, append new at the end of // the record area. // Note: we have to do this, even if the new record is smaller, because // otherwise EntryIterator // wouldn't know the size of this place, and wouldn't know where does the next // record start. final long pointerToAppended = recordArea.appendPointerAndCopyRecord( nextPtr, stagingSegmentsInView, newRecordSize); // modify the pointer in the previous link if (prevElemPtr == INVALID_PREV_POINTER) { // list had only one element, so prev is in the bucketSegments bucketSegments[bucketSegmentIndex].putLong(bucketOffset, pointerToAppended); } else { recordArea.overwritePointerAt(prevElemPtr, pointerToAppended); } // write the negated size of the hole to the place where the next pointer was, // so that EntryIterator // will know the size of the place without reading the old record. // The negative sign will mean that the record is abandoned, and the // the -1 is for avoiding trouble in case of a record having 0 size. (though I // think this should // never actually happen) // Note: the last record in the record area can't be abandoned. (EntryIterator // makes use of this fact.) recordArea.overwritePointerAt(curElemPtr, -oldRecordSize - 1); holes += oldRecordSize; } } catch (EOFException ex) { compactOrThrow(); insertOrReplaceRecord(newRecord); } } /** * This method can be called after getMatchFor returned null. It inserts the given record to * the hash table. Important: The given record should have the same key as the record that * was given to getMatchFor! WARNING; Don't do any modifications to the table between * getMatchFor and insertAfterNoMatch! * * @throws IOException (EOFException specifically, if memory ran out) */ public void insertAfterNoMatch(T record) throws IOException { if (closed) { return; } // create new link long pointerToAppended; try { pointerToAppended = recordArea.appendPointerAndRecord(END_OF_LIST, record); } catch (EOFException ex) { compactOrThrow(); insert(record); return; } // add new link to the end of the list if (prevElemPtr == INVALID_PREV_POINTER) { // list was empty bucketSegments[bucketSegmentIndex].putLong(bucketOffset, pointerToAppended); } else { // update the pointer of the last element of the list. recordArea.overwritePointerAt(prevElemPtr, pointerToAppended); } numElements++; resizeTableIfNecessary(); } } /** * WARNING: Doing any other operation on the table invalidates the iterator! (Even using * getMatchFor of a prober!) */ public final class EntryIterator implements MutableObjectIterator { private final long endPosition; public EntryIterator() { endPosition = recordArea.getAppendPosition(); if (endPosition == 0) { return; } recordArea.setReadPosition(0); } @Override public T next(T reuse) throws IOException { if (endPosition != 0 && recordArea.getReadPosition() < endPosition) { // Loop until we find a non-abandoned record. // Note: the last record in the record area can't be abandoned. while (!closed) { final long pointerOrNegatedLength = recordArea.readPointer(); final boolean isAbandoned = pointerOrNegatedLength < 0; if (!isAbandoned) { reuse = recordArea.readRecord(reuse); return reuse; } else { // pointerOrNegatedLength is storing a length, because the record was // abandoned. recordArea.skipBytesToRead((int) -(pointerOrNegatedLength + 1)); } } return null; // (we were closed) } else { return null; } } @Override public T next() throws IOException { return next(buildSideSerializer.createInstance()); } } /** * A facade for doing such operations on the hash table that are needed for a reduce operator * driver. */ public final class ReduceFacade { private final HashTableProber prober; private final boolean objectReuseEnabled; private final ReduceFunction reducer; private final Collector outputCollector; private T reuse; public ReduceFacade( ReduceFunction reducer, Collector outputCollector, boolean objectReuseEnabled) { this.reducer = reducer; this.outputCollector = outputCollector; this.objectReuseEnabled = objectReuseEnabled; this.prober = getProber( buildSideComparator, new SameTypePairComparator<>(buildSideComparator)); this.reuse = buildSideSerializer.createInstance(); } /** * Looks up the table entry that has the same key as the given record, and updates it by * performing a reduce step. * * @param record The record to update. * @throws Exception */ public void updateTableEntryWithReduce(T record) throws Exception { T match = prober.getMatchFor(record, reuse); if (match == null) { prober.insertAfterNoMatch(record); } else { // do the reduce step T res = reducer.reduce(match, record); // We have given reuse to the reducer UDF, so create new one if object reuse is // disabled if (!objectReuseEnabled) { reuse = buildSideSerializer.createInstance(); } prober.updateMatch(res); } } /** Emits all elements currently held by the table to the collector. */ public void emit() throws IOException { T record = buildSideSerializer.createInstance(); EntryIterator iter = getEntryIterator(); while ((record = iter.next(record)) != null && !closed) { outputCollector.collect(record); if (!objectReuseEnabled) { record = buildSideSerializer.createInstance(); } } } /** * Emits all elements currently held by the table to the collector, and resets the table. * The table will have the same number of buckets as before the reset, to avoid doing * resizes again. */ public void emitAndReset() throws IOException { final int oldNumBucketSegments = bucketSegments.length; emit(); close(); open(oldNumBucketSegments); } } }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy