All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.parboiled.transform.CachingGenerator Maven / Gradle / Ivy

There is a newer version: 1.2.0
Show newest version
/*
 * Copyright (c) 2009-2010 Ken Wenzel and Mathias Doenitz
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

package org.parboiled.transform;

import org.objectweb.asm.Type;
import org.objectweb.asm.tree.AbstractInsnNode;
import org.objectweb.asm.tree.FieldInsnNode;
import org.objectweb.asm.tree.FieldNode;
import org.objectweb.asm.tree.InsnList;
import org.objectweb.asm.tree.InsnNode;
import org.objectweb.asm.tree.IntInsnNode;
import org.objectweb.asm.tree.JumpInsnNode;
import org.objectweb.asm.tree.LabelNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.TypeInsnNode;
import org.objectweb.asm.tree.VarInsnNode;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static org.parboiled.common.Preconditions.checkArgNotNull;
import static org.parboiled.common.Preconditions.checkState;
import static org.parboiled.common.Utils.toObjectArray;
import static org.objectweb.asm.Opcodes.AASTORE;
import static org.objectweb.asm.Opcodes.ACC_PRIVATE;
import static org.objectweb.asm.Opcodes.ALOAD;
import static org.objectweb.asm.Opcodes.ANEWARRAY;
import static org.objectweb.asm.Opcodes.ARETURN;
import static org.objectweb.asm.Opcodes.ASTORE;
import static org.objectweb.asm.Opcodes.BIPUSH;
import static org.objectweb.asm.Opcodes.CHECKCAST;
import static org.objectweb.asm.Opcodes.DLOAD;
import static org.objectweb.asm.Opcodes.DUP;
import static org.objectweb.asm.Opcodes.DUP_X1;
import static org.objectweb.asm.Opcodes.DUP_X2;
import static org.objectweb.asm.Opcodes.FLOAD;
import static org.objectweb.asm.Opcodes.GETFIELD;
import static org.objectweb.asm.Opcodes.IFNONNULL;
import static org.objectweb.asm.Opcodes.IFNULL;
import static org.objectweb.asm.Opcodes.ILOAD;
import static org.objectweb.asm.Opcodes.INVOKESPECIAL;
import static org.objectweb.asm.Opcodes.INVOKESTATIC;
import static org.objectweb.asm.Opcodes.INVOKEVIRTUAL;
import static org.objectweb.asm.Opcodes.LLOAD;
import static org.objectweb.asm.Opcodes.NEW;
import static org.objectweb.asm.Opcodes.POP;
import static org.objectweb.asm.Opcodes.PUTFIELD;
import static org.objectweb.asm.Opcodes.SWAP;

/**
 * Wraps the method code with caching and proxying constructs.
 */
class CachingGenerator implements RuleMethodProcessor {

    private ParserClassNode classNode;
    private RuleMethod method;
    private InsnList instructions;
    private AbstractInsnNode current;
    private String cacheFieldName;

    public boolean appliesTo(ParserClassNode classNode, RuleMethod method) {
        checkArgNotNull(classNode, "classNode");
        checkArgNotNull(method, "method");
        return method.hasCachedAnnotation();
    }

    public void process(ParserClassNode classNode, RuleMethod method) throws Exception {
        checkArgNotNull(classNode, "classNode");
        checkArgNotNull(method, "method");
        checkState(!method.isSuperMethod()); // super methods have flag moved to the overriding method

        this.classNode = classNode;
        this.method = method;
        this.instructions = method.instructions;
        this.current = instructions.getFirst();

        generateCacheHitReturn();
        generateStoreNewProxyMatcher();
        seekToReturnInstruction();
        generateArmProxyMatcher();
        generateStoreInCache();
    }

    // if ( != null) return ;
    private void generateCacheHitReturn() {
        // stack:
        generateGetFromCache();
        // stack: 
        insert(new InsnNode(DUP));
        // stack:  :: 
        LabelNode cacheMissLabel = new LabelNode();
        insert(new JumpInsnNode(IFNULL, cacheMissLabel));
        // stack: 
        insert(new InsnNode(ARETURN));
        // stack: 
        insert(cacheMissLabel);
        // stack: 
        insert(new InsnNode(POP));
        // stack:
    }

    @SuppressWarnings( {"unchecked"})
    private void generateGetFromCache() {
        Type[] paramTypes = Type.getArgumentTypes(method.desc);
        cacheFieldName = findUnusedCacheFieldName();

        // if we have no parameters we use a simple Rule field as cache, otherwise a HashMap
        String cacheFieldDesc = paramTypes.length == 0 ? Types.RULE_DESC : "Ljava/util/HashMap;";
        classNode.fields.add(new FieldNode(ACC_PRIVATE, cacheFieldName, cacheFieldDesc, null, null));

        // stack:
        insert(new VarInsnNode(ALOAD, 0));
        // stack: 
        insert(new FieldInsnNode(GETFIELD, classNode.name, cacheFieldName, cacheFieldDesc));
        // stack: 

        if (paramTypes.length == 0) return; // if we have no parameters we are done

        // generate: if ( == null)  = new HashMap();

        // stack: 
        insert(new InsnNode(DUP));
        // stack:  :: 
        LabelNode alreadyInitialized = new LabelNode();
        insert(new JumpInsnNode(IFNONNULL, alreadyInitialized));
        // stack: 
        insert(new InsnNode(POP));
        // stack:
        insert(new VarInsnNode(ALOAD, 0));
        // stack: 
        insert(new TypeInsnNode(NEW, "java/util/HashMap"));
        // stack:  :: 
        insert(new InsnNode(DUP_X1));
        // stack:  ::  :: 
        insert(new InsnNode(DUP));
        // stack:  ::  ::  :: 
        insert(new MethodInsnNode(INVOKESPECIAL, "java/util/HashMap", "", "()V", false));
        // stack:  ::  :: 
        insert(new FieldInsnNode(PUTFIELD, classNode.name, cacheFieldName, cacheFieldDesc));
        // stack: 
        insert(alreadyInitialized);
        // stack: 

        // if we have more than one parameter or the parameter is an array we have to wrap with our Arguments class
        // since we need to unroll all inner arrays and apply custom hashCode(...) and equals(...) implementations
        if (paramTypes.length > 1 || paramTypes[0].getSort() == Type.ARRAY) {
            // generate: push new Arguments(new Object[] {})

            String arguments = Type.getInternalName(Arguments.class);
            // stack: 
            insert(new TypeInsnNode(NEW, arguments));
            // stack:  :: 
            insert(new InsnNode(DUP));
            // stack:  ::  :: 
            generatePushNewParameterObjectArray(paramTypes);
            // stack:  ::  ::  :: 
            insert(new MethodInsnNode(INVOKESPECIAL, arguments, "", "([Ljava/lang/Object;)V", false));
            // stack:  :: 
        } else {
            // stack: 
            generatePushParameterAsObject(paramTypes, 0);
            // stack:  :: 
        }

        // generate: .get(...)

        // stack:  :: 
        insert(new InsnNode(DUP));
        // stack:  ::  :: 
        insert(new VarInsnNode(ASTORE, method.maxLocals));
        // stack:  :: 
        insert(new MethodInsnNode(INVOKEVIRTUAL, "java/util/HashMap", "get", "(Ljava/lang/Object;)Ljava/lang/Object;", false));
        // stack: 
        insert(new TypeInsnNode(CHECKCAST, Types.RULE.getInternalName()));
        // stack: 
    }

    @SuppressWarnings( {"unchecked"})
    private String findUnusedCacheFieldName() {
        String name = "cache$" + method.name;
        int i = 2;
        while (hasField(name)) {
            name = "cache$" + method.name + i++;
        }
        return name;
    }

    public boolean hasField(String fieldName) {
        for (Object field : classNode.fields) {
            if (fieldName.equals(((FieldNode) field).name)) return true;
        }
        return false;
    }

    private void generatePushNewParameterObjectArray(Type[] paramTypes) {
        // stack: ...
        insert(new IntInsnNode(BIPUSH, paramTypes.length));
        // stack: ... :: 
        insert(new TypeInsnNode(ANEWARRAY, "java/lang/Object"));
        // stack: ... :: 

        for (int i = 0; i < paramTypes.length; i++) {
            // stack: ... :: 
            insert(new InsnNode(DUP));
            // stack: ... ::  :: 
            insert(new IntInsnNode(BIPUSH, i));
            // stack: ... ::  ::  :: 
            generatePushParameterAsObject(paramTypes, i);
            // stack: ... ::  ::  ::  :: 
            insert(new InsnNode(AASTORE));
            // stack: ... :: 
        }
        // stack: ... :: 
    }

    private void generatePushParameterAsObject(Type[] paramTypes, int parameterNr) {
        switch (paramTypes[parameterNr++].getSort()) {
            case Type.BOOLEAN:
                insert(new VarInsnNode(ILOAD, parameterNr));
                insert(new MethodInsnNode(INVOKESTATIC, "java/lang/Boolean", "valueOf", "(Z)Ljava/lang/Boolean;", false));
                return;
            case Type.CHAR:
                insert(new VarInsnNode(ILOAD, parameterNr));
                insert(new MethodInsnNode(INVOKESTATIC, "java/lang/Character", "valueOf", "(C)Ljava/lang/Character;", false));
                return;
            case Type.BYTE:
                insert(new VarInsnNode(ILOAD, parameterNr));
                insert(new MethodInsnNode(INVOKESTATIC, "java/lang/Byte", "valueOf", "(B)Ljava/lang/Byte;", false));
                return;
            case Type.SHORT:
                insert(new VarInsnNode(ILOAD, parameterNr));
                insert(new MethodInsnNode(INVOKESTATIC, "java/lang/Short", "valueOf", "(S)Ljava/lang/Short;", false));
                return;
            case Type.INT:
                insert(new VarInsnNode(ILOAD, parameterNr));
                insert(new MethodInsnNode(INVOKESTATIC, "java/lang/Integer", "valueOf", "(I)Ljava/lang/Integer;", false));
                return;
            case Type.FLOAT:
                insert(new VarInsnNode(FLOAD, parameterNr));
                insert(new MethodInsnNode(INVOKESTATIC, "java/lang/Float", "valueOf", "(F)Ljava/lang/Float;", false));
                return;
            case Type.LONG:
                insert(new VarInsnNode(LLOAD, parameterNr));
                insert(new MethodInsnNode(INVOKESTATIC, "java/lang/Long", "valueOf", "(J)Ljava/lang/Long;", false));
                return;
            case Type.DOUBLE:
                insert(new VarInsnNode(DLOAD, parameterNr));
                insert(new MethodInsnNode(INVOKESTATIC, "java/lang/Double", "valueOf", "(D)Ljava/lang/Double;", false));
                return;
            case Type.ARRAY:
            case Type.OBJECT:
                insert(new VarInsnNode(ALOAD, parameterNr));
                return;
            case Type.VOID:
            default:
                throw new IllegalStateException();
        }
    }

    //  = new ProxyMatcher();
    private void generateStoreNewProxyMatcher() {
        String proxyMatcherType = Types.PROXY_MATCHER.getInternalName();

        // stack:
        insert(new TypeInsnNode(NEW, proxyMatcherType));
        // stack: 
        insert(new InsnNode(DUP));
        // stack:  :: 
        insert(new MethodInsnNode(INVOKESPECIAL, proxyMatcherType, "", "()V", false));
        // stack: 
        generateStoreInCache();
        // stack: 
    }

    private void seekToReturnInstruction() {
        while (current.getOpcode() != ARETURN) {
            current = current.getNext();
        }
    }

    // .arm()
    private void generateArmProxyMatcher() {
        String proxyMatcherType = Types.PROXY_MATCHER.getInternalName();

        // stack:  :: 
        insert(new InsnNode(DUP_X1));
        // stack:  ::  :: 
        insert(new TypeInsnNode(CHECKCAST, Types.MATCHER.getInternalName()));
        // stack:  ::  :: 
        insert(new MethodInsnNode(INVOKEVIRTUAL, proxyMatcherType, "arm", '(' + Types.MATCHER_DESC + ")V", false));
        // stack: 
    }

    private void generateStoreInCache() {
        Type[] paramTypes = Type.getArgumentTypes(method.desc);

        // stack: 
        insert(new InsnNode(DUP));
        // stack:  :: 

        if (paramTypes.length == 0) {
            // stack:  :: 
            insert(new VarInsnNode(ALOAD, 0));
            // stack:  ::  :: 
            insert(new InsnNode(SWAP));
            // stack:  ::  :: 
            insert(new FieldInsnNode(PUTFIELD, classNode.name, cacheFieldName, Types.RULE_DESC));
            // stack: 
            return;
        }

        // stack:  :: 
        insert(new VarInsnNode(ALOAD, method.maxLocals));
        // stack:  ::  :: 
        insert(new InsnNode(SWAP));
        // stack:  ::  :: 
        insert(new VarInsnNode(ALOAD, 0));
        // stack:  ::  ::  :: 
        insert(new FieldInsnNode(GETFIELD, classNode.name, cacheFieldName, "Ljava/util/HashMap;"));
        // stack:  ::  ::  :: 
        insert(new InsnNode(DUP_X2));
        // stack:  ::  ::  ::  :: 
        insert(new InsnNode(POP));
        // stack:  ::  ::  :: 
        insert(new MethodInsnNode(INVOKEVIRTUAL, "java/util/HashMap", "put",
                "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", false));
        // stack:  :: 
        insert(new InsnNode(POP));
        // stack: 
    }

    private void insert(AbstractInsnNode instruction) {
        instructions.insertBefore(current, instruction);
    }

    public static class Arguments {
        private final Object[] params;

        public Arguments(Object[] params) {
            // we need to "unroll" all inner Object arrays
            List list = new ArrayList();
            unroll(params, list);
            this.params = list.toArray();
        }

        private void unroll(Object[] params, List list) {
            for (Object param : params) {
                if (param != null && param.getClass().isArray()) {
                    switch (Type.getType(param.getClass().getComponentType()).getSort()) {
                        case Type.BOOLEAN:
                            unroll(toObjectArray((boolean[]) param), list);
                            continue;
                        case Type.BYTE:
                            unroll(toObjectArray((byte[]) param), list);
                            continue;
                        case Type.CHAR:
                            unroll(toObjectArray((char[]) param), list);
                            continue;
                        case Type.DOUBLE:
                            unroll(toObjectArray((double[]) param), list);
                            continue;
                        case Type.FLOAT:
                            unroll(toObjectArray((float[]) param), list);
                            continue;
                        case Type.INT:
                            unroll(toObjectArray((int[]) param), list);
                            continue;
                        case Type.LONG:
                            unroll(toObjectArray((long[]) param), list);
                            continue;
                        case Type.SHORT:
                            unroll(toObjectArray((short[]) param), list);
                            continue;
                        case Type.OBJECT:
                        case Type.ARRAY:
                            unroll((Object[]) param, list);
                            continue;
                        default:
                            throw new IllegalStateException();
                    }
                }
                list.add(param);
            }
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) return true;
            if (!(o instanceof Arguments)) return false;
            Arguments that = (Arguments) o;
            return Arrays.equals(params, that.params);
        }

        @Override
        public int hashCode() {
            return params != null ? Arrays.hashCode(params) : 0;
        }
    }
}