
jdk.graal.compiler.lir.amd64.AMD64CounterModeAESCryptOp Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of compiler Show documentation
Show all versions of compiler Show documentation
The GraalVM compiler and the Graal-truffle optimizer.
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package jdk.graal.compiler.lir.amd64;
import static jdk.vm.ci.amd64.AMD64.r11;
import static jdk.vm.ci.amd64.AMD64.rax;
import static jdk.vm.ci.amd64.AMD64.rbx;
import static jdk.vm.ci.amd64.AMD64.xmm0;
import static jdk.vm.ci.amd64.AMD64.xmm1;
import static jdk.vm.ci.amd64.AMD64.xmm10;
import static jdk.vm.ci.amd64.AMD64.xmm11;
import static jdk.vm.ci.amd64.AMD64.xmm12;
import static jdk.vm.ci.amd64.AMD64.xmm13;
import static jdk.vm.ci.amd64.AMD64.xmm14;
import static jdk.vm.ci.amd64.AMD64.xmm2;
import static jdk.vm.ci.amd64.AMD64.xmm3;
import static jdk.vm.ci.amd64.AMD64.xmm4;
import static jdk.vm.ci.amd64.AMD64.xmm5;
import static jdk.vm.ci.amd64.AMD64.xmm6;
import static jdk.vm.ci.amd64.AMD64.xmm7;
import static jdk.vm.ci.amd64.AMD64.xmm8;
import static jdk.vm.ci.amd64.AMD64.xmm9;
import static jdk.vm.ci.code.ValueUtil.asRegister;
import static jdk.graal.compiler.lir.amd64.AMD64AESEncryptOp.AES_BLOCK_SIZE;
import static jdk.graal.compiler.lir.amd64.AMD64AESEncryptOp.keyShuffleMask;
import static jdk.graal.compiler.lir.amd64.AMD64AESEncryptOp.loadKey;
import static jdk.graal.compiler.lir.amd64.AMD64LIRHelper.pointerConstant;
import static jdk.graal.compiler.lir.amd64.AMD64LIRHelper.recordExternalAddress;
import java.util.function.BiConsumer;
import jdk.graal.compiler.asm.Label;
import jdk.graal.compiler.asm.amd64.AMD64Address;
import jdk.graal.compiler.asm.amd64.AMD64Assembler.ConditionFlag;
import jdk.graal.compiler.asm.amd64.AMD64MacroAssembler;
import jdk.graal.compiler.asm.amd64.AVXKind.AVXSize;
import jdk.graal.compiler.core.common.Stride;
import jdk.graal.compiler.debug.GraalError;
import jdk.graal.compiler.lir.LIRInstructionClass;
import jdk.graal.compiler.lir.SyncPort;
import jdk.graal.compiler.lir.asm.ArrayDataPointerConstant;
import jdk.graal.compiler.lir.asm.CompilationResultBuilder;
import jdk.vm.ci.amd64.AMD64Kind;
import jdk.vm.ci.code.Register;
import jdk.vm.ci.meta.AllocatableValue;
import jdk.vm.ci.meta.Value;
// @formatter:off
@SyncPort(from = "https://github.com/openjdk/jdk/blob/ce8399fd6071766114f5f201b6e44a7abdba9f5a/src/hotspot/cpu/x86/stubGenerator_x86_64_aes.cpp#L441-L748",
sha1 = "f73999add65bf7ccd9ee310df5412213fac98192")
// @formatter:on
public final class AMD64CounterModeAESCryptOp extends AMD64LIRInstruction {
public static final LIRInstructionClass TYPE = LIRInstructionClass.create(AMD64CounterModeAESCryptOp.class);
private final int lengthOffset;
@Alive({OperandFlag.REG}) private Value inValue;
@Alive({OperandFlag.REG}) private Value outValue;
@Alive({OperandFlag.REG}) private Value keyValue;
@Alive({OperandFlag.REG}) private Value counterValue;
@Alive({OperandFlag.REG}) private Value lenValue;
@Alive({OperandFlag.REG}) private Value encryptedCounterValue;
@Alive({OperandFlag.REG}) private Value usedPtrValue;
@Def({OperandFlag.REG}) protected Value resultValue;
@Temp protected Value[] temps;
public AMD64CounterModeAESCryptOp(AllocatableValue inValue,
AllocatableValue outValue,
AllocatableValue keyValue,
AllocatableValue counterValue,
AllocatableValue lenValue,
AllocatableValue encryptedCounterValue,
AllocatableValue usedPtrValue,
AllocatableValue resultValue,
int lengthOffset) {
super(TYPE);
this.inValue = inValue;
this.outValue = outValue;
this.keyValue = keyValue;
this.counterValue = counterValue;
this.lenValue = lenValue;
this.encryptedCounterValue = encryptedCounterValue;
this.usedPtrValue = usedPtrValue;
this.resultValue = resultValue;
this.lengthOffset = lengthOffset;
temps = new Value[]{
r11.asValue(),
rax.asValue(),
rbx.asValue(),
xmm0.asValue(),
xmm1.asValue(),
xmm2.asValue(),
xmm3.asValue(),
xmm4.asValue(),
xmm5.asValue(),
xmm6.asValue(),
xmm7.asValue(),
xmm8.asValue(),
xmm9.asValue(),
xmm10.asValue(),
xmm11.asValue(),
xmm12.asValue(),
xmm13.asValue(),
xmm14.asValue(),
};
}
static Label[] newLabels(int len) {
Label[] labels = new Label[len];
for (int i = 0; i < len; i++) {
labels[i] = new Label();
}
return labels;
}
private static Label[][] newLabels(int lenDimension1, int lenDimension2) {
Label[][] labels = new Label[lenDimension1][lenDimension2];
for (int i = 0; i < lenDimension1; i++) {
labels[i] = new Label[lenDimension2];
for (int j = 0; j < lenDimension2; j++) {
labels[i][j] = new Label();
}
}
return labels;
}
private static final int PARALLEL_FACTOR = 6;
private static ArrayDataPointerConstant counterShuffleMask = pointerConstant(16, new int[]{
// @formatter:off
0x0c0d0e0f, 0x08090a0b, 0x04050607, 0x00010203,
// @formatter:on
});
@Override
public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
GraalError.guarantee(inValue.getPlatformKind().equals(AMD64Kind.QWORD), "Invalid inValue kind: %s", inValue);
GraalError.guarantee(outValue.getPlatformKind().equals(AMD64Kind.QWORD), "Invalid outValue kind: %s", outValue);
GraalError.guarantee(keyValue.getPlatformKind().equals(AMD64Kind.QWORD), "Invalid keyValue kind: %s", keyValue);
GraalError.guarantee(counterValue.getPlatformKind().equals(AMD64Kind.QWORD), "Invalid counterValue kind: %s", counterValue);
GraalError.guarantee(lenValue.getPlatformKind().equals(AMD64Kind.DWORD), "Invalid lenValue kind: %s", lenValue);
GraalError.guarantee(encryptedCounterValue.getPlatformKind().equals(AMD64Kind.QWORD), "Invalid encryptedCounterValue kind: %s", encryptedCounterValue);
GraalError.guarantee(usedPtrValue.getPlatformKind().equals(AMD64Kind.QWORD), "Invalid usedPtrValue kind: %s", usedPtrValue);
GraalError.guarantee(resultValue.getPlatformKind().equals(AMD64Kind.DWORD), "Invalid resultValue kind: %s", resultValue);
Register from = asRegister(inValue);
Register to = asRegister(outValue);
Register key = asRegister(keyValue);
Register counter = asRegister(counterValue);
Register lenReg = asRegister(resultValue);
Register savedEncCounterStart = asRegister(encryptedCounterValue);
Register usedAddr = asRegister(usedPtrValue);
Register used = r11;
Register pos = rax;
Register xmmCounterShufMask = xmm0;
// used temporarily to swap key bytes up front
Register xmmKeyShufMask = xmm1;
Register xmmCurrCounter = xmm2;
Register xmmKeyTmp0 = xmm3;
Register xmmKeyTmp1 = xmm4;
// registers holding the four results in the parallelized loop
Register xmmResult0 = xmm5;
Register xmmResult1 = xmm6;
Register xmmResult2 = xmm7;
Register xmmResult3 = xmm8;
Register xmmResult4 = xmm9;
Register xmmResult5 = xmm10;
Register xmmFrom0 = xmm11;
Register xmmFrom1 = xmm12;
Register xmmFrom2 = xmm13;
Register xmmFrom3 = xmm14;
// reuse xmm3~4. Because xmmKeyTmp0~1 are useless when loading input text
Register xmmFrom4 = xmm3;
Register xmmFrom5 = xmm4;
// for key_128, key_192, key_256
int[] rounds = new int[]{10, 12, 14};
Label labelExitPreLoop = new Label();
Label labelPreLoopStart = new Label();
Label[] labelMultiBlockLoopTop = newLabels(3);
Label[] labelSingleBlockLoopTop = newLabels(3);
// for 6 blocks
Label[][] labelIncCounter = newLabels(3, 6);
// for single block, key128, key192, key256
Label[] labelIncCounterSingle = newLabels(3);
Label[] labelProcessTailInsr = newLabels(3);
Label[] labelProcessTail4Insr = newLabels(3);
Label[] labelProcessTail2Insr = newLabels(3);
Label[] labelProcessTail1Insr = newLabels(3);
Label[] labelProcessTailExitInsr = newLabels(3);
Label[] labelProcessTail4Extr = newLabels(3);
Label[] labelProcessTail2Extr = newLabels(3);
Label[] labelProcessTail1Extr = newLabels(3);
Label[] labelProcessTailExitExtr = newLabels(3);
Label labelExit = new Label();
masm.movl(lenReg, asRegister(lenValue));
masm.movl(used, new AMD64Address(usedAddr));
// initialize counter with initial counter
masm.movdqu(xmmCurrCounter, new AMD64Address(counter));
// pos as scratch
masm.movdqu(xmmCounterShufMask, recordExternalAddress(crb, counterShuffleMask));
// counter is shuffled
masm.pshufb(AVXSize.XMM, xmmCurrCounter, xmmCounterShufMask);
masm.movq(pos, 0);
// Use the partially used encrpyted counter from last invocation
masm.bind(labelPreLoopStart);
masm.cmplAndJcc(used, 16, ConditionFlag.AboveEqual, labelExitPreLoop, false);
masm.cmplAndJcc(lenReg, 0, ConditionFlag.LessEqual, labelExitPreLoop, false);
masm.movb(rbx, new AMD64Address(savedEncCounterStart, used, Stride.S1));
masm.xorb(rbx, new AMD64Address(from, pos, Stride.S1));
masm.movb(new AMD64Address(to, pos, Stride.S1), rbx);
masm.addq(pos, 1);
masm.addl(used, 1);
masm.subl(lenReg, 1);
masm.jmp(labelPreLoopStart);
masm.bind(labelExitPreLoop);
masm.movl(new AMD64Address(usedAddr), used);
// key length could be only {11, 13, 15} * 4 = {44, 52, 60}
masm.movdqu(xmmKeyShufMask, recordExternalAddress(crb, keyShuffleMask));
masm.movl(rbx, new AMD64Address(key, lengthOffset));
masm.cmplAndJcc(rbx, 52, ConditionFlag.Equal, labelMultiBlockLoopTop[1], false);
masm.cmplAndJcc(rbx, 60, ConditionFlag.Equal, labelMultiBlockLoopTop[2], false);
// k == 0 : generate code for key_128
// k == 1 : generate code for key_192
// k == 2 : generate code for key_256
for (int k = 0; k < 3; k++) {
// multi blocks starts here
masm.align(preferredLoopAlignment(crb));
masm.bind(labelMultiBlockLoopTop[k]);
// see if at least PARALLEL_FACTOR blocks left
masm.cmplAndJcc(lenReg, PARALLEL_FACTOR * AES_BLOCK_SIZE, ConditionFlag.LessEqual, labelSingleBlockLoopTop[k], false);
loadKey(masm, xmmKeyTmp0, key, 0x00, xmmKeyShufMask);
// load, then increase counters
applyCTRDoSix(masm::movdqa, xmmCurrCounter);
incCounter(masm, rbx, xmmResult1, 0x01, labelIncCounter[k][0]);
incCounter(masm, rbx, xmmResult2, 0x02, labelIncCounter[k][1]);
incCounter(masm, rbx, xmmResult3, 0x03, labelIncCounter[k][2]);
incCounter(masm, rbx, xmmResult4, 0x04, labelIncCounter[k][3]);
incCounter(masm, rbx, xmmResult5, 0x05, labelIncCounter[k][4]);
incCounter(masm, rbx, xmmCurrCounter, 0x06, labelIncCounter[k][5]);
// after increased, shuffled counters back for PXOR
applyCTRDoSix((dst, src) -> masm.pshufb(AVXSize.XMM, dst, src), xmmCounterShufMask);
// PXOR with Round 0 key
applyCTRDoSix(masm::pxor, xmmKeyTmp0);
// load two ROUND_KEYs at a time
for (int i = 1; i < rounds[k];) {
loadKey(masm, xmmKeyTmp1, key, i * 0x10, xmmKeyShufMask);
loadKey(masm, xmmKeyTmp0, key, (i + 1) * 0x10, xmmKeyShufMask);
applyCTRDoSix(masm::aesenc, xmmKeyTmp1);
i++;
if (i != rounds[k]) {
applyCTRDoSix(masm::aesenc, xmmKeyTmp0);
} else {
applyCTRDoSix(masm::aesenclast, xmmKeyTmp0);
}
i++;
}
// get next PARALLEL_FACTOR blocks into xmmResult registers
masm.movdqu(xmmFrom0, new AMD64Address(from, pos, Stride.S1, 0 * AES_BLOCK_SIZE));
masm.movdqu(xmmFrom1, new AMD64Address(from, pos, Stride.S1, 1 * AES_BLOCK_SIZE));
masm.movdqu(xmmFrom2, new AMD64Address(from, pos, Stride.S1, 2 * AES_BLOCK_SIZE));
masm.movdqu(xmmFrom3, new AMD64Address(from, pos, Stride.S1, 3 * AES_BLOCK_SIZE));
masm.movdqu(xmmFrom4, new AMD64Address(from, pos, Stride.S1, 4 * AES_BLOCK_SIZE));
masm.movdqu(xmmFrom5, new AMD64Address(from, pos, Stride.S1, 5 * AES_BLOCK_SIZE));
masm.pxor(xmmResult0, xmmFrom0);
masm.pxor(xmmResult1, xmmFrom1);
masm.pxor(xmmResult2, xmmFrom2);
masm.pxor(xmmResult3, xmmFrom3);
masm.pxor(xmmResult4, xmmFrom4);
masm.pxor(xmmResult5, xmmFrom5);
// store 6 results into the next 64 bytes of output
masm.movdqu(new AMD64Address(to, pos, Stride.S1, 0 * AES_BLOCK_SIZE), xmmResult0);
masm.movdqu(new AMD64Address(to, pos, Stride.S1, 1 * AES_BLOCK_SIZE), xmmResult1);
masm.movdqu(new AMD64Address(to, pos, Stride.S1, 2 * AES_BLOCK_SIZE), xmmResult2);
masm.movdqu(new AMD64Address(to, pos, Stride.S1, 3 * AES_BLOCK_SIZE), xmmResult3);
masm.movdqu(new AMD64Address(to, pos, Stride.S1, 4 * AES_BLOCK_SIZE), xmmResult4);
masm.movdqu(new AMD64Address(to, pos, Stride.S1, 5 * AES_BLOCK_SIZE), xmmResult5);
// increase the length of crypt text
masm.addq(pos, PARALLEL_FACTOR * AES_BLOCK_SIZE);
// decrease the remaining length
masm.subl(lenReg, PARALLEL_FACTOR * AES_BLOCK_SIZE);
masm.jmp(labelMultiBlockLoopTop[k]);
// singleBlock starts here
masm.align(preferredLoopAlignment(crb));
masm.bind(labelSingleBlockLoopTop[k]);
masm.cmplAndJcc(lenReg, 0, ConditionFlag.LessEqual, labelExit, false);
loadKey(masm, xmmKeyTmp0, key, 0x00, xmmKeyShufMask);
masm.movdqa(xmmResult0, xmmCurrCounter);
incCounter(masm, rbx, xmmCurrCounter, 0x01, labelIncCounterSingle[k]);
masm.pshufb(AVXSize.XMM, xmmResult0, xmmCounterShufMask);
masm.pxor(xmmResult0, xmmKeyTmp0);
for (int i = 1; i < rounds[k]; i++) {
loadKey(masm, xmmKeyTmp0, key, i * 0x10, xmmKeyShufMask);
masm.aesenc(xmmResult0, xmmKeyTmp0);
}
loadKey(masm, xmmKeyTmp0, key, rounds[k] * 0x10, xmmKeyShufMask);
masm.aesenclast(xmmResult0, xmmKeyTmp0);
masm.cmplAndJcc(lenReg, AES_BLOCK_SIZE, ConditionFlag.Less, labelProcessTailInsr[k], false);
masm.movdqu(xmmFrom0, new AMD64Address(from, pos, Stride.S1, 0 * AES_BLOCK_SIZE));
masm.pxor(xmmResult0, xmmFrom0);
masm.movdqu(new AMD64Address(to, pos, Stride.S1, 0 * AES_BLOCK_SIZE), xmmResult0);
masm.addq(pos, AES_BLOCK_SIZE);
masm.subl(lenReg, AES_BLOCK_SIZE);
masm.jmp(labelSingleBlockLoopTop[k]);
// Process the tail part of the input array
masm.bind(labelProcessTailInsr[k]);
// 1. Insert bytes from src array into xmmFrom0 register
masm.addq(pos, lenReg);
masm.testlAndJcc(lenReg, 8, ConditionFlag.Zero, labelProcessTail4Insr[k], false);
masm.subq(pos, 8);
masm.pinsrq(xmmFrom0, new AMD64Address(from, pos, Stride.S1), 0);
masm.bind(labelProcessTail4Insr[k]);
masm.testlAndJcc(lenReg, 4, ConditionFlag.Zero, labelProcessTail2Insr[k], false);
masm.subq(pos, 4);
masm.pslldq(xmmFrom0, 4);
masm.pinsrd(xmmFrom0, new AMD64Address(from, pos, Stride.S1), 0);
masm.bind(labelProcessTail2Insr[k]);
masm.testlAndJcc(lenReg, 2, ConditionFlag.Zero, labelProcessTail1Insr[k], false);
masm.subq(pos, 2);
masm.pslldq(xmmFrom0, 2);
masm.pinsrw(xmmFrom0, new AMD64Address(from, pos, Stride.S1), 0);
masm.bind(labelProcessTail1Insr[k]);
masm.testlAndJcc(lenReg, 1, ConditionFlag.Zero, labelProcessTailExitInsr[k], false);
masm.subq(pos, 1);
masm.pslldq(xmmFrom0, 1);
masm.pinsrb(xmmFrom0, new AMD64Address(from, pos, Stride.S1), 0);
masm.bind(labelProcessTailExitInsr[k]);
// 2. Perform pxor of the encrypted counter and plaintext Bytes.
// Also the encrypted counter is saved for next invocation.
masm.movdqu(new AMD64Address(savedEncCounterStart), xmmResult0);
masm.pxor(xmmResult0, xmmFrom0);
// 3. Extract bytes from xmmResult0 into the dest. array
masm.testlAndJcc(lenReg, 8, ConditionFlag.Zero, labelProcessTail4Extr[k], false);
masm.pextrq(new AMD64Address(to, pos, Stride.S1), xmmResult0, 0);
masm.psrldq(xmmResult0, 8);
masm.addq(pos, 8);
masm.bind(labelProcessTail4Extr[k]);
masm.testlAndJcc(lenReg, 4, ConditionFlag.Zero, labelProcessTail2Extr[k], false);
masm.pextrd(new AMD64Address(to, pos, Stride.S1), xmmResult0, 0);
masm.psrldq(xmmResult0, 4);
masm.addq(pos, 4);
masm.bind(labelProcessTail2Extr[k]);
masm.testlAndJcc(lenReg, 2, ConditionFlag.Zero, labelProcessTail1Extr[k], false);
masm.pextrw(new AMD64Address(to, pos, Stride.S1), xmmResult0, 0);
masm.psrldq(xmmResult0, 2);
masm.addq(pos, 2);
masm.bind(labelProcessTail1Extr[k]);
masm.testlAndJcc(lenReg, 1, ConditionFlag.Zero, labelProcessTailExitExtr[k], false);
masm.pextrb(new AMD64Address(to, pos, Stride.S1), xmmResult0, 0);
masm.bind(labelProcessTailExitExtr[k]);
masm.movl(new AMD64Address(usedAddr), lenReg);
masm.jmp(labelExit);
}
masm.bind(labelExit);
// counter is shuffled back.
masm.pshufb(AVXSize.XMM, xmmCurrCounter, xmmCounterShufMask);
masm.movdqu(new AMD64Address(counter), xmmCurrCounter); // save counter back
masm.movl(asRegister(resultValue), asRegister(lenValue));
}
private static void incCounter(AMD64MacroAssembler masm, Register reg, Register xmmdst, int incDelta, Label nextBlock) {
masm.pextrq(reg, xmmdst, 0x00);
masm.addq(reg, incDelta);
masm.pinsrq(xmmdst, reg, 0x00);
masm.jcc(ConditionFlag.CarryClear, nextBlock); // jump if no carry
masm.pextrq(reg, xmmdst, 0x01); // Carry-> D1
masm.addq(reg, 0x01);
masm.pinsrq(xmmdst, reg, 0x01);
masm.bind(nextBlock); // next instruction
}
private static void applyCTRDoSix(BiConsumer op, Register src) {
Register xmmResult0 = xmm5;
Register xmmResult1 = xmm6;
Register xmmResult2 = xmm7;
Register xmmResult3 = xmm8;
Register xmmResult4 = xmm9;
Register xmmResult5 = xmm10;
op.accept(xmmResult0, src);
op.accept(xmmResult1, src);
op.accept(xmmResult2, src);
op.accept(xmmResult3, src);
op.accept(xmmResult4, src);
op.accept(xmmResult5, src);
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy