org.apfloat.aparapi.IntKernel Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of apfloat-aparapi Show documentation
Show all versions of apfloat-aparapi Show documentation
Fast number theoretic transform implementation using Aparapi
/*
* MIT License
*
* Copyright (c) 2002-2021 Mikko Tommila
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package org.apfloat.aparapi;
import com.aparapi.Kernel;
import org.apfloat.ApfloatRuntimeException;
import org.apfloat.spi.ArrayAccess;
/**
* Kernel for the int
element type. Contains everything needed for the NTT.
* The data is organized in columns, not rows, for efficient processing on the GPU.
*
* Due to the extreme parallelization requirements (global size should be at lest 1024)
* this algorithm works efficiently only with 4 million decimal digit calculations or bigger.
* However with 4 million digits, it's only approximately as fast as the pure-Java
* version (depending on the GPU and CPU hardware). On the other hand, the algorithm
* mathematically only works up to about 226 million digits. So the useful range is only
* somewhere around 10-200 million digits.
*
* Some notes about the aparapi specific requirements for code that must be converted to OpenCL:
*
* assert()
does not work
* - Can't check for null
* - Can't get array length
* - Arrays referenced by the kernel can't be null even if they are not accessed
* - Arrays referenced by the kernel can't be zero-length even if they are not accessed
* - Can't invoke methods in other classes e.g. enclosing class of an inner class
* - Early return statements do not work
* - Variables used inside loops must be initialized before the loop
* - Must compile the class with full debug information i.e. with
-g
*
*
* @since 1.8.3
* @version 1.9.0
* @author Mikko Tommila
*/
class IntKernel
extends Kernel
{
private IntKernel()
{
}
public static IntKernel getInstance()
{
return IntKernel.kernel.get();
}
private static ThreadLocal kernel = ThreadLocal.withInitial(IntKernel::new);
// Methods for calculating the column transforms in parallel
public static final int TRANSFORM_ROWS = 1;
public static final int INVERSE_TRANSFORM_ROWS = 2;
public void setLength(int length)
{
this.length = length; // Transform length
}
public void setArrayAccess(ArrayAccess arrayAccess)
throws ApfloatRuntimeException
{
this.data = arrayAccess.getIntData();
this.offset = arrayAccess.getOffset();
if (this.length != 0)
{
this.stride = arrayAccess.getLength() / this.length;
}
}
public void setWTable(int[] wTable)
{
this.wTable = wTable;
}
public void setPermutationTable(int[] permutationTable)
{
this.permutationTable = (permutationTable == null ? new int[1] : permutationTable); // Zero-length array or null won't work
this.permutationTableLength = (permutationTable == null ? 0 : permutationTable.length);
}
private void columnTableFNT()
{
int nn, istep = 0, mmax = 0, r = 0;
int[] data = this.data;
int offset = this.offset + getGlobalId();
int stride = this.stride;
nn = this.length;
if (nn >= 2)
{
r = 1;
mmax = nn >> 1;
while (mmax > 0)
{
istep = mmax << 1;
// Optimize first step when wr = 1
for (int i = offset; i < offset + nn * stride; i += istep * stride)
{
int j = i + mmax * stride;
int a = data[i];
int b = data[j];
data[i] = modAdd(a, b);
data[j] = modSubtract(a, b);
}
int t = r;
for (int m = 1; m < mmax; m++)
{
for (int i = offset + m * stride; i < offset + nn * stride; i += istep * stride)
{
int j = i + mmax * stride;
int a = data[i];
int b = data[j];
data[i] = modAdd(a, b);
data[j] = modMultiply(this.wTable[t], modSubtract(a, b));
}
t += r;
}
r <<= 1;
mmax >>= 1;
}
if (this.permutationTableLength > 0)
{
columnScramble(offset);
}
}
}
private void inverseColumnTableFNT()
{
int nn, istep = 0, mmax = 0, r = 0;
int[] data = this.data;
int offset = this.offset + getGlobalId();
int stride = this.stride;
nn = this.length;
if (nn >= 2)
{
if (this.permutationTableLength > 0)
{
columnScramble(offset);
}
r = nn;
mmax = 1;
while (nn > mmax)
{
istep = mmax << 1;
r >>= 1;
// Optimize first step when w = 1
for (int i = offset; i < offset + nn * stride; i += istep * stride)
{
int j = i + mmax * stride;
int wTemp = data[j];
data[j] = modSubtract(data[i], wTemp);
data[i] = modAdd(data[i], wTemp);
}
int t = r;
for (int m = 1; m < mmax; m++)
{
for (int i = offset + m * stride; i < offset + nn * stride; i += istep * stride)
{
int j = i + mmax * stride;
int wTemp = modMultiply(this.wTable[t], data[j]);
data[j] = modSubtract(data[i], wTemp);
data[i] = modAdd(data[i], wTemp);
}
t += r;
}
mmax = istep;
}
}
}
private void columnScramble(int offset)
{
for (int k = 0; k < this.permutationTableLength; k += 2)
{
int i = offset + this.permutationTable[k] * this.stride,
j = offset + this.permutationTable[k + 1] * this.stride;
int tmp = this.data[i];
this.data[i] = this.data[j];
this.data[j] = tmp;
}
}
private int modMultiply(int a, int b)
{
long t = (long) a * (long) b;
//int r1 = a * b - (int) (this.inverseModulus * (double) a * (double) b) * this.modulus,
int r1 = (int) t - (int) ((t >>> 30) * this.inverseModulus >>> 33) * this.modulus,
r2 = r1 - this.modulus;
return (r2 < 0 ? r1 : r2);
}
private int modAdd(int a, int b)
{
int r1 = a + b,
r2 = r1 - this.modulus;
return (r2 < 0 ? r1 : r2);
}
private int modSubtract(int a, int b)
{
int r1 = a - b,
r2 = r1 + this.modulus;
return (r1 < 0 ? r2 : r1);
}
public void setModulus(int modulus)
{
//this.inverseModulus = 1.0 / (modulus + 0.5); // Round down
this.inverseModulus = (long) (9223372036854775808.0 / (double) modulus);
this.modulus = modulus;
}
public int getModulus()
{
return this.modulus;
}
private int stride;
private int length;
private int[] data;
private int offset;
private int[] wTable = { 0 };
private int[] permutationTable = { 0 };
private int permutationTableLength;
private int modulus;
//private double inverseModulus;
private long inverseModulus;
// Methods for transposing the matrix
public static final int TRANSPOSE = 3;
public static final int PERMUTE = 4;
public void setN2(int n2)
{
this.n2 = n2;
}
public void setIndex(int[] index)
{
this.index = index;
}
public void setIndexCount(int indexCount)
{
this.indexCount = indexCount;
}
private void transpose()
{
int i = getGlobalId(0),
j = getGlobalId(1);
if (i < j)
{
int position1 = this.offset + j * this.n2 + i,
position2 = this.offset + i * this.n2 + j;
int tmp = this.data[position1];
this.data[position1] = this.data[position2];
this.data[position2] = tmp;
}
}
private void permute()
{
int j = getGlobalId();
for (int i = 0; i < this.indexCount; i++)
{
int o = this.index[i];
int tmp = this.data[this.offset + this.n2 * o + j];
for (i++; this.index[i] != 0; i++)
{
int m = this.index[i];
this.data[this.offset + this.n2 * o + j] = this.data[this.offset + this.n2 * m + j];
o = m;
}
this.data[this.offset + this.n2 * o + j] = tmp;
}
}
private int n2;
private int[] index = { 0 };
private int indexCount;
// Methods for multiplying elements in the matrix
public static final int MULTIPLY_ELEMENTS = 5;
public void setStartRow(int startRow)
{
this.startRow = startRow;
}
public void setStartColumn(int startColumn)
{
this.startColumn = startColumn;
}
public void setRows(int rows)
{
this.rows = rows;
}
public void setColumns(int columns)
{
this.columns = columns;
}
public void setW(int w)
{
this.w = w;
}
public void setScaleFactor(int scaleFactor)
{
this.scaleFactor = scaleFactor;
}
private void multiplyElements()
{
int[] data = this.data;
int position = this.offset + getGlobalId();
int rowFactor = modPow(this.w, (int) this.startRow);
int columnFactor = modPow(this.w, (int) this.startColumn + getGlobalId());
int rowStartFactor = modMultiply(this.scaleFactor, modPow(rowFactor, (int) this.startColumn + getGlobalId()));
for (int i = 0; i < this.rows; i++)
{
data[position] = modMultiply(data[position], rowStartFactor);
position += this.columns;
rowStartFactor = modMultiply(rowStartFactor, columnFactor);
}
}
private int modPow(int a, int n)
{
if (n == 0)
{
return 1;
}
else if (n < 0)
{
n = getModulus() - 1 + n;
}
int exponent = (int) n;
while ((exponent & 1) == 0)
{
a = modMultiply(a, a);
exponent >>= 1;
}
int r = a;
for (exponent >>= 1; exponent > 0; exponent >>= 1)
{
a = modMultiply(a, a);
if ((exponent & 1) != 0)
{
r = modMultiply(r, a);
}
}
return r;
}
private int startRow;
private int startColumn;
private int rows;
private int columns;
private int w;
private int scaleFactor;
// Methods for factor-3 transform
public static final int TRANSFORM_COLUMNS = 6;
public static final int INVERSE_TRANSFORM_COLUMNS = 7;
public void setOp(int op)
{
this.op = op;
}
public void setWw(int ww)
{
this.ww = ww;
}
public void setW1(int w1)
{
this.w1 = w1;
}
public void setW2(int w2)
{
this.w2 = w2;
}
@Override
public void run()
{
if (this.op == TRANSFORM_ROWS)
{
columnTableFNT();
}
else if (this.op == INVERSE_TRANSFORM_ROWS)
{
inverseColumnTableFNT();
}
else if (this.op == TRANSPOSE)
{
transpose();
}
else if (this.op == PERMUTE)
{
permute();
}
else if (this.op == MULTIPLY_ELEMENTS)
{
multiplyElements();
}
else if (this.op == TRANSFORM_COLUMNS || this.op == INVERSE_TRANSFORM_COLUMNS)
{
transformColumns();
}
}
private void transformColumns()
{
int i = getGlobalId();
int tmp1 = modPow(this.w, (int) this.startColumn + i),
tmp2 = modPow(this.ww, (int) this.startColumn + i);
// 3-point WFTA on the corresponding array elements
int x0 = this.data[this.offset + i],
x1 = this.data[this.offset + this.columns + i],
x2 = this.data[this.offset + 2 * this.columns + i],
t;
if (this.op == INVERSE_TRANSFORM_COLUMNS)
{
// Multiply before transform
x1 = modMultiply(x1, tmp1);
x2 = modMultiply(x2, tmp2);
}
// Transform column
t = modAdd(x1, x2);
x2 = modSubtract(x1, x2);
x0 = modAdd(x0, t);
t = modMultiply(t, this.w1);
x2 = modMultiply(x2, this.w2);
t = modAdd(t, x0);
x1 = modAdd(t, x2);
x2 = modSubtract(t, x2);
if (this.op == TRANSFORM_COLUMNS)
{
// Multiply after transform
x1 = modMultiply(x1, tmp1);
x2 = modMultiply(x2, tmp2);
}
this.data[this.offset + i] = x0;
this.data[this.offset + this.columns + i] = x1;
this.data[this.offset + 2 * this.columns + i] = x2;
}
private int op;
private int ww;
private int w1;
private int w2;
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy