All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.flink.runtime.operators.hash.InPlaceMutableHashTable Maven / Gradle / Ivy

There is a newer version: 1.13.6
Show newest version
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.runtime.operators.hash;

import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.typeutils.SameTypePairComparator;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypePairComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.runtime.io.disk.RandomAccessInputView;
import org.apache.flink.runtime.memory.AbstractPagedOutputView;
import org.apache.flink.util.MathUtils;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MutableObjectIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.EOFException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * This hash table supports updating elements. If the new element has the same size as the old element, then
 * the update is done in-place. Otherwise a hole is created at the place of the old record, which will
 * eventually be removed by a compaction.
 *
 * The memory is divided into three areas:
 *  - Bucket area: they contain bucket heads:
 *    an 8 byte pointer to the first link of a linked list in the record area
 *  - Record area: this contains the actual data in linked list elements. A linked list element starts
 *    with an 8 byte pointer to the next element, and then the record follows.
 *  - Staging area: This is a small, temporary storage area for writing updated records. This is needed,
 *    because before serializing a record, there is no way to know in advance how large will it be.
 *    Therefore, we can't serialize directly into the record area when we are doing an update, because
 *    if it turns out to be larger than the old record, then it would override some other record
 *    that happens to be after the old one in memory. The solution is to serialize to the staging area first,
 *    and then copy it to the place of the original if it has the same size, otherwise allocate a new linked
 *    list element at the end of the record area, and mark the old one as abandoned. This creates "holes" in
 *    the record area, so compactions are eventually needed.
 *
 * Compaction happens by deleting everything in the bucket area, and then reinserting all elements.
 * The reinsertion happens by forgetting the structure (the linked lists) of the record area, and reading it
 * sequentially, and inserting all non-abandoned records, starting from the beginning of the record area.
 * Note, that insertions never override a record that hasn't been read by the reinsertion sweep, because
 * both the insertions and readings happen sequentially in the record area, and the insertions obviously
 * never overtake the reading sweep.
 *
 * Note: we have to abandon the old linked list element even when the updated record has a smaller size
 * than the original, because otherwise we wouldn't know where the next record starts during a reinsertion
 * sweep.
 *
 * The number of buckets depends on how large are the records. The serializer might be able to tell us this,
 * so in this case, we will calculate the number of buckets upfront, and won't do resizes.
 * If the serializer doesn't know the size, then we start with a small number of buckets, and do resizes as more
 * elements are inserted than the number of buckets.
 *
 * The number of memory segments given to the staging area is usually one, because it just needs to hold
 * one record.
 *
 * Note: For hashing, we couldn't just take the lower bits, but have to use a proper hash function from
 * MathUtils because of its avalanche property, so that changing only some high bits of
 * the original value won't leave the lower bits of the hash unaffected.
 * This is because when choosing the bucket for a record, we mask only the
 * lower bits (see numBucketsMask). Lots of collisions would occur when, for example,
 * the original value that is hashed is some bitset, where lots of different values
 * that are different only in the higher bits will actually occur.
 */

public class InPlaceMutableHashTable extends AbstractMutableHashTable {

	private static final Logger LOG = LoggerFactory.getLogger(InPlaceMutableHashTable.class);

	/** The minimum number of memory segments InPlaceMutableHashTable needs to be supplied with in order to work. */
	private static final int MIN_NUM_MEMORY_SEGMENTS = 3;

	// Note: the following two constants can't be negative, because negative values are reserved for storing the
	// negated size of the record, when it is abandoned (not part of any linked list).

	/** The last link in the linked lists will have this as next pointer. */
	private static final long END_OF_LIST = Long.MAX_VALUE;

	/** This value means that prevElemPtr is "pointing to the bucket head", and not into the record segments. */
	private static final long INVALID_PREV_POINTER = Long.MAX_VALUE - 1;


	private static final long RECORD_OFFSET_IN_LINK = 8;


	/**
	 * This initially contains all the memory we have, and then segments
	 * are taken from it by bucketSegments, recordArea, and stagingSegments.
	 */
	private final ArrayList freeMemorySegments;

	private final int numAllMemorySegments;

	private final int segmentSize;

	/**
	 * These will contain the bucket heads.
	 * The bucket heads are pointers to the linked lists containing the actual records.
	 */
	private MemorySegment[] bucketSegments;

	private static final int bucketSize = 8, bucketSizeBits = 3;

	private int numBuckets;
	private int numBucketsMask;
	private final int numBucketsPerSegment, numBucketsPerSegmentBits, numBucketsPerSegmentMask;

	/**
	 * The segments where the actual data is stored.
	 */
	private final RecordArea recordArea;

	/**
	 * Segments for the staging area.
	 * (It should contain at most one record at all times.)
	 */
	private final ArrayList stagingSegments;
	private final RandomAccessInputView stagingSegmentsInView;
	private final StagingOutputView stagingSegmentsOutView;

	private T reuse;

	/** This is the internal prober that insertOrReplaceRecord uses. */
	private final HashTableProber prober;

	/** The number of elements currently held by the table. */
	private long numElements = 0;

	/** The number of bytes wasted by updates that couldn't overwrite the old record due to size change. */
	private long holes = 0;

	/**
	 * If the serializer knows the size of the records, then we can calculate the optimal number of buckets
	 * upfront, so we don't need resizes.
	 */
	private boolean enableResize;


	public InPlaceMutableHashTable(TypeSerializer serializer, TypeComparator comparator, List memory) {
		super(serializer, comparator);
		this.numAllMemorySegments = memory.size();
		this.freeMemorySegments = new ArrayList<>(memory);

		// some sanity checks first
		if (freeMemorySegments.size() < MIN_NUM_MEMORY_SEGMENTS) {
			throw new IllegalArgumentException("Too few memory segments provided. InPlaceMutableHashTable needs at least " +
				MIN_NUM_MEMORY_SEGMENTS + " memory segments.");
		}

		// Get the size of the first memory segment and record it. All further buffers must have the same size.
		// the size must also be a power of 2
		segmentSize = freeMemorySegments.get(0).size();
		if ( (segmentSize & segmentSize - 1) != 0) {
			throw new IllegalArgumentException("Hash Table requires buffers whose size is a power of 2.");
		}

		this.numBucketsPerSegment = segmentSize / bucketSize;
		this.numBucketsPerSegmentBits = MathUtils.log2strict(this.numBucketsPerSegment);
		this.numBucketsPerSegmentMask = (1 << this.numBucketsPerSegmentBits) - 1;

		recordArea = new RecordArea(segmentSize);

		stagingSegments = new ArrayList<>();
		stagingSegments.add(forcedAllocateSegment());
		stagingSegmentsInView = new RandomAccessInputView(stagingSegments, segmentSize);
		stagingSegmentsOutView = new StagingOutputView(stagingSegments, segmentSize);

		prober = new HashTableProber<>(buildSideComparator, new SameTypePairComparator<>(buildSideComparator));

		enableResize = buildSideSerializer.getLength() == -1;
	}

	/**
	 * Gets the total capacity of this hash table, in bytes.
	 *
	 * @return The hash table's total capacity.
	 */
	public long getCapacity() {
		return numAllMemorySegments * (long)segmentSize;
	}

	/**
	 * Gets the number of bytes currently occupied in this hash table.
	 *
	 * @return The number of bytes occupied.
	 */
	public long getOccupancy() {
		return numAllMemorySegments * segmentSize - freeMemorySegments.size() * segmentSize;
	}

	private void open(int numBucketSegments) {
		synchronized (stateLock) {
			if (!closed) {
				throw new IllegalStateException("currently not closed.");
			}
			closed = false;
		}

		allocateBucketSegments(numBucketSegments);

		stagingSegments.add(forcedAllocateSegment());

		reuse = buildSideSerializer.createInstance();
	}

	/**
	 * Initialize the hash table
	 */
	@Override
	public void open() {
		open(calcInitialNumBucketSegments());
	}

	@Override
	public void close() {
		// make sure that we close only once
		synchronized (stateLock) {
			if (closed) {
				// We have to do this here, because the ctor already allocates a segment to the record area and
				// the staging area, even before we are opened. So we might have segments to free, even if we
				// are closed.
				recordArea.giveBackSegments();
				freeMemorySegments.addAll(stagingSegments);
				stagingSegments.clear();

				return;
			}
			closed = true;
		}

		LOG.debug("Closing InPlaceMutableHashTable and releasing resources.");

		releaseBucketSegments();

		recordArea.giveBackSegments();

		freeMemorySegments.addAll(stagingSegments);
		stagingSegments.clear();

		numElements = 0;
		holes = 0;
	}

	@Override
	public void abort() {
		LOG.debug("Aborting InPlaceMutableHashTable.");
		close();
	}

	@Override
	public List getFreeMemory() {
		if (!this.closed) {
			throw new IllegalStateException("Cannot return memory while InPlaceMutableHashTable is open.");
		}

		return freeMemorySegments;
	}

	private int calcInitialNumBucketSegments() {
		int recordLength = buildSideSerializer.getLength();
		double fraction; // fraction of memory to use for the buckets
		if (recordLength == -1) {
			// We don't know the record length, so we start with a small number of buckets, and do resizes if
			// necessary.
			// It seems that resizing is quite efficient, so we can err here on the too few bucket segments side.
			// Even with small records, we lose only ~15% speed.
			fraction = 0.1;
		} else {
			// We know the record length, so we can find a good value for the number of buckets right away, and
			// won't need any resizes later. (enableResize is false in this case, so no resizing will happen.)
			// Reasoning behind the formula:
			// We are aiming for one bucket per record, and one bucket contains one 8 byte pointer. The total
			// memory overhead of an element will be approximately 8+8 bytes, as the record in the record area
			// is preceded by a pointer (for the linked list).
			fraction = 8.0 / (16 + recordLength);
		}

		// We make the number of buckets a power of 2 so that taking modulo is efficient.
		int ret = Math.max(1, MathUtils.roundDownToPowerOf2((int)(numAllMemorySegments * fraction)));

		// We can't handle more than Integer.MAX_VALUE buckets (eg. because hash functions return int)
		if ((long)ret * numBucketsPerSegment > Integer.MAX_VALUE) {
			ret = MathUtils.roundDownToPowerOf2(Integer.MAX_VALUE / numBucketsPerSegment);
		}
		return ret;
	}

	private void allocateBucketSegments(int numBucketSegments) {
		if (numBucketSegments < 1) {
			throw new RuntimeException("Bug in InPlaceMutableHashTable");
		}

		bucketSegments = new MemorySegment[numBucketSegments];
		for(int i = 0; i < bucketSegments.length; i++) {
			bucketSegments[i] = forcedAllocateSegment();
			// Init all pointers in all buckets to END_OF_LIST
			for(int j = 0; j < numBucketsPerSegment; j++) {
				bucketSegments[i].putLong(j << bucketSizeBits, END_OF_LIST);
			}
		}
		numBuckets = numBucketSegments * numBucketsPerSegment;
		numBucketsMask = (1 << MathUtils.log2strict(numBuckets)) - 1;
	}

	private void releaseBucketSegments() {
		freeMemorySegments.addAll(Arrays.asList(bucketSegments));
		bucketSegments = null;
	}

	private MemorySegment allocateSegment() {
		int s = freeMemorySegments.size();
		if (s > 0) {
			return freeMemorySegments.remove(s - 1);
		} else {
			return null;
		}
	}

	private MemorySegment forcedAllocateSegment() {
		MemorySegment segment = allocateSegment();
		if (segment == null) {
			throw new RuntimeException("Bug in InPlaceMutableHashTable: A free segment should have been available.");
		}
		return segment;
	}

	/**
	 * Searches the hash table for a record with the given key.
	 * If it is found, then it is overridden with the specified record.
	 * Otherwise, the specified record is inserted.
	 * @param record The record to insert or to replace with.
	 * @throws IOException (EOFException specifically, if memory ran out)
     */
	@Override
	public void insertOrReplaceRecord(T record) throws IOException {
		if (closed) {
			return;
		}

		T match = prober.getMatchFor(record, reuse);
		if (match == null) {
			prober.insertAfterNoMatch(record);
		} else {
			prober.updateMatch(record);
		}
	}

	/**
	 * Inserts the given record into the hash table.
	 * Note: this method doesn't care about whether a record with the same key is already present.
	 * @param record The record to insert.
	 * @throws IOException (EOFException specifically, if memory ran out)
     */
	@Override
	public void insert(T record) throws IOException {
		if (closed) {
			return;
		}

		final int hashCode = MathUtils.jenkinsHash(buildSideComparator.hash(record));
		final int bucket = hashCode & numBucketsMask;
		final int bucketSegmentIndex = bucket >>> numBucketsPerSegmentBits; // which segment contains the bucket
		final MemorySegment bucketSegment = bucketSegments[bucketSegmentIndex];
		final int bucketOffset = (bucket & numBucketsPerSegmentMask) << bucketSizeBits; // offset of the bucket in the segment
		final long firstPointer = bucketSegment.getLong(bucketOffset);

		try {
			final long newFirstPointer = recordArea.appendPointerAndRecord(firstPointer, record);
			bucketSegment.putLong(bucketOffset, newFirstPointer);
		} catch (EOFException ex) {
			compactOrThrow();
			insert(record);
			return;
		}

		numElements++;
		resizeTableIfNecessary();
	}

	private void resizeTableIfNecessary() throws IOException {
		if (enableResize && numElements > numBuckets) {
			final long newNumBucketSegments = 2L * bucketSegments.length;
			// Checks:
			// - we can't handle more than Integer.MAX_VALUE buckets
			// - don't take more memory than the free memory we have left
			// - the buckets shouldn't occupy more than half of all our memory
			if (newNumBucketSegments * numBucketsPerSegment < Integer.MAX_VALUE &&
				newNumBucketSegments - bucketSegments.length < freeMemorySegments.size() &&
				newNumBucketSegments < numAllMemorySegments / 2) {
				// do the resize
				rebuild(newNumBucketSegments);
			}
		}
	}

	/**
	 * Returns an iterator that can be used to iterate over all the elements in the table.
	 * WARNING: Doing any other operation on the table invalidates the iterator! (Even
	 * using getMatchFor of a prober!)
	 * @return the iterator
     */
	@Override
	public EntryIterator getEntryIterator() {
		return new EntryIterator();
	}

	public  HashTableProber getProber(TypeComparator probeTypeComparator, TypePairComparator pairComparator) {
		return new HashTableProber<>(probeTypeComparator, pairComparator);
	}

	/**
	 * This function reinitializes the bucket segments,
	 * reads all records from the record segments (sequentially, without using the pointers or the buckets),
	 * and rebuilds the hash table.
	 */
	private void rebuild() throws IOException {
		rebuild(bucketSegments.length);
	}

	/** Same as above, but the number of bucket segments of the new table can be specified. */
	private void rebuild(long newNumBucketSegments) throws IOException {
		// Get new bucket segments
		releaseBucketSegments();
		allocateBucketSegments((int)newNumBucketSegments);

		T record = buildSideSerializer.createInstance();
		try {
			EntryIterator iter = getEntryIterator();
			recordArea.resetAppendPosition();
			recordArea.setWritePosition(0);
			while ((record = iter.next(record)) != null && !closed) {
				final int hashCode = MathUtils.jenkinsHash(buildSideComparator.hash(record));
				final int bucket = hashCode & numBucketsMask;
				final int bucketSegmentIndex = bucket >>> numBucketsPerSegmentBits; // which segment contains the bucket
				final MemorySegment bucketSegment = bucketSegments[bucketSegmentIndex];
				final int bucketOffset = (bucket & numBucketsPerSegmentMask) << bucketSizeBits; // offset of the bucket in the segment
				final long firstPointer = bucketSegment.getLong(bucketOffset);

				long ptrToAppended = recordArea.noSeekAppendPointerAndRecord(firstPointer, record);
				bucketSegment.putLong(bucketOffset, ptrToAppended);
			}
			recordArea.freeSegmentsAfterAppendPosition();
			holes = 0;

		} catch (EOFException ex) {
			throw new RuntimeException("Bug in InPlaceMutableHashTable: we shouldn't get out of memory during a rebuild, " +
				"because we aren't allocating any new memory.");
		}
	}

	/**
	 * If there is wasted space (due to updated records not fitting in their old places), then do a compaction.
	 * Else, throw EOFException to indicate that memory ran out.
	 * @throws IOException
	 */
	private void compactOrThrow() throws IOException {
		if (holes > (double)recordArea.getTotalSize() * 0.05) {
			rebuild();
		} else {
			throw new EOFException("InPlaceMutableHashTable memory ran out. " + getMemoryConsumptionString());
		}
	}

	/**
	 * @return String containing a summary of the memory consumption for error messages
	 */
	private String getMemoryConsumptionString() {
		return "InPlaceMutableHashTable memory stats:\n" +
			"Total memory:     " + numAllMemorySegments * segmentSize + "\n" +
			"Free memory:      " + freeMemorySegments.size() * segmentSize + "\n" +
			"Bucket area:      " + numBuckets * 8  + "\n" +
			"Record area:      " + recordArea.getTotalSize() + "\n" +
			"Staging area:     " + stagingSegments.size() * segmentSize + "\n" +
			"Num of elements:  " + numElements + "\n" +
			"Holes total size: " + holes;
	}


	/**
	 * This class encapsulates the memory segments that belong to the record area. It
	 *  - can append a record
	 *  - can overwrite a record at an arbitrary position (WARNING: the new record must have the same size
	 *    as the old one)
	 *  - can be rewritten by calling resetAppendPosition
	 *  - takes memory from InPlaceMutableHashTable.freeMemorySegments on append
	 */
	private final class RecordArea
	{
		private final ArrayList segments = new ArrayList<>();

		private final RecordAreaOutputView outView;
		private final RandomAccessInputView inView;

		private final int segmentSizeBits;
		private final int segmentSizeMask;

		private long appendPosition = 0;


		public RecordArea(int segmentSize) {
			int segmentSizeBits = MathUtils.log2strict(segmentSize);

			if ((segmentSize & (segmentSize - 1)) != 0) {
				throw new IllegalArgumentException("Segment size must be a power of 2!");
			}

			this.segmentSizeBits = segmentSizeBits;
			this.segmentSizeMask = segmentSize - 1;

			outView = new RecordAreaOutputView(segmentSize);
			try {
				addSegment();
			} catch (EOFException ex) {
				throw new RuntimeException("Bug in InPlaceMutableHashTable: we should have caught it earlier " +
					"that we don't have enough segments.");
			}
			inView = new RandomAccessInputView(segments, segmentSize);
		}


		private void addSegment() throws EOFException {
			MemorySegment m = allocateSegment();
			if (m == null) {
				throw new EOFException();
			}
			segments.add(m);
		}

		/**
		 * Moves all its memory segments to freeMemorySegments.
		 * Warning: this will leave the RecordArea in an unwritable state: you have to
		 * call setWritePosition before writing again.
		 */
		public void giveBackSegments() {
			freeMemorySegments.addAll(segments);
			segments.clear();

			resetAppendPosition();
		}

		public long getTotalSize() {
			return segments.size() * (long)segmentSize;
		}

		// ----------------------- Output -----------------------

		private void setWritePosition(long position) throws EOFException {
			if (position > appendPosition) {
				throw new IndexOutOfBoundsException();
			}

			final int segmentIndex = (int) (position >>> segmentSizeBits);
			final int offset = (int) (position & segmentSizeMask);

			// If position == appendPosition and the last buffer is full,
			// then we will be seeking to the beginning of a new segment
			if (segmentIndex == segments.size()) {
				addSegment();
			}

			outView.currentSegmentIndex = segmentIndex;
			outView.seekOutput(segments.get(segmentIndex), offset);
		}

		/**
		 * Sets appendPosition and the write position to 0, so that appending starts
		 * overwriting elements from the beginning. (This is used in rebuild.)
		 *
		 * Note: if data was written to the area after the current appendPosition
		 * before a call to resetAppendPosition, it should still be readable. To
		 * release the segments after the current append position, call
		 * freeSegmentsAfterAppendPosition()
		 */
		public void resetAppendPosition() {
			appendPosition = 0;

			// this is just for safety (making sure that we fail immediately
			// if a write happens without calling setWritePosition)
			outView.currentSegmentIndex = -1;
			outView.seekOutput(null, -1);
		}

		/**
		 * Releases the memory segments that are after the current append position.
		 * Note: The situation that there are segments after the current append position
		 * can arise from a call to resetAppendPosition().
		 */
		public void freeSegmentsAfterAppendPosition() {
			final int appendSegmentIndex = (int)(appendPosition >>> segmentSizeBits);
			while (segments.size() > appendSegmentIndex + 1 && !closed) {
				freeMemorySegments.add(segments.get(segments.size() - 1));
				segments.remove(segments.size() - 1);
			}
		}

		/**
		 * Overwrites the long value at the specified position.
		 * @param pointer Points to the position to overwrite.
		 * @param value The value to write.
		 * @throws IOException
		 */
		public void overwritePointerAt(long pointer, long value) throws IOException {
			setWritePosition(pointer);
			outView.writeLong(value);
		}

		/**
		 * Overwrites a record at the specified position. The record is read from a DataInputView  (this will be the staging area).
		 * WARNING: The record must not be larger than the original record.
		 * @param pointer Points to the position to overwrite.
		 * @param input The DataInputView to read the record from
		 * @param size The size of the record
		 * @throws IOException
		 */
		public void overwriteRecordAt(long pointer, DataInputView input, int size) throws IOException {
			setWritePosition(pointer);
			outView.write(input, size);
		}

		/**
		 * Appends a pointer and a record. The record is read from a DataInputView (this will be the staging area).
		 * @param pointer The pointer to write (Note: this is NOT the position to write to!)
		 * @param input The DataInputView to read the record from
		 * @param recordSize The size of the record
		 * @return A pointer to the written data
		 * @throws IOException (EOFException specifically, if memory ran out)
		 */
		public long appendPointerAndCopyRecord(long pointer, DataInputView input, int recordSize) throws IOException {
			setWritePosition(appendPosition);
			final long oldLastPosition = appendPosition;
			outView.writeLong(pointer);
			outView.write(input, recordSize);
			appendPosition += 8 + recordSize;
			return oldLastPosition;
		}

		/**
		 * Appends a pointer and a record.
		 * @param pointer The pointer to write (Note: this is NOT the position to write to!)
		 * @param record The record to write
		 * @return A pointer to the written data
		 * @throws IOException (EOFException specifically, if memory ran out)
		 */
		public long appendPointerAndRecord(long pointer, T record) throws IOException {
			setWritePosition(appendPosition);
			return noSeekAppendPointerAndRecord(pointer, record);
		}

		/**
		 * Appends a pointer and a record. Call this function only if the write position is at the end!
		 * @param pointer The pointer to write (Note: this is NOT the position to write to!)
		 * @param record The record to write
		 * @return A pointer to the written data
		 * @throws IOException (EOFException specifically, if memory ran out)
		 */
		public long noSeekAppendPointerAndRecord(long pointer, T record) throws IOException {
			final long oldLastPosition = appendPosition;
			final long oldPositionInSegment = outView.getCurrentPositionInSegment();
			final long oldSegmentIndex = outView.currentSegmentIndex;
			outView.writeLong(pointer);
			buildSideSerializer.serialize(record, outView);
			appendPosition += outView.getCurrentPositionInSegment() - oldPositionInSegment +
				outView.getSegmentSize() * (outView.currentSegmentIndex - oldSegmentIndex);
			return oldLastPosition;
		}

		public long getAppendPosition() {
			return appendPosition;
		}

		// ----------------------- Input -----------------------

		public void setReadPosition(long position) {
			inView.setReadPosition(position);
		}

		public long getReadPosition() {
			return inView.getReadPosition();
		}

		/**
		 * Note: this is sometimes a negated length instead of a pointer (see HashTableProber.updateMatch).
		 */
		public long readPointer() throws IOException {
			return inView.readLong();
		}

		public T readRecord(T reuse) throws IOException {
			return buildSideSerializer.deserialize(reuse, inView);
		}

		public void skipBytesToRead(int numBytes) throws IOException {
			inView.skipBytesToRead(numBytes);
		}

		// -----------------------------------------------------

		private final class RecordAreaOutputView extends AbstractPagedOutputView {

			public int currentSegmentIndex;

			public RecordAreaOutputView(int segmentSize) {
				super(segmentSize, 0);
			}

			@Override
			protected MemorySegment nextSegment(MemorySegment current, int positionInCurrent) throws EOFException {
				currentSegmentIndex++;
				if (currentSegmentIndex == segments.size()) {
					addSegment();
				}
				return segments.get(currentSegmentIndex);
			}

			@Override
			public void seekOutput(MemorySegment seg, int position) {
				super.seekOutput(seg, position);
			}
		}
	}


	private final class StagingOutputView extends AbstractPagedOutputView {

		private final ArrayList segments;

		private final int segmentSizeBits;

		private int currentSegmentIndex;


		public StagingOutputView(ArrayList segments, int segmentSize)
		{
			super(segmentSize, 0);
			this.segmentSizeBits = MathUtils.log2strict(segmentSize);
			this.segments = segments;
		}

		/**
		 * Seeks to the beginning.
		 */
		public void reset() {
			seekOutput(segments.get(0), 0);
			currentSegmentIndex = 0;
		}

		@Override
		protected MemorySegment nextSegment(MemorySegment current, int positionInCurrent) throws EOFException {
			currentSegmentIndex++;
			if (currentSegmentIndex == segments.size()) {
				MemorySegment m = allocateSegment();
				if (m == null) {
					throw new EOFException();
				}
				segments.add(m);
			}
			return segments.get(currentSegmentIndex);
		}

		public long getWritePosition() {
			return (((long) currentSegmentIndex) << segmentSizeBits) + getCurrentPositionInSegment();
		}
	}


	/**
	 * A prober for accessing the table.
	 * In addition to getMatchFor and updateMatch, it also has insertAfterNoMatch.
	 * Warning: Don't modify the table between calling getMatchFor and the other methods!
	 * @param  The type of the records that we are probing with
     */
	public final class HashTableProber extends AbstractHashTableProber{

		public HashTableProber(TypeComparator probeTypeComparator, TypePairComparator pairComparator) {
			super(probeTypeComparator, pairComparator);
		}

		private int bucketSegmentIndex;
		private int bucketOffset;
		private long curElemPtr;
		private long prevElemPtr;
		private long nextPtr;
		private long recordEnd;

		/**
		 * Searches the hash table for the record with the given key.
		 * (If there would be multiple matches, only one is returned.)
		 * @param record The record whose key we are searching for
		 * @param targetForMatch If a match is found, it will be written here
         * @return targetForMatch if a match is found, otherwise null.
         */
		@Override
		public T getMatchFor(PT record, T targetForMatch) {
			if (closed) {
				return null;
			}

			final int hashCode = MathUtils.jenkinsHash(probeTypeComparator.hash(record));
			final int bucket = hashCode & numBucketsMask;
			bucketSegmentIndex = bucket >>> numBucketsPerSegmentBits; // which segment contains the bucket
			final MemorySegment bucketSegment = bucketSegments[bucketSegmentIndex];
			bucketOffset = (bucket & numBucketsPerSegmentMask) << bucketSizeBits; // offset of the bucket in the segment

			curElemPtr = bucketSegment.getLong(bucketOffset);

			pairComparator.setReference(record);

			T currentRecordInList = targetForMatch;

			prevElemPtr = INVALID_PREV_POINTER;
			try {
				while (curElemPtr != END_OF_LIST && !closed) {
					recordArea.setReadPosition(curElemPtr);
					nextPtr = recordArea.readPointer();

					currentRecordInList = recordArea.readRecord(currentRecordInList);
					recordEnd = recordArea.getReadPosition();
					if (pairComparator.equalToReference(currentRecordInList)) {
						// we found an element with a matching key, and not just a hash collision
						return currentRecordInList;
					}

					prevElemPtr = curElemPtr;
					curElemPtr = nextPtr;
				}
			} catch (IOException ex) {
				throw new RuntimeException("Error deserializing record from the hashtable: " + ex.getMessage(), ex);
			}
			return null;
		}

		@Override
		public T getMatchFor(PT probeSideRecord) {
			return getMatchFor(probeSideRecord, buildSideSerializer.createInstance());
		}

		/**
		 * This method can be called after getMatchFor returned a match.
		 * It will overwrite the record that was found by getMatchFor.
		 * Warning: The new record should have the same key as the old!
		 * WARNING; Don't do any modifications to the table between
		 * getMatchFor and updateMatch!
		 * @param newRecord The record to override the old record with.
		 * @throws IOException (EOFException specifically, if memory ran out)
         */
		@Override
		public void updateMatch(T newRecord) throws IOException {
			if (closed) {
				return;
			}
			if (curElemPtr == END_OF_LIST) {
				throw new RuntimeException("updateMatch was called after getMatchFor returned no match");
			}

			try {
				// determine the new size
				stagingSegmentsOutView.reset();
				buildSideSerializer.serialize(newRecord, stagingSegmentsOutView);
				final int newRecordSize = (int)stagingSegmentsOutView.getWritePosition();
				stagingSegmentsInView.setReadPosition(0);

				// Determine the size of the place of the old record.
				final int oldRecordSize = (int)(recordEnd - (curElemPtr + RECORD_OFFSET_IN_LINK));

				if (newRecordSize == oldRecordSize) {
					// overwrite record at its original place
					recordArea.overwriteRecordAt(curElemPtr + RECORD_OFFSET_IN_LINK, stagingSegmentsInView, newRecordSize);
				} else {
					// new record has a different size than the old one, append new at the end of the record area.
					// Note: we have to do this, even if the new record is smaller, because otherwise EntryIterator
					// wouldn't know the size of this place, and wouldn't know where does the next record start.

					final long pointerToAppended =
						recordArea.appendPointerAndCopyRecord(nextPtr, stagingSegmentsInView, newRecordSize);

					// modify the pointer in the previous link
					if (prevElemPtr == INVALID_PREV_POINTER) {
						// list had only one element, so prev is in the bucketSegments
						bucketSegments[bucketSegmentIndex].putLong(bucketOffset, pointerToAppended);
					} else {
						recordArea.overwritePointerAt(prevElemPtr, pointerToAppended);
					}

					// write the negated size of the hole to the place where the next pointer was, so that EntryIterator
					// will know the size of the place without reading the old record.
					// The negative sign will mean that the record is abandoned, and the
					// the -1 is for avoiding trouble in case of a record having 0 size. (though I think this should
					// never actually happen)
					// Note: the last record in the record area can't be abandoned. (EntryIterator makes use of this fact.)
					recordArea.overwritePointerAt(curElemPtr, -oldRecordSize - 1);

					holes += oldRecordSize;
				}
			} catch (EOFException ex) {
				compactOrThrow();
				insertOrReplaceRecord(newRecord);
			}
		}

		/**
		 * This method can be called after getMatchFor returned null.
		 * It inserts the given record to the hash table.
		 * Important: The given record should have the same key as the record
		 * that was given to getMatchFor!
		 * WARNING; Don't do any modifications to the table between
		 * getMatchFor and insertAfterNoMatch!
		 * @throws IOException (EOFException specifically, if memory ran out)
		 */
		public void insertAfterNoMatch(T record) throws IOException {
			if (closed) {
				return;
			}

			// create new link
			long pointerToAppended;
			try {
				pointerToAppended = recordArea.appendPointerAndRecord(END_OF_LIST ,record);
			} catch (EOFException ex) {
				compactOrThrow();
				insert(record);
				return;
			}

			// add new link to the end of the list
			if (prevElemPtr == INVALID_PREV_POINTER) {
				// list was empty
				bucketSegments[bucketSegmentIndex].putLong(bucketOffset, pointerToAppended);
			} else {
				// update the pointer of the last element of the list.
				recordArea.overwritePointerAt(prevElemPtr, pointerToAppended);
			}

			numElements++;
			resizeTableIfNecessary();
		}
	}


	/**
	 * WARNING: Doing any other operation on the table invalidates the iterator! (Even
	 * using getMatchFor of a prober!)
	 */
	public final class EntryIterator implements MutableObjectIterator {

		private final long endPosition;

		public EntryIterator() {
			endPosition = recordArea.getAppendPosition();
			if (endPosition == 0) {
				return;
			}
			recordArea.setReadPosition(0);
		}

		@Override
		public T next(T reuse) throws IOException {
			if (endPosition != 0 && recordArea.getReadPosition() < endPosition) {
				// Loop until we find a non-abandoned record.
				// Note: the last record in the record area can't be abandoned.
				while (!closed) {
					final long pointerOrNegatedLength = recordArea.readPointer();
					final boolean isAbandoned = pointerOrNegatedLength < 0;
					if (!isAbandoned) {
						reuse = recordArea.readRecord(reuse);
						return reuse;
					} else {
						// pointerOrNegatedLength is storing a length, because the record was abandoned.
						recordArea.skipBytesToRead((int)-(pointerOrNegatedLength + 1));
					}
				}
				return null; // (we were closed)
			} else {
				return null;
			}
		}

		@Override
		public T next() throws IOException {
			return next(buildSideSerializer.createInstance());
		}
	}

	/**
	 * A facade for doing such operations on the hash table that are needed for a reduce operator driver.
	 */
	public final class ReduceFacade {

		private final HashTableProber prober;

		private final boolean objectReuseEnabled;

		private final ReduceFunction reducer;

		private final Collector outputCollector;

		private T reuse;


		public ReduceFacade(ReduceFunction reducer, Collector outputCollector, boolean objectReuseEnabled) {
			this.reducer = reducer;
			this.outputCollector = outputCollector;
			this.objectReuseEnabled = objectReuseEnabled;
			this.prober = getProber(buildSideComparator, new SameTypePairComparator<>(buildSideComparator));
			this.reuse = buildSideSerializer.createInstance();
		}

		/**
		 * Looks up the table entry that has the same key as the given record, and updates it by performing
		 * a reduce step.
		 * @param record The record to update.
		 * @throws Exception
         */
		public void updateTableEntryWithReduce(T record) throws Exception {
			T match = prober.getMatchFor(record, reuse);
			if (match == null) {
				prober.insertAfterNoMatch(record);
			} else {
				// do the reduce step
				T res = reducer.reduce(match, record);

				// We have given reuse to the reducer UDF, so create new one if object reuse is disabled
				if (!objectReuseEnabled) {
					reuse = buildSideSerializer.createInstance();
				}

				prober.updateMatch(res);
			}
		}

		/**
		 * Emits all elements currently held by the table to the collector.
		 */
		public void emit() throws IOException {
			T record = buildSideSerializer.createInstance();
			EntryIterator iter = getEntryIterator();
			while ((record = iter.next(record)) != null && !closed) {
				outputCollector.collect(record);
				if (!objectReuseEnabled) {
					record = buildSideSerializer.createInstance();
				}
			}
		}

		/**
		 * Emits all elements currently held by the table to the collector,
		 * and resets the table. The table will have the same number of buckets
		 * as before the reset, to avoid doing resizes again.
		 */
		public void emitAndReset() throws IOException {
			final int oldNumBucketSegments = bucketSegments.length;
			emit();
			close();
			open(oldNumBucketSegments);
		}
	}
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy