org.graalvm.compiler.lir.amd64.AMD64ArrayCompareToOp Maven / Gradle / Ivy
/*
* Copyright (c) 2017, 2020, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation. Oracle designates this
* particular file as subject to the "Classpath" exception as provided
* by Oracle in the LICENSE file that accompanied this code.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
package org.graalvm.compiler.lir.amd64;
import static jdk.vm.ci.amd64.AMD64.k7;
import static jdk.vm.ci.amd64.AMD64.rax;
import static jdk.vm.ci.amd64.AMD64.rcx;
import static jdk.vm.ci.amd64.AMD64.rdx;
import static jdk.vm.ci.code.ValueUtil.asRegister;
import static org.graalvm.compiler.lir.LIRInstruction.OperandFlag.ILLEGAL;
import static org.graalvm.compiler.lir.LIRInstruction.OperandFlag.REG;
import java.util.EnumSet;
import org.graalvm.compiler.asm.Label;
import org.graalvm.compiler.asm.amd64.AMD64Address;
import org.graalvm.compiler.asm.amd64.AMD64Assembler.ConditionFlag;
import org.graalvm.compiler.asm.amd64.AMD64MacroAssembler;
import org.graalvm.compiler.asm.amd64.AVXKind.AVXSize;
import org.graalvm.compiler.core.common.LIRKind;
import org.graalvm.compiler.core.common.Stride;
import org.graalvm.compiler.lir.LIRInstructionClass;
import org.graalvm.compiler.lir.Opcode;
import org.graalvm.compiler.lir.asm.CompilationResultBuilder;
import org.graalvm.compiler.lir.gen.LIRGeneratorTool;
import jdk.vm.ci.amd64.AMD64.CPUFeature;
import jdk.vm.ci.amd64.AMD64Kind;
import jdk.vm.ci.code.CodeUtil;
import jdk.vm.ci.code.Register;
import jdk.vm.ci.meta.Value;
/**
* Emits code which compares two arrays lexicographically. If the CPU supports any vector
* instructions specialized code is emitted to leverage these instructions.
*/
@Opcode("ARRAY_COMPARE_TO")
public final class AMD64ArrayCompareToOp extends AMD64ComplexVectorOp {
public static final LIRInstructionClass TYPE = LIRInstructionClass.create(AMD64ArrayCompareToOp.class);
private final Stride strideA;
private final Stride strideB;
private final int useAVX3Threshold;
@Def({REG}) protected Value resultValue;
@Alive({REG}) protected Value arrayAValue;
@Alive({REG}) protected Value arrayBValue;
@Use({REG}) protected Value lengthAValue;
@Use({REG}) protected Value lengthBValue;
@Temp({REG}) protected Value lengthAValueTemp;
@Temp({REG}) protected Value lengthBValueTemp;
@Temp({REG}) protected Value temp1;
@Temp({REG}) protected Value temp2;
@Temp({REG, ILLEGAL}) protected Value vectorTemp1;
@Temp({REG, ILLEGAL}) protected Value maskRegister;
public AMD64ArrayCompareToOp(LIRGeneratorTool tool, int useAVX3Threshold, Stride strideA, Stride strideB,
EnumSet runtimeCheckedCPUFeatures,
Value result,
Value arrayA, Value lengthA,
Value arrayB, Value lengthB) {
super(TYPE, tool, runtimeCheckedCPUFeatures, AVXSize.ZMM);
assert CodeUtil.isPowerOf2(useAVX3Threshold) : "AVX3Threshold must be power of 2";
this.useAVX3Threshold = useAVX3Threshold;
this.strideA = strideA;
this.strideB = strideB;
this.resultValue = result;
this.arrayAValue = arrayA;
this.arrayBValue = arrayB;
/*
* The length values are inputs but are also killed like temporaries so need both Use and
* Temp annotations, which will only work with fixed registers.
*/
this.lengthAValue = lengthAValueTemp = lengthA;
this.lengthBValue = lengthBValueTemp = lengthB;
// Allocate some temporaries.
this.temp1 = tool.newVariable(LIRKind.unknownReference(tool.target().arch.getWordKind()));
this.temp2 = tool.newVariable(LIRKind.unknownReference(tool.target().arch.getWordKind()));
// We only need the vector temporaries if we generate SSE code.
if (supports(tool.target(), runtimeCheckedCPUFeatures, CPUFeature.SSE4_2)) {
this.vectorTemp1 = tool.newVariable(LIRKind.value(AMD64Kind.DOUBLE));
} else {
this.vectorTemp1 = Value.ILLEGAL;
}
if (canUseAVX512Variant()) {
maskRegister = k7.asValue();
} else {
maskRegister = Value.ILLEGAL;
}
}
private boolean canUseAVX512Variant() {
return useAVX3Threshold == 0 && supportsAVX512VLBWAndZMM();
}
@Override
public void emitCode(CompilationResultBuilder crb, AMD64MacroAssembler masm) {
Register result = asRegister(resultValue);
Register str1 = asRegister(temp1);
Register str2 = asRegister(temp2);
// Load array base addresses.
masm.movq(str1, asRegister(arrayAValue));
masm.movq(str2, asRegister(arrayBValue));
Register cnt1 = asRegister(lengthAValue);
Register cnt2 = asRegister(lengthBValue);
Label labelLengthDiff = new Label();
Label labelPop = new Label();
Label labelDone = new Label();
Label labelWhileHead = new Label();
Label labelCompareWideVectorsLoopFailed = new Label(); // used only _LP64 && AVX3
int stride2x2 = 0x40;
Stride maxStride = Stride.max(strideA, strideB);
Stride scale1 = Stride.S1;
Stride scale2 = Stride.S2;
int elementsPerXMMVector = 16 >> maxStride.log2;
int elementsPerYMMVector = 32 >> maxStride.log2;
if (!(strideA == Stride.S1 && strideB == Stride.S1)) {
stride2x2 = 0x20;
}
if (strideA != strideB) {
masm.shrl(cnt2, 1);
}
// Compute the minimum of the string lengths and the
// difference of the string lengths (stack).
// Do the conditional move stuff
masm.movl(result, cnt1);
masm.subl(cnt1, cnt2);
masm.push(cnt1);
masm.cmovl(ConditionFlag.LessEqual, cnt2, result); // cnt2 = min(cnt1, cnt2)
// Is the minimum length zero?
masm.testlAndJcc(cnt2, cnt2, ConditionFlag.Zero, labelLengthDiff, false);
if (strideA == Stride.S1 && strideB == Stride.S1) {
// Load first bytes
masm.movzbl(result, new AMD64Address(str1, 0)); // result = str1[0]
masm.movzbl(cnt1, new AMD64Address(str2, 0)); // cnt1 = str2[0]
} else if (strideA == Stride.S2 && strideB == Stride.S2) {
// Load first characters
masm.movzwl(result, new AMD64Address(str1, 0));
masm.movzwl(cnt1, new AMD64Address(str2, 0));
} else {
masm.movzbl(result, new AMD64Address(str1, 0));
masm.movzwl(cnt1, new AMD64Address(str2, 0));
}
masm.sublAndJcc(result, cnt1, ConditionFlag.NotZero, labelPop, false);
if (strideA == Stride.S2 && strideB == Stride.S2) {
// Divide length by 2 to get number of chars
masm.shrl(cnt2, 1);
}
masm.cmplAndJcc(cnt2, 1, ConditionFlag.Equal, labelLengthDiff, false);
// Check if the strings start at the same location and setup scale and stride
if (strideA == strideB) {
masm.cmpqAndJcc(str1, str2, ConditionFlag.Equal, labelLengthDiff, false);
}
if (supportsAVX2AndYMM() && masm.supports(CPUFeature.SSE4_2)) {
Register vec1 = asRegister(vectorTemp1, AMD64Kind.DOUBLE);
Label labelCompareWideVectors = new Label();
Label labelVectorNotEqual = new Label();
Label labelCompareWideTail = new Label();
Label labelCompareSmallStr = new Label();
Label labelCompareWideVectorsLoop = new Label();
Label labelCompare16Chars = new Label();
Label labelCompareIndexChar = new Label();
Label labelCompareWideVectorsLoopAVX2 = new Label();
Label labelCompareTailLong = new Label();
Label labelCompareWideVectorsLoopAVX3 = new Label(); // used only _LP64 && AVX3
int pcmpmask = 0x19;
if (strideA == Stride.S1 && strideB == Stride.S1) {
pcmpmask &= ~0x01;
}
// Setup to compare 16-chars (32-bytes) vectors,
// start from first character again because it has aligned address.
assert result.equals(rax) && cnt2.equals(rdx) && cnt1.equals(rcx) : "pcmpestri";
// rax and rdx are used by pcmpestri as elements counters
masm.movl(result, cnt2);
masm.andlAndJcc(cnt2, ~(elementsPerYMMVector - 1), ConditionFlag.Zero, labelCompareTailLong, false);
// fast path : compare first 2 8-char vectors.
masm.bind(labelCompare16Chars);
if (strideA == strideB) {
masm.movdqu(vec1, new AMD64Address(str1, 0));
} else {
masm.pmovzxbw(vec1, new AMD64Address(str1, 0));
}
masm.pcmpestri(vec1, new AMD64Address(str2, 0), pcmpmask);
masm.jccb(ConditionFlag.Below, labelCompareIndexChar);
if (strideA == strideB) {
masm.movdqu(vec1, new AMD64Address(str1, 16));
} else {
masm.pmovzxbw(vec1, new AMD64Address(str1, 8));
}
masm.pcmpestri(vec1, new AMD64Address(str2, 16), pcmpmask);
masm.jccb(ConditionFlag.AboveEqual, labelCompareWideVectors);
masm.addl(cnt1, elementsPerXMMVector);
// Compare the characters at index in cnt1
masm.bind(labelCompareIndexChar); // cnt1 has the offset of the mismatching character
loadNextElements(masm, result, cnt2, str1, str2, maxStride, scale1, scale2, cnt1);
masm.subl(result, cnt2);
masm.jmp(labelPop);
// Setup the registers to start vector comparison loop
masm.bind(labelCompareWideVectors);
if (strideA == strideB) {
masm.leaq(str1, new AMD64Address(str1, result, maxStride));
masm.leaq(str2, new AMD64Address(str2, result, maxStride));
} else {
masm.leaq(str1, new AMD64Address(str1, result, scale1));
masm.leaq(str2, new AMD64Address(str2, result, scale2));
}
masm.subl(result, elementsPerYMMVector);
masm.sublAndJcc(cnt2, elementsPerYMMVector, ConditionFlag.Zero, labelCompareWideTail, false);
masm.negq(result);
// In a loop, compare 16-chars (32-bytes) at once using (vpxor+vptest)
masm.bind(labelCompareWideVectorsLoop);
// trying 64 bytes fast loop
if (canUseAVX512Variant()) {
masm.cmplAndJcc(cnt2, stride2x2, ConditionFlag.Below, labelCompareWideVectorsLoopAVX2, true);
// cnt2 holds the vector, not-zero means we cannot subtract by 0x40
masm.testlAndJcc(cnt2, stride2x2 - 1, ConditionFlag.NotZero, labelCompareWideVectorsLoopAVX2, true);
masm.bind(labelCompareWideVectorsLoopAVX3); // the hottest loop
if (strideA == strideB) {
masm.evmovdqu64(vec1, new AMD64Address(str1, result, maxStride));
// k7 == 11..11, if operands equal, otherwise k7 has some 0
masm.evpcmpeqb(k7, vec1, new AMD64Address(str2, result, maxStride));
} else {
masm.evpmovzxbw(vec1, new AMD64Address(str1, result, scale1));
// k7 == 11..11, if operands equal, otherwise k7 has some 0
masm.evpcmpeqb(k7, vec1, new AMD64Address(str2, result, scale2));
}
masm.kortestq(k7, k7);
masm.jcc(ConditionFlag.AboveEqual, labelCompareWideVectorsLoopFailed); // miscompare
masm.addq(result, stride2x2); // update since we already compared at this addr
// and sub the size too
masm.sublAndJcc(cnt2, stride2x2, ConditionFlag.NotZero, labelCompareWideVectorsLoopAVX3, true);
masm.vpxor(vec1, vec1, vec1, AVXSize.YMM);
masm.jmpb(labelCompareWideTail);
}
masm.bind(labelCompareWideVectorsLoopAVX2);
if (strideA == strideB) {
masm.vmovdqu(vec1, new AMD64Address(str1, result, maxStride));
masm.vpxor(vec1, vec1, new AMD64Address(str2, result, maxStride), AVXSize.YMM);
} else {
masm.vpmovzxbw(vec1, new AMD64Address(str1, result, scale1));
masm.vpxor(vec1, vec1, new AMD64Address(str2, result, scale2), AVXSize.YMM);
}
masm.vptest(vec1, vec1, AVXSize.YMM);
masm.jcc(ConditionFlag.NotZero, labelVectorNotEqual);
masm.addq(result, elementsPerYMMVector);
masm.sublAndJcc(cnt2, elementsPerYMMVector, ConditionFlag.NotZero, labelCompareWideVectorsLoop, false);
// clean upper bits of YMM registers
masm.vpxor(vec1, vec1, vec1, AVXSize.YMM);
// compare wide vectors tail
masm.bind(labelCompareWideTail);
masm.testqAndJcc(result, result, ConditionFlag.Zero, labelLengthDiff, false);
masm.movl(result, elementsPerYMMVector);
masm.movl(cnt2, result);
masm.negq(result);
masm.jmp(labelCompareWideVectorsLoopAVX2);
// Identifies the mismatching (higher or lower)16-bytes in the 32-byte vectors.
masm.bind(labelVectorNotEqual);
// clean upper bits of YMM registers
masm.vpxor(vec1, vec1, vec1, AVXSize.YMM);
if (strideA == strideB) {
masm.leaq(str1, new AMD64Address(str1, result, maxStride));
masm.leaq(str2, new AMD64Address(str2, result, maxStride));
} else {
masm.leaq(str1, new AMD64Address(str1, result, scale1));
masm.leaq(str2, new AMD64Address(str2, result, scale2));
}
masm.jmp(labelCompare16Chars);
// Compare tail chars, length between 1 to 15 chars
masm.bind(labelCompareTailLong);
masm.movl(cnt2, result);
masm.cmplAndJcc(cnt2, elementsPerXMMVector, ConditionFlag.Less, labelCompareSmallStr, false);
if (strideA == strideB) {
masm.movdqu(vec1, new AMD64Address(str1, 0));
} else {
masm.pmovzxbw(vec1, new AMD64Address(str1, 0));
}
masm.pcmpestri(vec1, new AMD64Address(str2, 0), pcmpmask);
masm.jcc(ConditionFlag.Below, labelCompareIndexChar);
masm.subqAndJcc(cnt2, elementsPerXMMVector, ConditionFlag.Zero, labelLengthDiff, false);
if (strideA == strideB) {
masm.leaq(str1, new AMD64Address(str1, result, maxStride));
masm.leaq(str2, new AMD64Address(str2, result, maxStride));
} else {
masm.leaq(str1, new AMD64Address(str1, result, scale1));
masm.leaq(str2, new AMD64Address(str2, result, scale2));
}
masm.negq(cnt2);
masm.jmpb(labelWhileHead);
masm.bind(labelCompareSmallStr);
} else if (masm.supports(CPUFeature.SSE4_2)) {
Register vec1 = asRegister(vectorTemp1, AMD64Kind.DOUBLE);
Label labelCompareWideVectors = new Label();
Label labelVectorNotEqual = new Label();
Label labelCompareTail = new Label();
int pcmpmask = 0x19;
// Setup to compare 8-char (16-byte) vectors,
// start from first character again because it has aligned address.
masm.movl(result, cnt2);
if (strideA == Stride.S1 && strideB == Stride.S1) {
pcmpmask &= ~0x01;
}
masm.andlAndJcc(cnt2, ~(elementsPerXMMVector - 1), ConditionFlag.Zero, labelCompareTail, false);
if (strideA == strideB) {
masm.leaq(str1, new AMD64Address(str1, result, maxStride));
masm.leaq(str2, new AMD64Address(str2, result, maxStride));
} else {
masm.leaq(str1, new AMD64Address(str1, result, scale1));
masm.leaq(str2, new AMD64Address(str2, result, scale2));
}
masm.negq(result);
// pcmpestri
// inputs:
// vec1- substring
// rax - negative string length (elements count)
// mem - scanned string
// rdx - string length (elements count)
// pcmpmask - cmp mode: 11000 (string compare with negated result)
// + 00 (unsigned bytes) or + 01 (unsigned shorts)
// outputs:
// rcx - first mismatched element index
assert result.equals(rax) && cnt2.equals(rdx) && cnt1.equals(rcx) : "pcmpestri";
masm.bind(labelCompareWideVectors);
if (strideA == strideB) {
masm.movdqu(vec1, new AMD64Address(str1, result, maxStride));
masm.pcmpestri(vec1, new AMD64Address(str2, result, maxStride), pcmpmask);
} else {
masm.pmovzxbw(vec1, new AMD64Address(str1, result, scale1));
masm.pcmpestri(vec1, new AMD64Address(str2, result, scale2), pcmpmask);
}
// After pcmpestri cnt1(rcx) contains mismatched element index
masm.jccb(ConditionFlag.Below, labelVectorNotEqual); // CF==1
masm.addq(result, elementsPerXMMVector);
masm.subqAndJcc(cnt2, elementsPerXMMVector, ConditionFlag.NotZero, labelCompareWideVectors, true);
// compare wide vectors tail
masm.testqAndJcc(result, result, ConditionFlag.Zero, labelLengthDiff, false);
masm.movl(cnt2, elementsPerXMMVector);
masm.movl(result, elementsPerXMMVector);
masm.negq(result);
if (strideA == strideB) {
masm.movdqu(vec1, new AMD64Address(str1, result, maxStride));
masm.pcmpestri(vec1, new AMD64Address(str2, result, maxStride), pcmpmask);
} else {
masm.pmovzxbw(vec1, new AMD64Address(str1, result, scale1));
masm.pcmpestri(vec1, new AMD64Address(str2, result, scale2), pcmpmask);
}
masm.jccb(ConditionFlag.AboveEqual, labelLengthDiff);
// Mismatched characters in the vectors
masm.bind(labelVectorNotEqual);
masm.addq(cnt1, result);
loadNextElements(masm, result, cnt2, str1, str2, maxStride, scale1, scale2, cnt1);
masm.subl(result, cnt2);
masm.jmpb(labelPop);
masm.bind(labelCompareTail); // limit is zero
masm.movl(cnt2, result);
// Fallthru to tail compare
}
// Shift str2 and str1 to the end of the arrays, negate min
if (strideA == strideB) {
masm.leaq(str1, new AMD64Address(str1, cnt2, maxStride));
masm.leaq(str2, new AMD64Address(str2, cnt2, maxStride));
} else {
masm.leaq(str1, new AMD64Address(str1, cnt2, scale1));
masm.leaq(str2, new AMD64Address(str2, cnt2, scale2));
}
masm.decrementl(cnt2); // first character was compared already
masm.negq(cnt2);
// Compare the rest of the elements
masm.bind(labelWhileHead);
loadNextElements(masm, result, cnt1, str1, str2, maxStride, scale1, scale2, cnt2);
masm.sublAndJcc(result, cnt1, ConditionFlag.NotZero, labelPop, true);
masm.incqAndJcc(cnt2, ConditionFlag.NotZero, labelWhileHead, true);
// Strings are equal up to min length. Return the length difference.
masm.bind(labelLengthDiff);
masm.pop(result);
if (strideA == Stride.S2 && strideB == Stride.S2) {
// Divide diff by 2 to get number of chars
masm.sarl(result, 1);
}
masm.jmpb(labelDone);
if (supportsAVX512VLBWAndZMM()) {
masm.bind(labelCompareWideVectorsLoopFailed);
masm.kmovq(cnt1, k7);
masm.notq(cnt1);
masm.bsfq(cnt2, cnt1);
// if (ae != StrIntrinsicNode::LL) {
if (!(strideA == Stride.S1 && strideB == Stride.S1)) {
// Divide diff by 2 to get number of chars
masm.sarl(cnt2, 1);
}
masm.addq(result, cnt2);
if (strideA == Stride.S1 && strideB == Stride.S1) {
masm.movzbl(cnt1, new AMD64Address(str2, result, Stride.S1));
masm.movzbl(result, new AMD64Address(str1, result, Stride.S1));
} else if (strideA == Stride.S2 && strideB == Stride.S2) {
masm.movzwl(cnt1, new AMD64Address(str2, result, maxStride));
masm.movzwl(result, new AMD64Address(str1, result, maxStride));
} else {
masm.movzwl(cnt1, new AMD64Address(str2, result, scale2));
masm.movzbl(result, new AMD64Address(str1, result, scale1));
}
masm.subl(result, cnt1);
masm.jmpb(labelPop);
}
// Discard the stored length difference
masm.bind(labelPop);
masm.pop(cnt1);
// That's it
masm.bind(labelDone);
if (strideA == Stride.S2 && strideB == Stride.S1) {
masm.negl(result);
}
}
private void loadNextElements(AMD64MacroAssembler masm, Register elem1, Register elem2, Register str1, Register str2,
Stride stride, Stride stride1,
Stride stride2, Register index) {
if (strideA == Stride.S1 && strideB == Stride.S1) {
masm.movzbl(elem1, new AMD64Address(str1, index, stride, 0));
masm.movzbl(elem2, new AMD64Address(str2, index, stride, 0));
} else if (strideA == Stride.S2 && strideB == Stride.S2) {
masm.movzwl(elem1, new AMD64Address(str1, index, stride, 0));
masm.movzwl(elem2, new AMD64Address(str2, index, stride, 0));
} else {
masm.movzbl(elem1, new AMD64Address(str1, index, stride1, 0));
masm.movzwl(elem2, new AMD64Address(str2, index, stride2, 0));
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy