All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.teavm.model.optimization.Inlining Maven / Gradle / Ivy

There is a newer version: 0.2.8
Show newest version
/*
 *  Copyright 2016 Alexey Andreev.
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *       http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package org.teavm.model.optimization;

import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.ObjectIntHashMap;
import com.carrotsearch.hppc.ObjectIntMap;
import com.carrotsearch.hppc.cursors.ObjectCursor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.teavm.dependency.DependencyInfo;
import org.teavm.model.BasicBlock;
import org.teavm.model.BasicBlockReader;
import org.teavm.model.ClassHierarchy;
import org.teavm.model.ClassReader;
import org.teavm.model.ElementModifier;
import org.teavm.model.Incoming;
import org.teavm.model.InliningInfo;
import org.teavm.model.Instruction;
import org.teavm.model.ListableClassReaderSource;
import org.teavm.model.MethodReader;
import org.teavm.model.MethodReference;
import org.teavm.model.Phi;
import org.teavm.model.Program;
import org.teavm.model.ProgramReader;
import org.teavm.model.TextLocation;
import org.teavm.model.TryCatchBlock;
import org.teavm.model.VariableReader;
import org.teavm.model.analysis.ClassInference;
import org.teavm.model.instructions.AbstractInstructionReader;
import org.teavm.model.instructions.AssignInstruction;
import org.teavm.model.instructions.ExitInstruction;
import org.teavm.model.instructions.InvocationType;
import org.teavm.model.instructions.InvokeInstruction;
import org.teavm.model.instructions.JumpInstruction;
import org.teavm.model.util.BasicBlockMapper;
import org.teavm.model.util.InstructionVariableMapper;
import org.teavm.model.util.ProgramUtils;
import org.teavm.model.util.TransitionExtractor;
import org.teavm.runtime.Fiber;

public class Inlining {
    private IntArrayList depthsByBlock;
    private Set instructionsToSkip;
    private ClassHierarchy hierarchy;
    private ListableClassReaderSource classes;
    private DependencyInfo dependencyInfo;
    private InliningStrategy strategy;
    private MethodUsageCounter usageCounter;
    private Set methodsUsedOnce = new HashSet<>();
    private boolean devirtualization;
    private ClassInference classInference;
    private InliningFilterFactory filterFactory;

    public Inlining(ClassHierarchy hierarchy, DependencyInfo dependencyInfo, InliningStrategy strategy,
            ListableClassReaderSource classes, Predicate externalMethods,
            boolean devirtualization, InliningFilterFactory filterFactory) {
        this.hierarchy = hierarchy;
        this.classes = classes;
        this.dependencyInfo = dependencyInfo;
        this.strategy = strategy;
        this.devirtualization = devirtualization;
        this.filterFactory = filterFactory;
        usageCounter = new MethodUsageCounter(externalMethods);

        for (String className : classes.getClassNames()) {
            ClassReader cls = classes.get(className);
            for (MethodReader method : cls.getMethods()) {
                ProgramReader program = method.getProgram();
                if (program != null) {
                    usageCounter.currentMethod = method.getReference();
                    for (BasicBlockReader block : program.getBasicBlocks()) {
                        block.readAllInstructions(usageCounter);
                    }
                }
            }
        }

        for (ObjectCursor cursor : usageCounter.methodUsageCount.keys()) {
            if (usageCounter.methodUsageCount.get(cursor.value) == 1) {
                methodsUsedOnce.add(cursor.value);
            }
        }
    }

    public List getOrder() {
        List order = new ArrayList<>();
        Set visited = new HashSet<>();
        for (String className : classes.getClassNames()) {
            ClassReader cls = classes.get(className);
            for (MethodReader method : cls.getMethods()) {
                if (method.getProgram() != null) {
                    computeOrder(method.getReference(), order, visited);
                }
            }
        }
        Collections.reverse(order);
        return order;
    }

    private void computeOrder(MethodReference method, List order, Set visited) {
        if (!visited.add(method)) {
            return;
        }
        Set invokedMethods = usageCounter.methodDependencies.get(method);
        if (invokedMethods != null) {
            for (MethodReference invokedMethod : invokedMethods) {
                computeOrder(invokedMethod, order, visited);
            }
        }
        order.add(method);
    }

    public boolean hasUsages(MethodReference method) {
        return usageCounter.methodUsageCount.getOrDefault(method, -1) != 0;
    }

    public void removeUsages(Program program) {
        for (BasicBlock block : program.getBasicBlocks()) {
            for (Instruction instruction : block) {
                if (!(instruction instanceof InvokeInstruction)) {
                    continue;
                }

                InvokeInstruction invoke = (InvokeInstruction) instruction;
                if (invoke.getType() != InvocationType.SPECIAL) {
                    continue;
                }

                int usageCount = usageCounter.methodUsageCount.getOrDefault(invoke.getMethod(), -1);
                if (usageCount > 0) {
                    usageCounter.methodUsageCount.put(invoke.getMethod(), usageCount - 1);
                }
            }
        }
    }

    public void apply(Program program, MethodReference method) {
        depthsByBlock = new IntArrayList(program.basicBlockCount());
        for (int i = 0; i < program.basicBlockCount(); ++i) {
            depthsByBlock.add(0);
        }
        instructionsToSkip = new HashSet<>();

        if (devirtualization) {
            while (applyOnce(program, method)) {
                devirtualize(program, method, dependencyInfo);
            }
        } else {
            applyOnce(program, method);
        }
        depthsByBlock = null;
        instructionsToSkip = null;

        new UnreachableBasicBlockEliminator().optimize(program);
    }

    private boolean applyOnce(Program program, MethodReference method) {
        InliningStep step = strategy.start(method, program);
        if (step == null) {
            return false;
        }
        List plan = buildPlan(program, -1, step, method, null);
        if (plan.isEmpty()) {
            return false;
        }
        execPlan(program, plan, 0);
        return true;
    }

    private void execPlan(Program program, List plan, int offset) {
        for (PlanEntry entry : plan) {
            execPlanEntry(program, entry, offset);
        }
    }

    private void execPlanEntry(Program program, PlanEntry planEntry, int offset) {
        int usageCount = usageCounter.methodUsageCount.getOrDefault(planEntry.method, -1);
        if (usageCount > 0) {
            usageCounter.methodUsageCount.put(planEntry.method, usageCount - 1);
        }

        BasicBlock block = program.basicBlockAt(planEntry.targetBlock + offset);
        InvokeInstruction invoke = (InvokeInstruction) planEntry.targetInstruction;
        BasicBlock splitBlock = program.createBasicBlock();
        BasicBlock firstInlineBlock = program.createBasicBlock();
        Program inlineProgram = planEntry.program;
        for (int i = 1; i < inlineProgram.basicBlockCount(); ++i) {
            program.createBasicBlock();
        }
        while (depthsByBlock.size() < program.basicBlockCount()) {
            depthsByBlock.add(planEntry.depth + 1);
        }

        int variableOffset = program.variableCount();
        for (int i = 0; i < inlineProgram.variableCount(); ++i) {
            program.createVariable();
        }

        while (planEntry.targetInstruction.getNext() != null) {
            Instruction insn = planEntry.targetInstruction.getNext();
            insn.delete();
            splitBlock.add(insn);
        }
        splitBlock.getTryCatchBlocks().addAll(ProgramUtils.copyTryCatches(block, program));

        invoke.delete();
        JumpInstruction jumpToInlinedProgram = new JumpInstruction();
        jumpToInlinedProgram.setTarget(firstInlineBlock);
        block.add(jumpToInlinedProgram);

        InliningInfoMerger inliningInfoMerger = new InliningInfoMerger(planEntry.locationInfo);

        for (int i = 0; i < inlineProgram.basicBlockCount(); ++i) {
            BasicBlock blockToInline = inlineProgram.basicBlockAt(i);
            BasicBlock inlineBlock = program.basicBlockAt(firstInlineBlock.getIndex() + i);
            while (blockToInline.getFirstInstruction() != null) {
                Instruction insn = blockToInline.getFirstInstruction();
                insn.delete();
                inlineBlock.add(insn);

                if (insn instanceof InvokeInstruction) {
                    InvokeInstruction invokeInsn = (InvokeInstruction) insn;
                    if (invokeInsn.getType() == InvocationType.SPECIAL) {
                        usageCount = usageCounter.methodUsageCount.getOrDefault(invokeInsn.getMethod(), -1);
                        if (usageCount >= 0) {
                            usageCounter.methodUsageCount.put(invokeInsn.getMethod(), usageCount + 1);
                        }
                    }
                }

                TextLocation location = insn.getLocation();
                if (location == null) {
                    location = TextLocation.EMPTY;
                }
                location = new TextLocation(location.getFileName(), location.getLine(),
                        inliningInfoMerger.merge(location.getInlining()));
                insn.setLocation(location);
            }

            List phis = new ArrayList<>(blockToInline.getPhis());
            blockToInline.getPhis().clear();
            inlineBlock.getPhis().addAll(phis);

            List tryCatches = new ArrayList<>(blockToInline.getTryCatchBlocks());
            blockToInline.getTryCatchBlocks().clear();
            inlineBlock.getTryCatchBlocks().addAll(tryCatches);

            inlineBlock.setExceptionVariable(blockToInline.getExceptionVariable());
        }

        BasicBlockMapper blockMapper = new BasicBlockMapper((BasicBlock b) ->
                program.basicBlockAt(b.getIndex() + firstInlineBlock.getIndex()));
        InstructionVariableMapper variableMapper = new InstructionVariableMapper(var -> {
            if (var.getIndex() == 0) {
                return invoke.getInstance();
            } else if (var.getIndex() <= invoke.getArguments().size()) {
                return invoke.getArguments().get(var.getIndex() - 1);
            } else {
                return program.variableAt(var.getIndex() + variableOffset);
            }
        });

        List resultVariables = new ArrayList<>();
        for (int i = 0; i < inlineProgram.basicBlockCount(); ++i) {
            BasicBlock mappedBlock = program.basicBlockAt(firstInlineBlock.getIndex() + i);
            blockMapper.transform(mappedBlock);
            variableMapper.apply(mappedBlock);
            mappedBlock.getTryCatchBlocks().addAll(ProgramUtils.copyTryCatches(block, program));
            Instruction lastInsn = mappedBlock.getLastInstruction();
            if (lastInsn instanceof ExitInstruction) {
                ExitInstruction exit = (ExitInstruction) lastInsn;
                JumpInstruction exitReplacement = new JumpInstruction();
                exitReplacement.setTarget(splitBlock);
                exitReplacement.setLocation(exit.getLocation());
                exit.replace(exitReplacement);
                if (exit.getValueToReturn() != null) {
                    Incoming resultIncoming = new Incoming();
                    resultIncoming.setSource(mappedBlock);
                    resultIncoming.setValue(exit.getValueToReturn());
                    resultVariables.add(resultIncoming);
                }
            }
        }

        if (!resultVariables.isEmpty() && invoke.getReceiver() != null) {
            if (resultVariables.size() == 1) {
                AssignInstruction resultAssignment = new AssignInstruction();
                resultAssignment.setReceiver(invoke.getReceiver());
                resultAssignment.setAssignee(resultVariables.get(0).getValue());
                splitBlock.addFirst(resultAssignment);
            } else {
                Phi resultPhi = new Phi();
                resultPhi.setReceiver(invoke.getReceiver());
                resultPhi.getIncomings().addAll(resultVariables);
                splitBlock.getPhis().add(resultPhi);
            }
        }

        TransitionExtractor transitionExtractor = new TransitionExtractor();
        Instruction splitLastInsn = splitBlock.getLastInstruction();
        if (splitLastInsn != null) {
            splitLastInsn.acceptVisitor(transitionExtractor);
            if (transitionExtractor.getTargets() != null) {
                List incomings = Arrays.stream(transitionExtractor.getTargets())
                        .flatMap(bb -> bb.getPhis().stream())
                        .flatMap(phi -> phi.getIncomings().stream())
                        .filter(incoming -> incoming.getSource() == block)
                        .collect(Collectors.toList());
                for (Incoming incoming : incomings) {
                    incoming.setSource(splitBlock);
                }
            }
        }

        execPlan(program, planEntry.innerPlan, firstInlineBlock.getIndex());
    }

    private List buildPlan(Program program, int depth, InliningStep step, MethodReference method,
            InliningInfo inliningInfo) {
        List plan = new ArrayList<>();
        int originalDepth = depth;
        InliningFilter filter = filterFactory.createFilter(method);

        ContextImpl context = new ContextImpl();
        for (BasicBlock block : program.getBasicBlocks()) {
            if (!block.getTryCatchBlocks().isEmpty()) {
                continue;
            }

            if (originalDepth < 0) {
                depth = depthsByBlock.get(block.getIndex());
            }

            for (Instruction insn : block) {
                if (instructionsToSkip.contains(insn)) {
                    continue;
                }

                if (!(insn instanceof InvokeInstruction)) {
                    continue;
                }
                InvokeInstruction invoke = (InvokeInstruction) insn;
                if (invoke.getType() == InvocationType.VIRTUAL) {
                    continue;
                }

                if (invoke.getMethod().getClassName().equals(Fiber.class.getName())
                        != method.getClassName().equals(Fiber.class.getName())) {
                    continue;
                }
                if (!filter.apply(invoke.getMethod())) {
                    continue;
                }

                MethodReader invokedMethod = getMethod(invoke.getMethod());
                if (invokedMethod == null || invokedMethod.getProgram() == null
                        || invokedMethod.getProgram().basicBlockCount() == 0
                        || invokedMethod.hasModifier(ElementModifier.SYNCHRONIZED)) {
                    instructionsToSkip.add(insn);
                    continue;
                }

                context.depth = depth;
                InliningStep innerStep = step.tryInline(invokedMethod.getReference(), invokedMethod.getProgram(),
                        context);
                if (innerStep == null) {
                    instructionsToSkip.add(insn);
                    continue;
                }
                Program invokedProgram = ProgramUtils.copy(invokedMethod.getProgram());

                TextLocation location = insn.getLocation();
                InliningInfo innerInliningInfo = new InliningInfo(
                        invoke.getMethod(),
                        location != null ? location.getFileName() : null,
                        location != null ? location.getLine() : -1,
                        inliningInfo);

                PlanEntry entry = new PlanEntry();
                entry.targetBlock = block.getIndex();
                entry.targetInstruction = insn;
                entry.program = invokedProgram;
                entry.innerPlan.addAll(buildPlan(invokedProgram, depth + 1, innerStep, invokedMethod.getReference(),
                        innerInliningInfo));
                entry.depth = depth;
                entry.method = invokedMethod.getReference();
                entry.locationInfo = innerInliningInfo;
                plan.add(entry);
            }
        }
        Collections.reverse(plan);

        return plan;
    }

    private MethodReader getMethod(MethodReference methodRef) {
        ClassReader cls = classes.get(methodRef.getClassName());
        return cls != null ? cls.getMethod(methodRef.getDescriptor()) : null;
    }

    private void devirtualize(Program program, MethodReference method, DependencyInfo dependencyInfo) {
        if (classInference == null) {
            classInference = new ClassInference(dependencyInfo, hierarchy, classes.getClassNames(), 30);
        }
        classInference.infer(program, method);

        for (BasicBlock block : program.getBasicBlocks()) {
            for (Instruction instruction : block) {
                if (!(instruction instanceof InvokeInstruction)) {
                    continue;
                }
                InvokeInstruction invoke = (InvokeInstruction) instruction;
                if (invoke.getType() != InvocationType.VIRTUAL) {
                    continue;
                }

                Set implementations = new HashSet<>();
                if (classInference.isOverflow(invoke.getInstance().getIndex())) {
                    List knownImplementations = classInference.getMethodImplementations(
                            invoke.getMethod().getDescriptor());
                    if (knownImplementations != null) {
                        implementations.addAll(knownImplementations);
                    }
                } else {
                    for (String className : classInference.classesOf(invoke.getInstance().getIndex())) {
                        MethodReference rawMethod = new MethodReference(className, invoke.getMethod().getDescriptor());
                        MethodReader resolvedMethod = dependencyInfo.getClassSource().resolveImplementation(rawMethod);
                        if (resolvedMethod != null) {
                            implementations.add(resolvedMethod.getReference());
                        }
                    }
                }

                if (implementations.size() == 1) {
                    invoke.setType(InvocationType.SPECIAL);
                    invoke.setMethod(implementations.iterator().next());
                }
            }
        }
    }

    static class PlanEntry {
        int targetBlock;
        Instruction targetInstruction;
        MethodReference method;
        Program program;
        int depth;
        final List innerPlan = new ArrayList<>();
        InliningInfo locationInfo;
    }

    static class MethodUsageCounter extends AbstractInstructionReader {
        ObjectIntMap methodUsageCount = new ObjectIntHashMap<>();
        Map> methodDependencies = new LinkedHashMap<>();
        Predicate externalMethods;
        MethodReference currentMethod;

        MethodUsageCounter(Predicate externalMethods) {
            this.externalMethods = externalMethods;
        }

        @Override
        public void invoke(VariableReader receiver, VariableReader instance, MethodReference method,
                List arguments, InvocationType type) {
            if (type == InvocationType.SPECIAL && !externalMethods.test(method)) {
                methodUsageCount.put(method, methodUsageCount.get(method) + 1);
                methodDependencies.computeIfAbsent(currentMethod, k -> new LinkedHashSet<>()).add(method);
            }
        }
    }

    class ContextImpl implements InliningContext {
        int depth;

        @Override
        public boolean isUsedOnce(MethodReference method) {
            return methodsUsedOnce.contains(method);
        }

        @Override
        public ProgramReader getProgram(MethodReference method) {
            ClassReader cls = classes.get(method.getClassName());
            if (cls == null) {
                return null;
            }
            MethodReader methodReader = cls.getMethod(method.getDescriptor());
            return methodReader != null ? methodReader.getProgram() : null;
        }

        @Override
        public int getDepth() {
            return depth;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy