All Downloads are FREE. Search and download functionalities are using the official Maven repository.

ch.randelshofer.fastdoubleparser.FftMultiplier Maven / Gradle / Ivy

The newest version!
/*
 * @(#)FftMultiplier.java
 * Copyright © 2024 Werner Randelshofer, Switzerland. MIT License.
 */
package ch.randelshofer.fastdoubleparser;

import java.math.BigInteger;

import static ch.randelshofer.fastdoubleparser.FastDoubleMath.fastScalb;
import static ch.randelshofer.fastdoubleparser.FastDoubleSwar.fma;

/**
 * Provides methods for multiplying two {@link BigInteger}s using the
 * {@code FFT algorithm}.
 * 

* This code is based on {@code bigint} by Timothy Buktu. *

* References: *

*
bigint, Copyright 2013 Timothy Buktu, 2-clause BSD License.
* Note: We only use portions from this project, that have been marked with 2-clause BSD License * in this file LICENSE. *
*
github.com
*
*/ final class FftMultiplier { public static final double COS_0_25 = Math.cos(0.25 * Math.PI); public static final double SIN_0_25 = Math.sin(0.25 * Math.PI); /** * The threshold value for using floating point FFT multiplication. * If the number of bits in each mag array is greater than the * Toom-Cook threshold, and the number of bits in at least one of * the mag arrays is greater than this threshold, then FFT * multiplication will be used. */ private static final int FFT_THRESHOLD = 33220; /** * This constant limits {@code mag.length} of BigIntegers to the supported * range. */ private static final int MAX_MAG_LENGTH = Integer.MAX_VALUE / Integer.SIZE + 1; // (1 << 26) /** * for FFTs of length up to 3*2^19 */ private static final int ROOTS3_CACHE_SIZE = 20; /** * for FFTs of length up to 2^19 */ private static final int ROOTS_CACHE2_SIZE = 20; /** * The threshold value for using 3-way Toom-Cook multiplication. */ private static final int TOOM_COOK_THRESHOLD = 240 * 8; /** * Sets of complex roots of unity. The set at index k contains 2^k * elements representing all (2^(k+2))-th roots between 0 and pi/2. * Used for FFT multiplication. */ private volatile static ComplexVector[] ROOTS2_CACHE = new ComplexVector[ROOTS_CACHE2_SIZE]; /** * Sets of complex roots of unity. The set at index k contains 3*2^k * elements representing all (3*2^(k+2))-th roots between 0 and pi/2. * Used for FFT multiplication. */ private volatile static ComplexVector[] ROOTS3_CACHE = new ComplexVector[ROOTS3_CACHE_SIZE]; /** * Returns the maximum number of bits that one double precision number can fit without * causing the multiplication to be incorrect. * * @param bitLen length of this number in bits * @return the maximum number of bits */ static int bitsPerFftPoint(int bitLen) { if (bitLen <= 19 * (1 << 9)) { return 19; } if (bitLen <= 18 * (1 << 10)) { return 18; } if (bitLen <= 17 * (1 << 12)) { return 17; } if (bitLen <= 16 * (1 << 14)) { return 16; } if (bitLen <= 15 * (1 << 16)) { return 15; } if (bitLen <= 14 * (1 << 18)) { return 14; } if (bitLen <= 13 * (1 << 20)) { return 13; } if (bitLen <= 12 * (1 << 21)) { return 12; } if (bitLen <= 11 * (1 << 23)) { return 11; } if (bitLen <= 10 * (1 << 25)) { return 10; } if (bitLen <= 9 * (1 << 27)) { return 9; } return 8; } /** * Returns n-th complex roots of unity for the angles 0..pi/2, suitable * for a transform of length n. * They are used as twiddle factors and as weights for the right-angle transform. * n must be 1 or an even number. */ private static ComplexVector calculateRootsOfUnity(int n) { if (n == 1) { ComplexVector v = new ComplexVector(1); v.real(0, 1); v.imag(0, 0); return v; } ComplexVector roots = new ComplexVector(n); roots.set(0, 1.0, 0.0); double cos = COS_0_25; double sin = SIN_0_25; roots.set(n / 2, cos, sin); double angleTerm = 0.5 * Math.PI / n; for (int i = 1; i < n / 2; i++) { double angle = angleTerm * i; cos = Math.cos(angle); sin = Math.sin(angle); roots.set(i, cos, sin); roots.set(n - i, sin, cos); } return roots; } /** * Performs an FFT of length 2^n on the vector {@code a}. * This is a decimation-in-frequency implementation. * * @param a input and output, must be a power of two in size * @param roots an array that contains one set of roots at indices * log2(a.length), log2(a.length)-2, log2(a.length)-4, ... * Each roots[s] must contain 2^s roots of unity such that * {@code roots[s][k] = e^(pi*k*i/(2*roots.length))}, * i.e., they must cover the first quadrant. */ private static void fft(ComplexVector a, ComplexVector[] roots) { int n = a.length; int logN = 31 - Integer.numberOfLeadingZeros(n); MutableComplex a0 = new MutableComplex(); MutableComplex a1 = new MutableComplex(); MutableComplex a2 = new MutableComplex(); MutableComplex a3 = new MutableComplex(); // do two FFT stages at a time (radix-4) MutableComplex omega1 = new MutableComplex(); MutableComplex omega2 = new MutableComplex(); int s = logN; for (; s >= 2; s -= 2) { ComplexVector rootsS = roots[s - 2]; int m = 1 << s; for (int i = 0; i < n; i += m) { for (int j = 0; j < m / 4; j++) { omega1.set(rootsS, j); // computing omega2 from omega1 is less accurate than Math.cos() and Math.sin(), // but it is the same error we'd incur with radix-2, so we're not breaking the // assumptions of the Percival paper. omega1.squareInto(omega2); int idx0 = i + j; int idx1 = i + j + m / 4; int idx2 = i + j + m / 2; int idx3 = i + j + m * 3 / 4; // radix-4 butterfly: // a[idx0] = (a[idx0] + a[idx1] + a[idx2] + a[idx3]) * w^0 // a[idx1] = (a[idx0] + a[idx1]*(-i) + a[idx2]*(-1) + a[idx3]*i) * w^1 // a[idx2] = (a[idx0] + a[idx1]*(-1) + a[idx2] + a[idx3]*(-1)) * w^2 // a[idx3] = (a[idx0] + a[idx1]*i + a[idx2]*(-1) + a[idx3]*(-i)) * w^3 // where w = omega1^(-1) = conjugate(omega1) a.addInto(idx0, a, idx1, a0); a0.add(a, idx2); a0.add(a, idx3); a.subtractTimesIInto(idx0, a, idx1, a1); a1.subtract(a, idx2); a1.addTimesI(a, idx3); a1.multiplyConjugate(omega1); a.subtractInto(idx0, a, idx1, a2); a2.add(a, idx2); a2.subtract(a, idx3); a2.multiplyConjugate(omega2); a.addTimesIInto(idx0, a, idx1, a3); a3.subtract(a, idx2); a3.subtractTimesI(a, idx3); a3.multiply(omega1); // Bernstein's trick: multiply by omega^(-1) instead of omega^3 a0.copyInto(a, idx0); a1.copyInto(a, idx1); a2.copyInto(a, idx2); a3.copyInto(a, idx3); } } } // do one final radix-2 step if there is an odd number of stages if (s > 0) { for (int i = 0; i < n; i += 2) { // omega = 1 // a0 = a[i]; // a1 = a[i + IMAG]; // a[i] += a1; // a[i + IMAG] = a0 - a1; a.copyInto(i, a0); a.copyInto(i + ComplexVector.IMAG, a1); a.add(i, a1); a0.subtractInto(a1, a, i + 1); } } } /** * Performs FFTs or IFFTs of size 3 on the vector {@code (a0[i], a1[i], a2[i])} * for each {@code i}. The output is placed back into {@code a0, a1, and a2}. * * @param a0 inputs / outputs for the first FFT coefficient * @param a1 inputs / outputs for the second FFT coefficient * @param a2 inputs / outputs for the third FFT coefficient * @param sign 1 for a forward FFT, -1 for an inverse FFT * @param scale 1 for a forward FFT, 1/3 for an inverse FFT */ private static void fft3(ComplexVector a0, ComplexVector a1, ComplexVector a2, int sign, double scale) { double omegaImag = sign * -0.5 * Math.sqrt(3); // imaginary part of omega for n=3: sin(sign*(-2)*pi*1/3) for (int i = 0; i < a0.length; i++) { double a0Real = a0.real(i) + a1.real(i) + a2.real(i); double a0Imag = a0.imag(i) + a1.imag(i) + a2.imag(i); double c = omegaImag * (a2.imag(i) - a1.imag(i)); double d = omegaImag * (a1.real(i) - a2.real(i)); double e = 0.5 * (a1.real(i) + a2.real(i)); double f = 0.5 * (a1.imag(i) + a2.imag(i)); double a1Real = a0.real(i) - e + c; double a1Imag = a0.imag(i) + d - f; double a2Real = a0.real(i) - e - c; double a2Imag = a0.imag(i) - d - f; a0.real(i, a0Real * scale); a0.imag(i, a0Imag * scale); a1.real(i, a1Real * scale); a1.imag(i, a1Imag * scale); a2.real(i, a2Real * scale); a2.imag(i, a2Imag * scale); } } /** * Performs an FFT of length 3*2^n on the vector {@code a}. * Uses the 4-step algorithm to decompose the 3*2^n FFT into 2^n FFTs of * length 3 and 3 FFTs of length 2^n. * See https://www.nas.nasa.gov/assets/pdf/techreports/1989/rnr-89-004.pdf * * @param a input and output, must be 3*2^n in size for some n>=2 * @param roots2 an array that contains one set of roots at indices * log2(a.length/3), log2(a.length/3)-2, log2(a.length/3)-4, ... * Each roots[s] must contain 2^s roots of unity such that * {@code roots[s][k] = e^(pi*k*i/(2*roots.length))}, * i.e., they must cover the first quadrant. * @param roots3 must be the same length as {@code a} and contain roots of * unity such that {@code roots[k] = e^(pi*k*i/(2*roots3.length))}, * i.e., they need to cover the first quadrant. */ private static void fftMixedRadix(ComplexVector a, ComplexVector[] roots2, ComplexVector roots3) { int oneThird = a.length / 3; ComplexVector a0 = new ComplexVector(a, 0, oneThird); ComplexVector a1 = new ComplexVector(a, oneThird, oneThird * 2); ComplexVector a2 = new ComplexVector(a, oneThird * 2, a.length); // step 1: perform a.length/3 transforms of length 3 fft3(a0, a1, a2, 1, 1); // step 2: multiply by roots of unity MutableComplex omega = new MutableComplex(); for (int i = 0; i < a.length / 4; i++) { omega.set(roots3, i); // a0[i] *= omega^0; a1[i] *= omega^1; a2[i] *= omega^2 a1.multiplyConjugate(i, omega); a2.multiplyConjugate(i, omega); a2.multiplyConjugate(i, omega); } for (int i = a.length / 4; i < oneThird; i++) { omega.set(roots3, i - a.length / 4); // a0[i] *= omega^0; a1[i] *= omega^1; a2[i] *= omega^2 a1.multiplyConjugateTimesI(i, omega); a2.multiplyConjugateTimesI(i, omega); a2.multiplyConjugateTimesI(i, omega); } // step 3 is not needed // step 4: perform 3 transforms of length a.length/3 fft(a0, roots2); fft(a1, roots2); fft(a2, roots2); } static BigInteger fromFftVector(ComplexVector fftVec, int signum, int bitsPerFftPoint) { assert bitsPerFftPoint <= 25 : bitsPerFftPoint + " does not fit into an int with slack"; int fftLen = (int) Math.min(fftVec.length, ((long) MAX_MAG_LENGTH * 32) / bitsPerFftPoint + 1); int magLen = (int) (8 * ((long) fftLen * bitsPerFftPoint + 31) / 32); byte[] mag = new byte[magLen]; int base = 1 << bitsPerFftPoint; int bitMask = base - 1; int bitPadding = 32 - bitsPerFftPoint; long carry = 0; int bitLength = mag.length * 8; int bitIdx = bitLength - bitsPerFftPoint; int magComponent = 0; int prevIdx = Math.min(Math.max(0, bitIdx >> 3), mag.length - 4); for (int part = 0; part <= 1; part++) { // 0=real, 1=imaginary for (int fftIdx = 0; fftIdx < fftLen; fftIdx++) { long fftElem = Math.round(fftVec.part(fftIdx, part)) + carry; carry = fftElem >> bitsPerFftPoint; int idx = Math.min(Math.max(0, bitIdx >> 3), mag.length - 4); magComponent >>>= (prevIdx - idx) << 3; int shift = bitPadding - bitIdx + (idx << 3); magComponent |= (int) ((fftElem & bitMask) << shift); FastDoubleSwar.writeIntBE(mag, idx, magComponent); prevIdx = idx; bitIdx -= bitsPerFftPoint; } } return new BigInteger(signum, mag); } /** * Returns sets of complex roots of unity. For k=logN, logN-2, logN-4, ..., * the return value contains all k-th roots between 0 and pi/2. * * @param logN for a transform of length 2^logN */ private static ComplexVector[] getRootsOfUnity2(int logN) { ComplexVector[] roots = new ComplexVector[logN + 1]; for (int i = logN; i >= 0; i -= 2) { if (i < ROOTS_CACHE2_SIZE) { if (ROOTS2_CACHE[i] == null) { ROOTS2_CACHE[i] = calculateRootsOfUnity(1 << i); } roots[i] = ROOTS2_CACHE[i]; } else { roots[i] = calculateRootsOfUnity(1 << i); } } return roots; } /** * Returns sets of complex roots of unity. For k=logN, logN-2, logN-4, ..., * the return value contains all k-th roots between 0 and pi/2. * * @param logN for a transform of length 3*2^logN */ private static ComplexVector getRootsOfUnity3(int logN) { if (logN < ROOTS3_CACHE_SIZE) { if (ROOTS3_CACHE[logN] == null) { ROOTS3_CACHE[logN] = calculateRootsOfUnity(3 << logN); } return ROOTS3_CACHE[logN]; } else { return calculateRootsOfUnity(3 << logN); } } /** * Performs an inverse FFT of length 2^n on the vector {@code a}. * This is a decimation-in-time implementation. * * @param a input and output, must be a power of two in size * @param roots an array that contains one set of roots at indices * log2(a.length), log2(a.length)-2, log2(a.length)-4, ... * Each roots[s] must contain 2^s roots of unity such that * {@code roots[s][k] = e^(pi*k*i/(2*roots.length))}, * i.e., they must cover the first quadrant. */ private static void ifft(ComplexVector a, ComplexVector[] roots) { int n = a.length; int logN = 31 - Integer.numberOfLeadingZeros(n); MutableComplex a0 = new MutableComplex(); MutableComplex a1 = new MutableComplex(); MutableComplex a2 = new MutableComplex(); MutableComplex a3 = new MutableComplex(); MutableComplex b0 = new MutableComplex(); MutableComplex b1 = new MutableComplex(); MutableComplex b2 = new MutableComplex(); MutableComplex b3 = new MutableComplex(); int s = 1; // do one radix-2 step if there is an odd number of stages if (logN % 2 != 0) { for (int i = 0; i < n; i += 2) { // omega = 1 a.copyInto(i + 1, a2); a.copyInto(i, a0); a.add(i, a2); a0.subtractInto(a2, a, i + 1); } s++; } // do the remaining stages two at a time (radix-4) MutableComplex omega1 = new MutableComplex(); MutableComplex omega2 = new MutableComplex(); for (; s <= logN; s += 2) { ComplexVector rootsS = roots[s - 1]; int m = 1 << (s + 1); for (int i = 0; i < n; i += m) { for (int j = 0; j < m / 4; j++) { omega1.set(rootsS, j); // computing omega2 from omega1 is less accurate than Math.cos() and Math.sin(), // but it is the same error we'd incur with radix-2, so we're not breaking the // assumptions of the Percival paper. omega1.squareInto(omega2); int idx0 = i + j; int idx1 = i + j + m / 4; int idx2 = i + j + m / 2; int idx3 = i + j + m * 3 / 4; // radix-4 butterfly: // a[idx0] = a[idx0]*w^0 + a[idx1]*w^1 + a[idx2]*w^2 + a[idx3]*w^3 // a[idx1] = a[idx0]*w^0 + a[idx1]*i*w^1 + a[idx2]*(-1)*w^2 + a[idx3]*(-i)*w^3 // a[idx2] = a[idx0]*w^0 + a[idx1]*(-1)*w^1 + a[idx2]*w^2 + a[idx3]*(-1)*w^3 // a[idx3] = a[idx0]*w^0 + a[idx1]*(-i)*w^1 + a[idx2]*(-1)*w^2 + a[idx3]*i*w^3 // where w = omega1 a.copyInto(idx0, a0); a.multiplyInto(idx1, omega1, a1); a.multiplyInto(idx2, omega2, a2); a.multiplyConjugateInto(idx3, omega1, a3); // Bernstein's trick: multiply by omega^(-1) instead of omega^3 a0.addInto(a1, b0); b0.add(a2); b0.add(a3); a0.addTimesIInto(a1, b1); b1.subtract(a2); b1.subtractTimesI(a3); a0.subtractInto(a1, b2); b2.add(a2); b2.subtract(a3); a0.subtractTimesIInto(a1, b3); b3.subtract(a2); b3.addTimesI(a3); b0.copyInto(a, idx0); b1.copyInto(a, idx1); b2.copyInto(a, idx2); b3.copyInto(a, idx3); } } } // divide all vector elements by n for (int i = 0; i < n; i++) { a.timesTwoToThe(i, -logN); } } /** * Performs an inverse FFT of length 3*2^n on the vector {@code a}. * Uses the 4-step algorithm to decompose the 3*2^n FFT into 2^n FFTs of * length 3 and 3 FFTs of length 2^n. * See https://www.nas.nasa.gov/assets/pdf/techreports/1989/rnr-89-004.pdf * * @param a input and output, must be 3*2^n in size for some n>=2 * @param roots2 an array that contains one set of roots at indices * log2(a.length/3), log2(a.length/3)-2, log2(a.length/3)-4, ... * Each roots[s] must contain 2^s roots of unity such that * {@code roots[s][k] = e^(pi*k*i/(2*roots.length))}, * i.e., they must cover the first quadrant. * @param roots3 must be the same length as {@code a} and contain roots of * unity such that {@code roots[k] = e^(pi*k*i/(2*roots3.length))}, * i.e., they need to cover the first quadrant. */ private static void ifftMixedRadix(ComplexVector a, ComplexVector[] roots2, ComplexVector roots3) { int oneThird = a.length / 3; ComplexVector a0 = new ComplexVector(a, 0, oneThird); ComplexVector a1 = new ComplexVector(a, oneThird, oneThird * 2); ComplexVector a2 = new ComplexVector(a, oneThird * 2, a.length); // step 1: perform 3 transforms of length a.length/3 ifft(a0, roots2); ifft(a1, roots2); ifft(a2, roots2); // step 2: multiply by roots of unity MutableComplex omega = new MutableComplex(); for (int i = 0; i < a.length / 4; i++) { omega.set(roots3, i); // a0[i] *= omega^0; a1[i] *= omega^1; a2[i] *= omega^2 a1.multiply(i, omega); a2.multiply(i, omega); a2.multiply(i, omega); } for (int i = a.length / 4; i < oneThird; i++) { omega.set(roots3, i - a.length / 4); // a0[i] *= omega^0; a1[i] *= omega^1; a2[i] *= omega^2 a1.multiplyByIAnd(i, omega); a2.multiplyByIAnd(i, omega); a2.multiplyByIAnd(i, omega); } // step 3 is not needed // step 4: perform a.length/3 transforms of length 3 fft3(a0, a1, a2, -1, 1.0 / 3); } /** * Returns a BigInteger whose value is {@code (a * b)}. * * @param a value a * @param b value b * @return {@code this * val} * @implNote An implementation may offer better algorithmic * performance when {@code a == b}. */ static BigInteger multiply(BigInteger a, BigInteger b) { assert a != null : "a==null"; assert b != null : "b==null"; if (b.signum() == 0 || a.signum() == 0) { return BigInteger.ZERO; } // Squaring is slightly faster than multiplication. // We check for identity here and not for equality, because an equality check of big integers is very expensive. if (b == a) { return square(b); } int xlen = a.bitLength(); int ylen = b.bitLength(); if ((long) xlen + ylen > 32L * MAX_MAG_LENGTH) { throw new ArithmeticException("BigInteger would overflow supported range"); } if (xlen > TOOM_COOK_THRESHOLD && ylen > TOOM_COOK_THRESHOLD && (xlen > FFT_THRESHOLD || ylen > FFT_THRESHOLD)) { return multiplyFft(a, b); } return a.multiply(b); } /** * Multiplies two BigIntegers using a floating-point FFT. *

* Floating-point math is inaccurate; to ensure the output of the FFT and * IFFT rounds to the correct result for every input, the provably safe * FFT error bounds from "Rapid Multiplication Modulo The Sum And * Difference of Highly Composite Numbers" by Colin Percival, pg. 392 * (fft.pdf) are used, the vector is * "balanced" before the FFT, and accurate twiddle factors are used. *

* This implementation incorporates several features compared to the * standard FFT algorithm * (Cooley Tukey FFT algorithm): *

    *
  • It uses a variant called right-angle convolution which weights the * vector before the transform. The benefit of the right-angle * convolution is that when multiplying two numbers of length n, an * FFT of length n suffices whereas a regular FFT needs length 2n. * This is because the right-angle convolution places half of the * result in the real part and the other half in the imaginary part. * See: Discrete Weighted Transforms And Large-Integer Arithmetic by * Richard Crandall and Barry Fagin. *
  • FFTs of length 3*2^n are supported in addition to 2^n. *
  • Radix-4 butterflies; see * https://www.nxp.com/docs/en/application-note/AN3666.pdf *
  • Bernstein's conjugate twiddle trick for a small speed gain at the * expense of (further) reordering the output of the FFT which is not * a problem because it is reordered back in the IFFT. *
  • Roots of unity are cached *
* FFT vectors are stored as arrays of primitive doubles (two array * elements are needed for representing one complex number). Storing them * as arrays of primitive doubles instead of as MutableComplex objects is * memory efficient, * but in some cases below ~10^6 decimal digits, it hurts speed because * it requires additional copying. Ideally this would be implemented using * value types when they become available. * * @param a value a * @param b value b * @return a*b */ static BigInteger multiplyFft(BigInteger a, BigInteger b) { int signum = a.signum() * b.signum(); byte[] aMag = (a.signum() < 0 ? a.negate() : a).toByteArray(); byte[] bMag = (b.signum() < 0 ? b.negate() : b).toByteArray(); int bitLen = Math.max(aMag.length, bMag.length) * 8; int bitsPerPoint = bitsPerFftPoint(bitLen); int fftLen = (bitLen + bitsPerPoint - 1) / bitsPerPoint + 1; // +1 for a possible carry, see toFFTVector() int logFFTLen = 32 - Integer.numberOfLeadingZeros(fftLen - 1); // Use a 2^n or 3*2^n transform, whichever is shortest int fftLen2 = 1 << (logFFTLen); // rounded to 2^n int fftLen3 = fftLen2 * 3 / 4; // rounded to 3*2^n if (fftLen < fftLen3 && logFFTLen > 3) { ComplexVector[] roots2 = getRootsOfUnity2(logFFTLen - 2); // roots for length fftLen/3 which is a power of two ComplexVector weights = getRootsOfUnity3(logFFTLen - 2); ComplexVector twiddles = getRootsOfUnity3(logFFTLen - 4); ComplexVector aVec = toFftVector(aMag, fftLen3, bitsPerPoint); aVec.applyWeights(weights); fftMixedRadix(aVec, roots2, twiddles); ComplexVector bVec = toFftVector(bMag, fftLen3, bitsPerPoint); bVec.applyWeights(weights); fftMixedRadix(bVec, roots2, twiddles); aVec.multiplyPointwise(bVec); ifftMixedRadix(aVec, roots2, twiddles); aVec.applyInverseWeights(weights); return fromFftVector(aVec, signum, bitsPerPoint); } else { ComplexVector[] roots = getRootsOfUnity2(logFFTLen); ComplexVector aVec = toFftVector(aMag, fftLen2, bitsPerPoint); aVec.applyWeights(roots[logFFTLen]); fft(aVec, roots); ComplexVector bVec = toFftVector(bMag, fftLen2, bitsPerPoint); bVec.applyWeights(roots[logFFTLen]); fft(bVec, roots); aVec.multiplyPointwise(bVec); ifft(aVec, roots); aVec.applyInverseWeights(roots[logFFTLen]); return fromFftVector(aVec, signum, bitsPerPoint); } } /** * Returns a BigInteger whose value is {@code (this2)}. * * @return {@code this2} */ static BigInteger square(BigInteger a) { if (a.signum() == 0) { return BigInteger.ZERO; } return a.bitLength() < FFT_THRESHOLD ? a.multiply(a) : squareFft(a); } static BigInteger squareFft(BigInteger a) { byte[] mag = a.toByteArray(); int bitLen = mag.length * 8; int bitsPerPoint = bitsPerFftPoint(bitLen); int fftLen = (bitLen + bitsPerPoint - 1) / bitsPerPoint + 1; // +1 for a possible carry, see toFFTVector() int logFFTLen = 32 - Integer.numberOfLeadingZeros(fftLen - 1); // Use a 2^n or 3*2^n transform, whichever is shorter int fftLen2 = 1 << (logFFTLen); // rounded to 2^n int fftLen3 = fftLen2 * 3 / 4; // rounded to 3*2^n if (fftLen < fftLen3) { fftLen = fftLen3; ComplexVector vec = toFftVector(mag, fftLen, bitsPerPoint); ComplexVector[] roots2 = getRootsOfUnity2(logFFTLen - 2); // roots for length fftLen/3 which is a power of two ComplexVector weights = getRootsOfUnity3(logFFTLen - 2); ComplexVector twiddles = getRootsOfUnity3(logFFTLen - 4); vec.applyWeights(weights); fftMixedRadix(vec, roots2, twiddles); vec.squarePointwise(); ifftMixedRadix(vec, roots2, twiddles); vec.applyInverseWeights(weights); return fromFftVector(vec, 1, bitsPerPoint); } else { fftLen = fftLen2; ComplexVector vec = toFftVector(mag, fftLen, bitsPerPoint); ComplexVector[] roots = getRootsOfUnity2(logFFTLen); vec.applyWeights(roots[logFFTLen]); fft(vec, roots); vec.squarePointwise(); ifft(vec, roots); vec.applyInverseWeights(roots[logFFTLen]); return fromFftVector(vec, 1, bitsPerPoint); } } /** * Converts this BigInteger into an array of complex numbers suitable for an FFT. * Populates the real parts and sets the imaginary parts to zero. */ static ComplexVector toFftVector(byte[] mag, int fftLen, int bitsPerFftPoint) { assert bitsPerFftPoint <= 25 : bitsPerFftPoint + " does not fit into an int with slack"; ComplexVector fftVec = new ComplexVector(fftLen); if (mag.length < 4) { byte[] paddedMag = new byte[4]; System.arraycopy(mag, 0, paddedMag, 4 - mag.length, mag.length); mag = paddedMag; } // Read fftPoint bits from right (least significant) to left (most significant) int base = 1 << bitsPerFftPoint; int halfBase = base / 2; int bitMask = base - 1; int bitPadding = 32 - bitsPerFftPoint; int bitLength = mag.length * 8; int carry = 0;// when we subtract base from a digit, we need to carry one int fftIdx = 0; for (int bitIdx = bitLength - bitsPerFftPoint; bitIdx > -bitsPerFftPoint; bitIdx -= bitsPerFftPoint) { int idx = Math.min(Math.max(0, bitIdx >> 3), mag.length - 4); int shift = bitPadding - bitIdx + (idx << 3); int fftPoint = (FastDoubleSwar.readIntBE(mag, idx) >>> shift) & bitMask; // "balance" the output digits so -base/2 < digit < base/2 fftPoint += carry; carry = (halfBase - fftPoint) >>> 31;// if fftPoint>halfBase then carry:=1, else carry:=0 fftPoint -= base & (-carry);//if (carry != 0) then fftPoint -= base; fftVec.real(fftIdx, fftPoint); fftIdx++; } // final carry if (carry > 0) { fftVec.real(fftIdx, carry); } return fftVec; } final static class ComplexVector { /** * A complex number in an FFT double[] vector occupies 2^1 array elements. */ private final static int COMPLEX_SIZE_SHIFT = 1; final static int IMAG = 1; final static int REAL = 0; /** * This arrays contains complex numbers. *

* A complex number occupies 2 consecutive array elements: * the real part and then the imaginary part. */ private final double[] a; /** * The number of complex numbers stored in this vector. */ private final int length; /** * Offset to the real part of a complex number. */ private final int offset; ComplexVector(int length) { this.a = new double[length << COMPLEX_SIZE_SHIFT]; this.length = length; this.offset = 0; } /** * Creates a view on another vector. * * @param c the other vector * @param from start index of the view * @param to end index of the view */ ComplexVector(ComplexVector c, int from, int to) { this.length = to - from; this.a = c.a; this.offset = from << 1; } void add(int idxa, MutableComplex c) { a[realIdx(idxa)] += c.real; a[imagIdx(idxa)] += c.imag; } void addInto(int idxa, ComplexVector c, int idxc, MutableComplex destination) { destination.real = a[realIdx(idxa)] + c.real(idxc); destination.imag = a[imagIdx(idxa)] + c.imag(idxc); } void addTimesIInto(int idxa, ComplexVector c, int idxc, MutableComplex destination) { destination.real = a[realIdx(idxa)] - c.imag(idxc); destination.imag = a[imagIdx(idxa)] + c.real(idxc); } /** * Multiplies the elements of an FFT vector by 1/weight. * Used for the right-angle convolution. */ void applyInverseWeights(ComplexVector weights) { int offw = weights.offset; double[] w = weights.a; int end = offset + length << 1; for (int offa = offset; offa < end; offa += 2) { // the following code is the same as: this.multiplyConjugate(i, weights[i]); double real = a[offa + REAL]; double imag = a[offa + IMAG]; a[offa] = fma(real, w[offw + REAL], imag * w[offw + IMAG]); a[offa + IMAG] = fma(-real, w[offw + IMAG], imag * w[offw + REAL]); offw += 2; } } /** * Multiplies the elements of an FFT vector by weights. * Doing this makes a regular FFT convolution a right-angle convolution. */ void applyWeights(ComplexVector weights) { // The following code is the same as: // for (int i=0;i





© 2015 - 2025 Weber Informatics LLC | Privacy Policy