All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.jetbrains.jet.codegen.inline.MethodInliner Maven / Gradle / Ivy

There is a newer version: 2.0.0
Show newest version
/*
 * Copyright 2010-2014 JetBrains s.r.o.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.jetbrains.jet.codegen.inline;

import com.google.common.collect.Lists;
import com.intellij.util.ArrayUtil;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.jet.codegen.ClosureCodegen;
import org.jetbrains.jet.codegen.StackValue;
import org.jetbrains.jet.codegen.state.JetTypeMapper;
import org.jetbrains.org.objectweb.asm.Label;
import org.jetbrains.org.objectweb.asm.MethodVisitor;
import org.jetbrains.org.objectweb.asm.Opcodes;
import org.jetbrains.org.objectweb.asm.Type;
import org.jetbrains.org.objectweb.asm.commons.InstructionAdapter;
import org.jetbrains.org.objectweb.asm.commons.Method;
import org.jetbrains.org.objectweb.asm.commons.RemappingMethodAdapter;
import org.jetbrains.org.objectweb.asm.tree.*;
import org.jetbrains.org.objectweb.asm.tree.analysis.*;

import java.util.*;

import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.getReturnType;
import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.isAnonymousConstructorCall;
import static org.jetbrains.jet.codegen.inline.InlineCodegenUtil.isInvokeOnLambda;

public class MethodInliner {

    private final MethodNode node;

    private final Parameters parameters;

    private final InliningContext inliningContext;

    private final FieldRemapper nodeRemapper;

    private final boolean isSameModule;

    private final String errorPrefix;

    private final JetTypeMapper typeMapper;

    private final List invokeCalls = new ArrayList();

    //keeps order
    private final List constructorInvocations = new ArrayList();
    //current state
    private final Map currentTypeMapping = new HashMap();

    private final InlineResult result;

    /*
     *
     * @param node
     * @param parameters
     * @param inliningContext
     * @param lambdaType - in case on lambda 'invoke' inlining
     */
    public MethodInliner(
            @NotNull MethodNode node,
            @NotNull Parameters parameters,
            @NotNull InliningContext parent,
            @NotNull FieldRemapper nodeRemapper,
            boolean isSameModule,
            @NotNull String errorPrefix
    ) {
        this.node = node;
        this.parameters = parameters;
        this.inliningContext = parent;
        this.nodeRemapper = nodeRemapper;
        this.isSameModule = isSameModule;
        this.errorPrefix = errorPrefix;
        this.typeMapper = parent.state.getTypeMapper();
        this.result = InlineResult.create();
    }

    public InlineResult doInline(
            @NotNull MethodVisitor adapter,
            @NotNull LocalVarRemapper remapper,
            boolean remapReturn,
            @NotNull LabelOwner labelOwner
    ) {
        //analyze body
        MethodNode transformedNode = markPlacesForInlineAndRemoveInlinable(node);

        //substitute returns with "goto end" instruction to keep non local returns in lambdas
        Label end = new Label();
        transformedNode = doInline(transformedNode);
        removeClosureAssertions(transformedNode);
        InsnList instructions = transformedNode.instructions;
        instructions.resetLabels();

        MethodNode resultNode = new MethodNode(InlineCodegenUtil.API, transformedNode.access, transformedNode.name, transformedNode.desc,
                                         transformedNode.signature, ArrayUtil.toStringArray(transformedNode.exceptions));
        RemapVisitor visitor = new RemapVisitor(resultNode, remapper, nodeRemapper);
        try {
            transformedNode.accept(visitor);
        }
        catch (Exception e) {
            throw wrapException(e, transformedNode, "couldn't inline method call");
        }

        resultNode.visitLabel(end);
        processReturns(resultNode, labelOwner, remapReturn, end);

        //flush transformed node to output
        resultNode.accept(new InliningInstructionAdapter(adapter));

        return result;
    }

    private MethodNode doInline(MethodNode node) {

        final Deque currentInvokes = new LinkedList(invokeCalls);

        MethodNode resultNode = new MethodNode(node.access, node.name, node.desc, node.signature, null);

        final Iterator iterator = constructorInvocations.iterator();

        RemappingMethodAdapter remappingMethodAdapter = new RemappingMethodAdapter(resultNode.access, resultNode.desc, resultNode,
                                                                                   new TypeRemapper(currentTypeMapping));

        InlineAdapter lambdaInliner = new InlineAdapter(remappingMethodAdapter, parameters.totalSize()) {

            private ConstructorInvocation invocation;
            @Override
            public void anew(@NotNull Type type) {
                if (isAnonymousConstructorCall(type.getInternalName(), "")) {
                    invocation = iterator.next();

                    if (invocation.shouldRegenerate()) {
                        //TODO: need poping of type but what to do with local funs???
                        Type newLambdaType = Type.getObjectType(inliningContext.nameGenerator.genLambdaClassName());
                        currentTypeMapping.put(invocation.getOwnerInternalName(), newLambdaType.getInternalName());
                        AnonymousObjectTransformer transformer =
                                new AnonymousObjectTransformer(invocation.getOwnerInternalName(),
                                                               inliningContext
                                                                       .subInlineWithClassRegeneration(
                                                                               inliningContext.nameGenerator,
                                                                               currentTypeMapping,
                                                                               invocation),
                                                               isSameModule, newLambdaType
                                );

                        InlineResult transformResult = transformer.doTransform(invocation, nodeRemapper);
                        result.addAllClassesToRemove(transformResult);

                        if (inliningContext.isInliningLambda) {
                            //this class is transformed and original not used so we should remove original one after inlining
                            result.addClassToRemove(invocation.getOwnerInternalName());
                        }
                    }
                }

                //in case of regenerated invocation type would be remapped to new one via remappingMethodAdapter
                super.anew(type);
            }

            @Override
            public void visitMethodInsn(int opcode, String owner, String name, String desc, boolean itf) {
                if (/*INLINE_RUNTIME.equals(owner) &&*/ isInvokeOnLambda(owner, name)) { //TODO add method
                    assert !currentInvokes.isEmpty();
                    InvokeCall invokeCall = currentInvokes.remove();
                    LambdaInfo info = invokeCall.lambdaInfo;

                    if (info == null) {
                        //noninlinable lambda
                        super.visitMethodInsn(opcode, owner, name, desc, itf);
                        return;
                    }

                    int valueParamShift = getNextLocalIndex();//NB: don't inline cause it changes
                    putStackValuesIntoLocals(info.getInvokeParamsWithoutCaptured(), valueParamShift, this, desc);

                    Parameters lambdaParameters = info.addAllParameters(nodeRemapper);

                    InlinedLambdaRemapper newCapturedRemapper =
                            new InlinedLambdaRemapper(info.getLambdaClassType().getInternalName(), nodeRemapper, lambdaParameters);

                    setLambdaInlining(true);
                    MethodInliner inliner = new MethodInliner(info.getNode(), lambdaParameters,
                                                              inliningContext.subInlineLambda(info),
                                                              newCapturedRemapper, true /*cause all calls in same module as lambda*/,
                                                              "Lambda inlining " + info.getLambdaClassType().getInternalName());

                    LocalVarRemapper remapper = new LocalVarRemapper(lambdaParameters, valueParamShift);
                    InlineResult lambdaResult = inliner.doInline(this.mv, remapper, true, info);//TODO add skipped this and receiver
                    result.addAllClassesToRemove(lambdaResult);

                    //return value boxing/unboxing
                    Method bridge =
                            typeMapper.mapSignature(ClosureCodegen.getErasedInvokeFunction(info.getFunctionDescriptor())).getAsmMethod();
                    Method delegate = typeMapper.mapSignature(info.getFunctionDescriptor()).getAsmMethod();
                    StackValue.onStack(delegate.getReturnType()).put(bridge.getReturnType(), this);
                    setLambdaInlining(false);
                }
                else if (isAnonymousConstructorCall(owner, name)) { //TODO add method
                    assert invocation != null : " call not corresponds to new call" + owner + " " + name;
                    if (invocation.shouldRegenerate()) {
                        //put additional captured parameters on stack
                        for (CapturedParamDesc capturedParamDesc : invocation.getAllRecapturedParameters()) {
                            visitFieldInsn(Opcodes.GETSTATIC, capturedParamDesc.getContainingLambdaName(),
                                           "$$$" + capturedParamDesc.getFieldName(), capturedParamDesc.getType().getDescriptor());
                        }
                        super.visitMethodInsn(opcode, invocation.getNewLambdaType().getInternalName(), name, invocation.getNewConstructorDescriptor(), itf);
                        invocation = null;
                    } else {
                        super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc, itf);
                    }
                }
                else {
                    super.visitMethodInsn(opcode, changeOwnerForExternalPackage(owner, opcode), name, desc, itf);
                }
            }

        };

        node.accept(lambdaInliner);

        return resultNode;
    }

    @NotNull
    public static CapturedParamInfo findCapturedField(FieldInsnNode node, FieldRemapper fieldRemapper) {
        assert node.name.startsWith("$$$") : "Captured field template should start with $$$ prefix";
        FieldInsnNode fin = new FieldInsnNode(node.getOpcode(), node.owner, node.name.substring(3), node.desc);
        CapturedParamInfo field = fieldRemapper.findField(fin);
        if (field == null) {
            throw new IllegalStateException("Couldn't find captured field " + node.owner + "." + node.name + " in " + fieldRemapper.getLambdaInternalName());
        }
        return field;
    }

    @NotNull
    public MethodNode prepareNode(@NotNull MethodNode node) {
        final int capturedParamsSize = parameters.getCaptured().size();
        final int realParametersSize = parameters.getReal().size();
        Type[] types = Type.getArgumentTypes(node.desc);
        Type returnType = Type.getReturnType(node.desc);

        ArrayList capturedTypes = parameters.getCapturedTypes();
        Type[] allTypes = ArrayUtil.mergeArrays(types, capturedTypes.toArray(new Type[capturedTypes.size()]));

        node.instructions.resetLabels();
        MethodNode transformedNode = new MethodNode(InlineCodegenUtil.API, node.access, node.name, Type.getMethodDescriptor(returnType, allTypes), node.signature, null) {

            private final boolean isInliningLambda = nodeRemapper.isInsideInliningLambda();

            private int getNewIndex(int var) {
                return var + (var < realParametersSize ? 0 : capturedParamsSize);
            }

            @Override
            public void visitVarInsn(int opcode, int var) {
                super.visitVarInsn(opcode, getNewIndex(var));
            }

            @Override
            public void visitIincInsn(int var, int increment) {
                super.visitIincInsn(getNewIndex(var), increment);
            }

            @Override
            public void visitMaxs(int maxStack, int maxLocals) {
                super.visitMaxs(maxStack, maxLocals + capturedParamsSize);
            }

            @Override
            public void visitLineNumber(int line, @NotNull Label start) {
                if(isInliningLambda) {
                    super.visitLineNumber(line, start);
                }
            }

            @Override
            public void visitLocalVariable(
                    @NotNull String name, @NotNull String desc, String signature, @NotNull Label start, @NotNull Label end, int index
            ) {
                if (isInliningLambda) {
                    super.visitLocalVariable(name, desc, signature, start, end, getNewIndex(index));
                }
            }
        };

        node.accept(transformedNode);

        transformCaptured(transformedNode);

        return transformedNode;
    }

    @NotNull
    protected MethodNode markPlacesForInlineAndRemoveInlinable(@NotNull MethodNode node) {
        node = prepareNode(node);

        Analyzer analyzer = new Analyzer(new SourceInterpreter()) {
            @NotNull
            @Override
            protected Frame newFrame(
                    int nLocals, int nStack
            ) {
                return new Frame(nLocals, nStack) {
                    @Override
                    public void execute(
                            @NotNull AbstractInsnNode insn, Interpreter interpreter
                    ) throws AnalyzerException {
                        if (insn.getOpcode() == Opcodes.RETURN) {
                            //there is exception on void non local return in frame
                            return;
                        }
                        super.execute(insn, interpreter);
                    }
                };
            }
        };

        Frame[] sources;
        try {
            sources = analyzer.analyze("fake", node);
        }
        catch (AnalyzerException e) {
            throw wrapException(e, node, "couldn't inline method call");
        }

        AbstractInsnNode cur = node.instructions.getFirst();
        int index = 0;
        Set deadLabels = new HashSet();

        while (cur != null) {
            Frame frame = sources[index];

            if (frame != null) {
                if (cur.getType() == AbstractInsnNode.METHOD_INSN) {
                    MethodInsnNode methodInsnNode = (MethodInsnNode) cur;
                    String owner = methodInsnNode.owner;
                    String desc = methodInsnNode.desc;
                    String name = methodInsnNode.name;
                    //TODO check closure
                    int paramLength = Type.getArgumentTypes(desc).length + 1;//non static
                    if (isInvokeOnLambda(owner, name) /*&& methodInsnNode.owner.equals(INLINE_RUNTIME)*/) {
                        SourceValue sourceValue = frame.getStack(frame.getStackSize() - paramLength);

                        LambdaInfo lambdaInfo = null;
                        int varIndex = -1;

                        if (sourceValue.insns.size() == 1) {
                            AbstractInsnNode insnNode = sourceValue.insns.iterator().next();

                            lambdaInfo = getLambdaIfExists(insnNode);
                            if (lambdaInfo != null) {
                                //remove inlinable access
                                node.instructions.remove(insnNode);
                            }
                        }

                        invokeCalls.add(new InvokeCall(varIndex, lambdaInfo));
                    }
                    else if (isAnonymousConstructorCall(owner, name)) {
                        Map lambdaMapping = new HashMap();
                        int paramStart = frame.getStackSize() - paramLength;

                        for (int i = 0; i < paramLength; i++) {
                            SourceValue sourceValue = frame.getStack(paramStart + i);
                            if (sourceValue.insns.size() == 1) {
                                AbstractInsnNode insnNode = sourceValue.insns.iterator().next();
                                LambdaInfo lambdaInfo = getLambdaIfExists(insnNode);
                                if (lambdaInfo != null) {
                                    lambdaMapping.put(i, lambdaInfo);
                                    node.instructions.remove(insnNode);
                                }
                            }
                        }

                        constructorInvocations.add(new ConstructorInvocation(owner, desc, lambdaMapping, isSameModule, inliningContext.classRegeneration));
                    }
                }
            }

            AbstractInsnNode prevNode = cur;
            cur = cur.getNext();
            index++;

            //given frame is null if and only if the corresponding instruction cannot be reached (dead code).
            if (frame == null) {
                //clean dead code otherwise there is problems in unreachable finally block, don't touch label it cause try/catch/finally problems
                if (prevNode.getType() == AbstractInsnNode.LABEL) {
                    deadLabels.add((LabelNode) prevNode);
                } else {
                    node.instructions.remove(prevNode);
                }
            }
        }

        //clean dead try/catch blocks
        List blocks = node.tryCatchBlocks;
        for (Iterator iterator = blocks.iterator(); iterator.hasNext(); ) {
            TryCatchBlockNode block = iterator.next();
            if (deadLabels.contains(block.start) && deadLabels.contains(block.end)) {
                iterator.remove();
            }
        }

        return node;
    }

    public LambdaInfo getLambdaIfExists(AbstractInsnNode insnNode) {
        if (insnNode.getOpcode() == Opcodes.ALOAD) {
            int varIndex = ((VarInsnNode) insnNode).var;
            if (varIndex < parameters.totalSize()) {
                return parameters.get(varIndex).getLambda();
            }
        }
        else if (insnNode instanceof FieldInsnNode) {
            FieldInsnNode fieldInsnNode = (FieldInsnNode) insnNode;
            if (fieldInsnNode.name.startsWith("$$$")) {
                return findCapturedField(fieldInsnNode, nodeRemapper).getLambda();
            }
        }

        return null;
    }

    private static void removeClosureAssertions(MethodNode node) {
        AbstractInsnNode cur = node.instructions.getFirst();
        while (cur != null && cur.getNext() != null) {
            AbstractInsnNode next = cur.getNext();
            if (next.getType() == AbstractInsnNode.METHOD_INSN) {
                MethodInsnNode methodInsnNode = (MethodInsnNode) next;
                if (methodInsnNode.name.equals("checkParameterIsNotNull") && methodInsnNode.owner.equals("kotlin/jvm/internal/Intrinsics")) {
                    AbstractInsnNode prev = cur.getPrevious();

                    assert cur.getOpcode() == Opcodes.LDC : "checkParameterIsNotNull should go after LDC but " + cur;
                    assert prev.getOpcode() == Opcodes.ALOAD : "checkParameterIsNotNull should be invoked on local var but " + prev;

                    node.instructions.remove(prev);
                    node.instructions.remove(cur);
                    cur = next.getNext();
                    node.instructions.remove(next);
                    next = cur;
                }
            }
            cur = next;
        }
    }

    private void transformCaptured(@NotNull MethodNode node) {
        if (nodeRemapper.isRoot()) {
            return;
        }

        //Fold all captured variable chain - ALOAD 0 ALOAD this$0 GETFIELD $captured - to GETFIELD $$$$captured
        //On future decoding this field could be inline or unfolded in another field access chain (it can differ in some missed this$0)
        AbstractInsnNode cur = node.instructions.getFirst();
        while (cur != null) {
            if (cur instanceof VarInsnNode && cur.getOpcode() == Opcodes.ALOAD) {
                if (((VarInsnNode) cur).var == 0) {
                    List accessChain = getCapturedFieldAccessChain((VarInsnNode) cur);
                    AbstractInsnNode insnNode = nodeRemapper.foldFieldAccessChainIfNeeded(accessChain, node);
                    if (insnNode != null) {
                        cur = insnNode;
                    }
                }
            }
            cur = cur.getNext();
        }
    }

    @NotNull
    public static List getCapturedFieldAccessChain(@NotNull VarInsnNode aload0) {
        List fieldAccessChain = new ArrayList();
        fieldAccessChain.add(aload0);
        AbstractInsnNode next = aload0.getNext();
        while (next != null && next instanceof FieldInsnNode || next instanceof LabelNode) {
            if (next instanceof LabelNode) {
                next = next.getNext();
                continue; //it will be delete on transformation
            }
            fieldAccessChain.add(next);
            if ("this$0".equals(((FieldInsnNode) next).name)) {
                next = next.getNext();
            }
            else {
                break;
            }
        }

        return fieldAccessChain;
    }

    public static void putStackValuesIntoLocals(List directOrder, int shift, InstructionAdapter iv, String descriptor) {
        Type[] actualParams = Type.getArgumentTypes(descriptor);
        assert actualParams.length == directOrder.size() : "Number of expected and actual params should be equals!";

        int size = 0;
        for (Type next : directOrder) {
            size += next.getSize();
        }

        shift += size;
        int index = directOrder.size();

        for (Type next : Lists.reverse(directOrder)) {
            shift -= next.getSize();
            Type typeOnStack = actualParams[--index];
            if (!typeOnStack.equals(next)) {
                StackValue.onStack(typeOnStack).put(next, iv);
            }
            iv.store(shift, next);
        }
    }

    //TODO: check annotation on class - it's package part
    //TODO: check it's external module
    //TODO?: assert method exists in facade?
    public String changeOwnerForExternalPackage(String type, int opcode) {
        if (isSameModule || (opcode & Opcodes.INVOKESTATIC) == 0) {
            return type;
        }

        int i = type.indexOf('-');
        if (i >= 0) {
            return type.substring(0, i);
        }
        return type;
    }

    @NotNull
    public RuntimeException wrapException(@NotNull Exception originalException, @NotNull MethodNode node, @NotNull String errorSuffix) {
        if (originalException instanceof InlineException) {
            return new InlineException(errorPrefix + ": " + errorSuffix, originalException);
        } else {
            return new InlineException(errorPrefix + ": " + errorSuffix + "\ncause: " +
                                       InlineCodegen.getNodeText(node), originalException);
        }
    }

    @NotNull
    public static List processReturns(@NotNull MethodNode node, @NotNull LabelOwner labelOwner, boolean remapReturn, Label endLabel) {
        if (!remapReturn) {
            return Collections.emptyList();
        }
        List result = new ArrayList();
        InsnList instructions = node.instructions;
        AbstractInsnNode insnNode = instructions.getFirst();
        while (insnNode != null) {
            if (InlineCodegenUtil.isReturnOpcode(insnNode.getOpcode())) {
                AbstractInsnNode previous = insnNode.getPrevious();
                MethodInsnNode flagNode;
                boolean isLocalReturn = true;
                String labelName = null;
                if (previous != null && previous instanceof MethodInsnNode && InlineCodegenUtil.NON_LOCAL_RETURN.equals(((MethodInsnNode) previous).owner)) {
                    flagNode = (MethodInsnNode) previous;
                    labelName = flagNode.name;
                }

                if (labelName != null) {
                    isLocalReturn = labelOwner.isMyLabel(labelName);
                    //remove global return flag
                    if (isLocalReturn) {
                        instructions.remove(previous);
                    }
                }

                if (isLocalReturn && endLabel != null) {
                    LabelNode labelNode = (LabelNode) endLabel.info;
                    JumpInsnNode jumpInsnNode = new JumpInsnNode(Opcodes.GOTO, labelNode);
                    instructions.insert(insnNode, jumpInsnNode);
                    instructions.remove(insnNode);
                    insnNode = jumpInsnNode;
                }

                //genetate finally block before nonLocalReturn flag/return/goto
                result.add(new FinallyBlockInfo(isLocalReturn ? insnNode : insnNode.getPrevious(), getReturnType(insnNode.getOpcode())));
            }
            insnNode = insnNode.getNext();
        }
        return result;
    }

    public static class FinallyBlockInfo {

        final AbstractInsnNode beforeIns;

        final Type returnType;

        public FinallyBlockInfo(AbstractInsnNode beforeIns, Type returnType) {
            this.beforeIns = beforeIns;
            this.returnType = returnType;
        }

    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy