All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apfloat.internal.StepCarryCRTStrategy Maven / Gradle / Ivy

There is a newer version: 1.14.0
Show newest version
package org.apfloat.internal;

import org.apfloat.ApfloatContext;
import org.apfloat.ApfloatRuntimeException;
import org.apfloat.spi.BuilderFactory;
import org.apfloat.spi.CarryCRTStrategy;
import org.apfloat.spi.CarryCRTStepStrategy;
import org.apfloat.spi.DataStorageBuilder;
import org.apfloat.spi.DataStorage;

/**
 * Class for performing the final step of a three-modulus
 * Number Theoretic Transform based convolution. Works with blocks of data.

* * The algorithm is parallelized for multiprocessor computers, * if the data fits in memory.

* * The parallelization works so that the carry-CRT is done in * blocks in parallel. As a final step, a second pass is done * through the data set to propagate the carries from one block * to the next.

* * All access to this class must be externally synchronized. * * @see CarryCRTStepStrategy * * @since 1.7.0 * @version 1.8.0 * @author Mikko Tommila */ public class StepCarryCRTStrategy implements CarryCRTStrategy, Parallelizable { // Runnable for calculating the carry-CRT in blocks, then get the carry of the previous block, add it, and provide carry to the next block private class CarryCRTRunnable implements Runnable { public CarryCRTRunnable(DataStorage resultMod0, DataStorage resultMod1, DataStorage resultMod2, DataStorage dataStorage, long size, long resultSize, long offset, long length, MessagePasser messagePasser, CarryCRTStepStrategy stepStrategy) { this.resultMod0 = resultMod0; this.resultMod1 = resultMod1; this.resultMod2 = resultMod2; this.dataStorage = dataStorage; this.size = size; this.resultSize = resultSize; this.offset = offset; this.length = length; this.messagePasser = messagePasser; this.stepStrategy = stepStrategy; } public void run() { T results = this.stepStrategy.crt(this.resultMod0, this.resultMod1, this.resultMod2, this.dataStorage, this.size, this.resultSize, this.offset, this.length); // Finishing step - get the carry from the previous block and propagate it through the data if (this.offset > 0) { T previousResults = this.messagePasser.receiveMessage(this.offset); results = this.stepStrategy.carry(this.dataStorage, this.size, this.resultSize, this.offset, this.length, results, previousResults); } // Finally, send the carry to the next block this.messagePasser.sendMessage(this.offset + this.length, results); // Last block sanity check if (this.offset + this.length == this.size) { assert (results != null); assert (java.lang.reflect.Array.getLength(results) == 2); assert (((Number) java.lang.reflect.Array.get(results, 0)).longValue() == 0); assert (((Number) java.lang.reflect.Array.get(results, 1)).longValue() == 0); } } private DataStorage resultMod0, resultMod1, resultMod2, dataStorage; private long size, resultSize, offset, length; private MessagePasser messagePasser; private CarryCRTStepStrategy stepStrategy; } /** * Creates a carry-CRT object using the specified radix. * * @param radix The radix that will be used. */ public StepCarryCRTStrategy(int radix) { this.radix = radix; } /** * Calculate the final result of a three-NTT convolution.

* * Performs a Chinese Remainder Theorem (CRT) on each element * of the three result data sets to get the result of each element * modulo the product of the three moduli. Then it calculates the carries * to get the final result.

* * Note that the return value's initial word may be zero or non-zero, * depending on how large the result is.

* * Assumes that MODULUS[0] > MODULUS[1] > MODULUS[2]. * * @param resultMod0 The result modulo MODULUS[0]. * @param resultMod1 The result modulo MODULUS[1]. * @param resultMod2 The result modulo MODULUS[2]. * @param resultSize The number of elements needed in the final result. * * @return The final result with the CRT performed and the carries calculated. */ public DataStorage carryCRT(final DataStorage resultMod0, final DataStorage resultMod1, final DataStorage resultMod2, final long resultSize) throws ApfloatRuntimeException { ApfloatContext ctx = ApfloatContext.getContext(); BuilderFactory builderFactory = ctx.getBuilderFactory(); Class elementArrayType = builderFactory.getElementArrayType(); return doCarryCRT(elementArrayType, resultMod0, resultMod1, resultMod2, resultSize); } private DataStorage doCarryCRT(Class elementArrayType, DataStorage resultMod0, DataStorage resultMod1, DataStorage resultMod2, long resultSize) throws ApfloatRuntimeException { long size = Math.min(resultSize + 2, resultMod0.getSize()); // Some extra precision if not full result is required ApfloatContext ctx = ApfloatContext.getContext(); BuilderFactory builderFactory = ctx.getBuilderFactory(); DataStorageBuilder dataStorageBuilder = builderFactory.getDataStorageBuilder(); final DataStorage dataStorage = dataStorageBuilder.createDataStorage(resultSize * builderFactory.getElementSize()); dataStorage.setSize(resultSize); ParallelRunnable parallelRunnable = createCarryCRTParallelRunnable(elementArrayType, resultMod0, resultMod1, resultMod2, dataStorage, size, resultSize); if (size <= Integer.MAX_VALUE && // Only if the size fits in an integer, but with memory arrays it should resultMod0.isCached() && // Only if the data storage supports efficient parallel random access resultMod1.isCached() && resultMod2.isCached() && dataStorage.isCached()) { ParallelRunner.runParallel(parallelRunnable); } else { parallelRunnable.getRunnable(0, size).run(); // Just run in current thread without parallelization } return dataStorage; } /** * Create a ParallelRunnable object for doing the carry-CRT in parallel. * * @param The element array type used. * @param elementArrayType The element array type used. * @param resultMod0 The result modulo MODULUS[0]. * @param resultMod1 The result modulo MODULUS[1]. * @param resultMod2 The result modulo MODULUS[2]. * @param dataStorage The destination data storage of the computation. * @param size The number of elements in the whole data set. * @param resultSize The number of elements needed in the final result. * * @return An suitable object for performing the carry-CRT in parallel. */ protected ParallelRunnable createCarryCRTParallelRunnable(Class elementArrayType, final DataStorage resultMod0, final DataStorage resultMod1, final DataStorage resultMod2, final DataStorage dataStorage, final long size, final long resultSize) { ApfloatContext ctx = ApfloatContext.getContext(); BuilderFactory builderFactory = ctx.getBuilderFactory(); final MessagePasser messagePasser = new MessagePasser(); final CarryCRTStepStrategy stepStrategy = builderFactory.getCarryCRTBuilder(elementArrayType).createCarryCRTSteps(this.radix); ParallelRunnable parallelRunnable = new ParallelRunnable(size) { public Runnable getRunnable(long offset, long length) { return new CarryCRTRunnable(resultMod0, resultMod1, resultMod2, dataStorage, size, resultSize, offset, length, messagePasser, stepStrategy); } }; return parallelRunnable; } private int radix; }





© 2015 - 2024 Weber Informatics LLC | Privacy Policy