All Downloads are FREE. Search and download functionalities are using the official Maven repository.

net.sandius.rembulan.compiler.gen.asm.RunMethod Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2016 Miroslav Janíček
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.sandius.rembulan.compiler.gen.asm;

import net.sandius.rembulan.compiler.gen.CodeSegmenter;
import net.sandius.rembulan.compiler.gen.SegmentedCode;
import net.sandius.rembulan.compiler.gen.asm.helpers.ASMUtils;
import net.sandius.rembulan.compiler.ir.BasicBlock;
import net.sandius.rembulan.compiler.ir.Label;
import net.sandius.rembulan.impl.DefaultSavedState;
import net.sandius.rembulan.runtime.ExecutionContext;
import net.sandius.rembulan.runtime.ResolvedControlThrowable;
import net.sandius.rembulan.runtime.Resumable;
import net.sandius.rembulan.runtime.UnresolvedControlThrowable;
import net.sandius.rembulan.util.Check;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.*;

import java.util.ArrayList;
import java.util.List;

import static org.objectweb.asm.Opcodes.*;

class RunMethod {

	public final int LV_CONTEXT = 1;
	public final int LV_RESUME = 2;
	public final int LV_VARARGS = 3;  // index of the varargs argument, if present

	public static final int ST_SHIFT_SEGMENT  = 24;
	public static final int ST_SHIFT_LABELIDX = 16;

	private final ASMBytecodeEmitter context;
	private final List methodNodes;
	private final boolean resumable;

	private final List closureFields;
	private final List constFields;

	interface LabelResolver {
		boolean isLocalLabel(Label l);
		int labelStateIndex(Label l);
	}

	static int labelStateIdx(SegmentedCode.LabelEntry le) {
		return (le.segmentIdx << ST_SHIFT_SEGMENT) | (le.idx << ST_SHIFT_LABELIDX);
	}

	public RunMethod(ASMBytecodeEmitter context) {
		this.context = Check.notNull(context);

		final SegmentedCode segmentedCode = CodeSegmenter.segment(
				context.fn.code(),
				context.compilerSettings.nodeSizeLimit());

		this.methodNodes = new ArrayList<>();

		this.closureFields = new ArrayList<>();
		this.constFields = new ArrayList<>();

		if (segmentedCode.isSingleton()) {
			// as before
			BytecodeEmitVisitor visitor = new BytecodeEmitVisitor(
					context, this, context.slots, context.types, closureFields, constFields, -1,
					new LabelResolver() {
						@Override
						public boolean isLocalLabel(Label l) {
							return true;
						}

						@Override
						public int labelStateIndex(Label l) {
							throw new IllegalStateException();
						}
					});

			this.methodNodes.add(emitSingletonRunMethod(visitor, segmentedCode.segments().get(0)));
			this.resumable = visitor.isResumable();
		}
		else {
			// split up into multiple segments

			boolean resumable = false;
			for (int i = 0; i < segmentedCode.segments().size(); i++) {

				final int thisSegmentIdx = i;

				BytecodeEmitVisitor visitor = new BytecodeEmitVisitor(
						context, this, context.slots, context.types, closureFields, constFields, i,
						new LabelResolver() {
							@Override
							public boolean isLocalLabel(Label l) {
								return segmentedCode.labelEntry(l).segmentIdx == thisSegmentIdx;
							}

							@Override
							public int labelStateIndex(Label l) {
								return labelStateIdx(segmentedCode.labelEntry(l));
							}
						});

				this.methodNodes.add(emitSegmentedSubRunMethod(i, visitor, segmentedCode.segments().get(i)));
				resumable |= visitor.isResumable();
			}

			this.resumable = resumable;

			this.methodNodes.add(emitSegmentedRunMethod(segmentedCode.segments().size()));

//			throw new UnsupportedOperationException();  // TODO
		}
	}

	public int numOfRegisters() {
		return context.slots.numSlots();
	}

	public int slotOffset() {
		return context.isVararg() ? LV_VARARGS + 1 : LV_VARARGS;
	}

	public boolean isResumable() {
		return resumable;
	}

	public String[] throwsExceptions() {
		return new String[] { Type.getInternalName(ResolvedControlThrowable.class) };
	}

	public boolean usesSnapshotMethod() {
		return isResumable();
	}

	private String snapshotMethodName() {
		return "snapshot";
	}

	private Type snapshotMethodType() {
		ArrayList args = new ArrayList<>();

		args.add(Type.INT_TYPE);
		if (context.isVararg()) {
			args.add(ASMUtils.arrayTypeFor(Object.class));
		}
		for (int i = 0; i < numOfRegisters(); i++) {
			args.add(Type.getType(Object.class));
		}
		return Type.getMethodType(context.savedStateClassType(), args.toArray(new Type[0]));
	}

	public MethodInsnNode snapshotMethodInvokeInsn() {
		return new MethodInsnNode(
				INVOKESPECIAL,
				context.thisClassType().getInternalName(),
				snapshotMethodName(),
				snapshotMethodType().getDescriptor(),
				false);
	}

	public MethodNode snapshotMethodNode() {
		MethodNode node = new MethodNode(
				ACC_PRIVATE,
				snapshotMethodName(),
				snapshotMethodType().getDescriptor(),
				null,
				null);

		InsnList il = node.instructions;
		LabelNode begin = new LabelNode();
		LabelNode end = new LabelNode();

		il.add(begin);

		il.add(new TypeInsnNode(NEW, Type.getInternalName(DefaultSavedState.class)));
		il.add(new InsnNode(DUP));

		// resumption point
		il.add(new VarInsnNode(ILOAD, 1));

		// registers
		int numRegs = numOfRegisters() + (context.isVararg() ? 1 : 0);
		int regOffset = context.isVararg() ? 3 : 2;

		il.add(ASMUtils.loadInt(numRegs));
		il.add(new TypeInsnNode(ANEWARRAY, Type.getInternalName(Object.class)));
		{
			for (int i = 0; i < numRegs; i++) {
				il.add(new InsnNode(DUP));
				il.add(ASMUtils.loadInt(i));
				il.add(new VarInsnNode(ALOAD, 2 + i));
				il.add(new InsnNode(AASTORE));
			}
		}

		il.add(ASMUtils.ctor(
				Type.getType(DefaultSavedState.class),
				Type.INT_TYPE,
				ASMUtils.arrayTypeFor(Object.class)));

		il.add(new InsnNode(ARETURN));

		il.add(end);

		List locals = node.localVariables;

		locals.add(new LocalVariableNode("this", context.thisClassType().getDescriptor(), null, begin, end, 0));
		locals.add(new LocalVariableNode("rp", Type.INT_TYPE.getDescriptor(), null, begin, end, 1));
		if (context.isVararg()) {
			locals.add(new LocalVariableNode("varargs", ASMUtils.arrayTypeFor(Object.class).getDescriptor(), null, begin, end, 2));
		}
		for (int i = 0; i < numOfRegisters(); i++) {
			locals.add(new LocalVariableNode("r_" + i, Type.getDescriptor(Object.class), null, begin, end, regOffset + i));
		}

		node.maxLocals = 2 + numOfRegisters();
		node.maxStack = 4 + 3;  // 4 to get register array at top, +3 to add element to it

		return node;
	}

	public String methodName() {
		return "run";
	}

	private Type methodType(Type returnType) {
		ArrayList args = new ArrayList<>();

		args.add(Type.getType(ExecutionContext.class));
		args.add(Type.INT_TYPE);
		if (context.isVararg()) {
			args.add(ASMUtils.arrayTypeFor(Object.class));
		}
		for (int i = 0; i < numOfRegisters(); i++) {
			args.add(Type.getType(Object.class));
		}
		return Type.getMethodType(returnType, args.toArray(new Type[0]));
	}

	public Type methodType() {
		return methodType(Type.VOID_TYPE);
	}

	private Type subMethodType() {
		return methodType(context.savedStateClassType());
	}

	public AbstractInsnNode methodInvokeInsn() {
		return new MethodInsnNode(
				INVOKESPECIAL,
				context.thisClassType().getInternalName(),
				methodName(),
				methodType().getDescriptor(),
				false);
	}

	private InsnList errorState(LabelNode label) {
		InsnList il = new InsnList();
		il.add(label);
		il.add(ASMUtils.frameSame());
		il.add(new TypeInsnNode(NEW, Type.getInternalName(IllegalStateException.class)));
		il.add(new InsnNode(DUP));
		il.add(ASMUtils.ctor(IllegalStateException.class));
		il.add(new InsnNode(ATHROW));
		return il;
	}

	private InsnList dispatchTable(List extLabels, List resumptionLabels, LabelNode errorStateLabel) {
		InsnList il = new InsnList();

		assert (!extLabels.isEmpty());

		ArrayList labels = new ArrayList<>();
		labels.addAll(extLabels);
		labels.addAll(resumptionLabels);
		LabelNode[] labelArray = labels.toArray(new LabelNode[labels.size()]);

		int min = 1 - extLabels.size();
		int max = resumptionLabels.size();

		il.add(new VarInsnNode(ILOAD, LV_RESUME));
		il.add(new TableSwitchInsnNode(min, max, errorStateLabel, labelArray));
		return il;
	}

	InsnList createSnapshot() {
		InsnList il = new InsnList();

		il.add(new VarInsnNode(ALOAD, 0));  // this
		il.add(new VarInsnNode(ALOAD, 0));
		il.add(new VarInsnNode(ILOAD, LV_RESUME));
		if (context.isVararg()) {
			il.add(new VarInsnNode(ALOAD, LV_VARARGS));
		}
		for (int i = 0; i < numOfRegisters(); i++) {
			il.add(new VarInsnNode(ALOAD, slotOffset() + i));
		}
		il.add(snapshotMethodInvokeInsn());

		return il;
	}

	protected InsnList resumptionHandler(LabelNode label) {
		InsnList il = new InsnList();

		il.add(label);
		il.add(ASMUtils.frameSame1(UnresolvedControlThrowable.class));

		il.add(createSnapshot());

		// register snapshot with the control exception
		il.add(new MethodInsnNode(
				INVOKEVIRTUAL,
				Type.getInternalName(UnresolvedControlThrowable.class),
				"resolve",
				Type.getMethodType(
						Type.getType(ResolvedControlThrowable.class),
						Type.getType(Resumable.class),
						Type.getType(Object.class)).getDescriptor(),
				false));

		// rethrow
		il.add(new InsnNode(ATHROW));

		return il;
	}

	static class ClosureFieldInstance {

		private final FieldNode fieldNode;
		private final InsnList instantiateInsns;

		public ClosureFieldInstance(FieldNode fieldNode, InsnList instantiateInsns) {
			this.fieldNode = Check.notNull(fieldNode);
			this.instantiateInsns = Check.notNull(instantiateInsns);
		}

		public FieldNode fieldNode() {
			return fieldNode;
		}

		public InsnList instantiateInsns() {
			return instantiateInsns;
		}

	}

	public List closureFields() {
		return closureFields;
	}

	abstract static class ConstFieldInstance {

		private final Object value;
		private final String fieldName;
		private final Type ownerClassType;
		private final Type fieldType;

		public ConstFieldInstance(Object value, String fieldName, Type ownerClassType, Type fieldType) {
			this.value = Check.notNull(value);
			this.fieldName = Check.notNull(fieldName);
			this.ownerClassType = Check.notNull(ownerClassType);
			this.fieldType = Check.notNull(fieldType);
		}

		public Object value() {
			return value;
		}

		public FieldNode fieldNode() {
			return new FieldNode(
					ACC_PRIVATE + ACC_STATIC + ACC_FINAL,
					fieldName,
					fieldType.getDescriptor(),
					null,
					null);
		}

		public abstract void doInstantiate(InsnList il);

		public InsnList instantiateInsns() {
			InsnList il = new InsnList();
			doInstantiate(il);
			il.add(new FieldInsnNode(
					PUTSTATIC,
					ownerClassType.getInternalName(),
					fieldName,
					fieldType.getDescriptor()));
	        return il;
		}

		public InsnList accessInsns() {
			InsnList il = new InsnList();
			il.add(new FieldInsnNode(
					GETSTATIC,
					ownerClassType.getInternalName(),
					fieldName,
					fieldType.getDescriptor()));
			return il;
		}

	}

	public List constFields() {
		return constFields;
	}

	private List baseLocals(LabelNode l_begin, LabelNode l_end) {
		List locals = new ArrayList<>();

		locals.add(new LocalVariableNode("this", context.thisClassType().getDescriptor(), null, l_begin, l_end, 0));
		locals.add(new LocalVariableNode("context", Type.getDescriptor(ExecutionContext.class), null, l_begin, l_end, LV_CONTEXT));
		locals.add(new LocalVariableNode("rp", Type.INT_TYPE.getDescriptor(), null, l_begin, l_end, LV_RESUME));

		if (context.isVararg()) {
			locals.add(new LocalVariableNode(
					"varargs",
					ASMUtils.arrayTypeFor(Object.class).getDescriptor(),
					null,
					l_begin,
					l_end,
					LV_VARARGS
					));
		}

		for (int i = 0; i < numOfRegisters(); i++) {
			locals.add(new LocalVariableNode("s_" + i, Type.getDescriptor(Object.class), null, l_begin, l_end, slotOffset() + i));
		}

		return locals;
	}

	private void addLocals(MethodNode node, LabelNode l_begin, LabelNode l_end, BytecodeEmitVisitor visitor) {
		List locals = node.localVariables;
		locals.addAll(baseLocals(l_begin, l_end));
		locals.addAll(visitor.locals());
	}

	private MethodNode emitRunMethod(String methodName, Type returnType, BytecodeEmitVisitor visitor, List blocks, boolean sub) {
		MethodNode node = new MethodNode(
				ACC_PRIVATE,
				methodName,
				methodType(returnType).getDescriptor(),
				null,
				throwsExceptions());

		InsnList insns = node.instructions;

		LabelNode l_begin = new LabelNode();
		LabelNode l_end = new LabelNode();

		visitor.visitBlocks(blocks);

		InsnList prefix = new InsnList();
		InsnList suffix = new InsnList();

		final LabelNode l_head;
		final List els = new ArrayList<>();
		if (sub) {
			assert (!blocks.isEmpty());
			for (int i = blocks.size() - 1; i >= 0; i--) {
				BasicBlock blk = blocks.get(i);
				LabelNode l = visitor.labels.get(blk.label());
				assert (l != null);
				els.add(l);
			}
			l_head = visitor.labels.get(blocks.get(0).label());
		}
		else {
			l_head = new LabelNode();
			els.add(l_head);
		}

		assert (l_head != null);

		if (visitor.isResumable()) {
			LabelNode l_error_state = new LabelNode();
			LabelNode l_handler_begin = new LabelNode();

			List rls = visitor.resumptionLabels();

			assert (!rls.isEmpty() || !els.isEmpty());

			prefix.add(dispatchTable(els, rls, l_error_state));

			final LabelNode l_entry = l_head;

			if (!sub) {
				prefix.add(l_entry);
				prefix.add(ASMUtils.frameSame());
			}

			suffix.add(errorState(l_error_state));
			suffix.add(resumptionHandler(l_handler_begin));

			node.tryCatchBlocks.add(new TryCatchBlockNode(l_entry, l_error_state, l_handler_begin, Type.getInternalName(UnresolvedControlThrowable.class)));
		}

		insns.add(l_begin);
		insns.add(prefix);
		insns.add(visitor.instructions());
		insns.add(suffix);
		insns.add(l_end);

		addLocals(node, l_begin, l_end, visitor);

		return node;
	}

	private MethodNode emitSingletonRunMethod(BytecodeEmitVisitor visitor, List blocks) {
		return emitRunMethod(methodName(), Type.VOID_TYPE, visitor, blocks, false);
	}

	private String subRunMethodName(int segmentIdx) {
		return "run_" + segmentIdx;
	}

	private MethodNode emitSegmentedSubRunMethod(int segmentIdx, BytecodeEmitVisitor visitor, List blocks) {
		return emitRunMethod(subRunMethodName(segmentIdx), context.savedStateClassType(), visitor, blocks, true);
	}

	private MethodNode emitSegmentedRunMethod(int numOfSegments) {
		MethodNode node = new MethodNode(
				ACC_PRIVATE,
				methodName(),
				methodType().getDescriptor(),
				null,
				throwsExceptions());

		InsnList il = node.instructions;

		int lvOffset = slotOffset() + numOfRegisters();

		int lv_rpp        = lvOffset + 0;
		int lv_methodIdx  = lvOffset + 1;
		int lv_jmpIdx     = lvOffset + 2;
		int lv_stateIdx   = lvOffset + 3;
		int lv_savedState = lvOffset + 4;

		LabelNode l_top = new LabelNode();
		LabelNode l_ret = new LabelNode();
		LabelNode l_end = new LabelNode();

		LabelNode l_rpp = new LabelNode();
		LabelNode l_methodIdx = new LabelNode();
		LabelNode l_jmpIdx = new LabelNode();
		LabelNode l_stateIdx = new LabelNode();
		LabelNode l_savedState = new LabelNode();

		il.add(l_top);
		il.add(new FrameNode(F_SAME, 0, null, 0, null));

		// rpp = rp & ((1 << ST_SHIFT_SEGMENT) - 1)
		il.add(new VarInsnNode(ILOAD, LV_RESUME));
		il.add(ASMUtils.loadInt((1 << ST_SHIFT_SEGMENT) - 1));
		il.add(new InsnNode(IAND));
		il.add(new VarInsnNode(ISTORE, lv_rpp));
		il.add(l_rpp);
		il.add(new FrameNode(F_APPEND, 1, new Object[] { Opcodes.INTEGER }, 0, null));

		// methodIdx = rp >>> ST_SHIFT_SEGMENT
		il.add(new VarInsnNode(ILOAD, LV_RESUME));
		il.add(ASMUtils.loadInt(ST_SHIFT_SEGMENT));
		il.add(new InsnNode(IUSHR));
		il.add(new VarInsnNode(ISTORE, lv_methodIdx));
		il.add(l_methodIdx);
		il.add(new FrameNode(F_APPEND, 1, new Object[] { Opcodes.INTEGER }, 0, null));

		// jmpIdx = rpp >>> ST_SHIFT_LABELIDX
		il.add(new VarInsnNode(ILOAD, lv_rpp));
		il.add(ASMUtils.loadInt(ST_SHIFT_LABELIDX));
		il.add(new InsnNode(IUSHR));
		il.add(new VarInsnNode(ISTORE, lv_jmpIdx));
		il.add(l_jmpIdx);
		il.add(new FrameNode(F_APPEND, 1, new Object[] { Opcodes.INTEGER }, 0, null));

		// stateIdx = (rp & ((1 << ST_SHIFT_LABELIDX) - 1)) - jmpIdx
		il.add(new VarInsnNode(ILOAD, LV_RESUME));
		il.add(ASMUtils.loadInt((1 << ST_SHIFT_LABELIDX) - 1));
		il.add(new InsnNode(IAND));
		il.add(new VarInsnNode(ILOAD, lv_jmpIdx));
		il.add(new InsnNode(ISUB));
		il.add(new VarInsnNode(ISTORE, lv_stateIdx));
		il.add(l_stateIdx);
		il.add(new FrameNode(F_APPEND, 1, new Object[] { Opcodes.INTEGER }, 0, null));

		// savedState = null
		il.add(new InsnNode(ACONST_NULL));
		il.add(new VarInsnNode(ASTORE, lv_savedState));
		il.add(l_savedState);
		il.add(new FrameNode(F_APPEND, 1, new Object[] { context.savedStateClassType().getInternalName() }, 0, null));

		// switch on methodIdx

		LabelNode l_after = new LabelNode();

		LabelNode l_error = new LabelNode();
		LabelNode[] l_invokes = new LabelNode[numOfSegments];
		for (int i = 0; i < numOfSegments; i++) {
			l_invokes[i] = new LabelNode();
		}

		il.add(new VarInsnNode(ILOAD, lv_methodIdx));
		il.add(new TableSwitchInsnNode(0, numOfSegments - 1, l_error, l_invokes));

		for (int i = 0; i < numOfSegments; i++) {
			il.add(l_invokes[i]);
			il.add(new FrameNode(F_SAME, 0, null, 0, null));
			// push arguments to stack
			il.add(new VarInsnNode(ALOAD, 0));
			il.add(new VarInsnNode(ALOAD, LV_CONTEXT));
			il.add(new VarInsnNode(ILOAD, lv_stateIdx));  // pass stateIdx to the sub-method
			if (context.isVararg()) {
				il.add(new VarInsnNode(ALOAD, LV_VARARGS));
			}
			for (int j = 0; j < numOfRegisters(); j++) {
				il.add(new VarInsnNode(ALOAD, slotOffset() + j));
			}

			il.add(new MethodInsnNode(INVOKESPECIAL,
					context.thisClassType().getInternalName(),
					subRunMethodName(i),
					subMethodType().getDescriptor(),
					false));

			il.add(new VarInsnNode(ASTORE, lv_savedState));
			il.add(new JumpInsnNode(GOTO, l_after));
		}

		// error state
		il.add(errorState(l_error));

		il.add(l_after);
		il.add(new FrameNode(F_SAME, 0, null, 0, null));

		il.add(new VarInsnNode(ALOAD, lv_savedState));
		il.add(new JumpInsnNode(IFNULL, l_ret));  // savedState == null ?

		// continuing: savedState != null

		// FIXME: taken from ResumeMethod -- beware of code duplication!

		il.add(new VarInsnNode(ALOAD, lv_savedState));  // saved state
		il.add(new MethodInsnNode(
				INVOKEVIRTUAL,
				Type.getInternalName(DefaultSavedState.class),
				"resumptionPoint",
				Type.getMethodDescriptor(
						Type.INT_TYPE),
				false
		));  // resumption point
		il.add(new VarInsnNode(ISTORE, LV_RESUME));

		// registers
		if (context.isVararg() || numOfRegisters() > 0) {
			il.add(new VarInsnNode(ALOAD, lv_savedState));
			il.add(new MethodInsnNode(
					INVOKEVIRTUAL,
					Type.getInternalName(DefaultSavedState.class),
					"registers",
					Type.getMethodDescriptor(
							ASMUtils.arrayTypeFor(Object.class)),
					false
			));

			int numRegs = numOfRegisters() + (context.isVararg() ? 1 : 0);

			for (int i = 0; i < numRegs; i++) {
				if (i + 1 < numRegs) {
					il.add(new InsnNode(DUP));
				}
				il.add(ASMUtils.loadInt(i));
				il.add(new InsnNode(AALOAD));
				if (i == 0 && context.isVararg()) {
					il.add(new TypeInsnNode(CHECKCAST, ASMUtils.arrayTypeFor(Object.class).getInternalName()));
				}
				il.add(new VarInsnNode(ASTORE, LV_VARARGS + i));
			}
		}

		// loop back to the beginning
		il.add(new JumpInsnNode(GOTO, l_top));

		// got a null, that's the end
		il.add(l_ret);
		il.add(new FrameNode(F_SAME, 0, null, 0, null));
		il.add(new InsnNode(RETURN));

		il.add(l_end);

		// add local variables
		node.localVariables.addAll(baseLocals(l_top, l_end));
		node.localVariables.add(new LocalVariableNode("rpp", Type.INT_TYPE.getDescriptor(), null, l_rpp, l_ret, lv_rpp));
		node.localVariables.add(new LocalVariableNode("methodIdx", Type.INT_TYPE.getDescriptor(), null, l_methodIdx, l_ret, lv_methodIdx));
		node.localVariables.add(new LocalVariableNode("jmpIdx", Type.INT_TYPE.getDescriptor(), null, l_jmpIdx, l_ret, lv_jmpIdx));
		node.localVariables.add(new LocalVariableNode("stateIdx", Type.INT_TYPE.getDescriptor(), null, l_stateIdx, l_ret, lv_stateIdx));
		node.localVariables.add(new LocalVariableNode("savedState", context.savedStateClassType().getDescriptor(), null, l_savedState, l_ret, lv_savedState));


		return node;
	}

	public List methodNodes() {
		return methodNodes;
	}

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy