org.apache.commons.weaver.privilizer.BlueprintingVisitor Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of commons-weaver-privilizer Show documentation
Show all versions of commons-weaver-privilizer Show documentation
Implements the Apache Commons Weaver SPI for the Privilizer module.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.commons.weaver.privilizer;
import java.io.InputStream;
import java.lang.reflect.Modifier;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.Validate;
import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.commons.lang3.tuple.Pair;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.FieldVisitor;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.commons.AdviceAdapter;
import org.objectweb.asm.commons.GeneratorAdapter;
import org.objectweb.asm.commons.Method;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.MethodNode;
/**
* {@link ClassVisitor} to import so-called "blueprint methods".
*/
class BlueprintingVisitor extends Privilizer.PrivilizerClassVisitor {
private final Set blueprintTypes = new HashSet();
private final Map, MethodNode> blueprintRegistry = new HashMap, MethodNode>();
private final Map, String> importedMethods = new HashMap, String>();
private final Map> methodCache = new HashMap>();
private final Map, FieldAccess> fieldAccessMap = new HashMap, FieldAccess>();
private final ClassVisitor next;
/**
* Create a new {@link BlueprintingVisitor}.
* @param privilizer owner
* @param next wrapped
* @param config annotation
*/
BlueprintingVisitor(@SuppressWarnings("PMD.UnusedFormalParameter") final Privilizer privilizer, //false positive
final ClassVisitor next,
final Privilizing config) {
privilizer.super(new ClassNode(Opcodes.ASM5));
this.next = next;
// load up blueprint methods:
for (final Privilizing.CallTo callTo : config.value()) {
final Type blueprintType = Type.getType(callTo.value());
blueprintTypes.add(blueprintType);
for (final Map.Entry entry : getMethods(blueprintType).entrySet()) {
boolean found = false;
if (callTo.methods().length == 0) {
found = true;
} else {
for (final String name : callTo.methods()) {
if (entry.getKey().getName().equals(name)) {
found = true;
break;
}
}
}
if (found) {
blueprintRegistry.put(Pair.of(blueprintType, entry.getKey()), entry.getValue());
}
}
}
}
private Map getMethods(final Type type) {
if (methodCache.containsKey(type)) {
return methodCache.get(type);
}
final ClassNode classNode = read(type.getClassName());
final Map result = new HashMap();
@SuppressWarnings("unchecked")
final List methods = classNode.methods;
for (final MethodNode methodNode : methods) {
if (Modifier.isStatic(methodNode.access) && !"".equals(methodNode.name)) {
result.put(new Method(methodNode.name, methodNode.desc), methodNode);
}
}
methodCache.put(type, result);
return result;
}
private ClassNode read(final String className) {
final ClassNode result = new ClassNode(Opcodes.ASM5);
InputStream bytecode = null;
try {
bytecode = privilizer().env.getClassfile(className).getInputStream();
new ClassReader(bytecode).accept(result, ClassReader.SKIP_DEBUG | ClassReader.EXPAND_FRAMES);
} catch (final Exception e) {
throw new RuntimeException(e);
} finally {
IOUtils.closeQuietly(bytecode);
}
return result;
}
@Override
@SuppressWarnings("PMD.UseVarargs") //overridden method
public void visit(final int version, final int access, final String name, final String signature,
final String superName, final String[] interfaces) {
Validate.isTrue(!blueprintTypes.contains(Type.getObjectType(name)),
"Class %s cannot declare itself as a blueprint!", name);
super.visit(version, access, name, signature, superName, interfaces);
}
@Override
@SuppressWarnings("PMD.UseVarargs") //overridden method
public MethodVisitor visitMethod(final int access, final String name, final String desc, final String signature,
final String[] exceptions) {
final MethodVisitor toWrap = super.visitMethod(access, name, desc, signature, exceptions);
return new MethodInvocationHandler(toWrap) {
@Override
boolean shouldImport(final Pair methodKey) {
return blueprintRegistry.containsKey(methodKey);
}
};
}
private String importMethod(final Pair key) {
if (importedMethods.containsKey(key)) {
return importedMethods.get(key);
}
final String result =
new StringBuilder(key.getLeft().getInternalName().replace('/', '_')).append("$$")
.append(key.getRight().getName()).toString();
importedMethods.put(key, result);
privilizer().env.debug("importing %s#%s as %s", key.getLeft().getClassName(), key.getRight(), result);
final int access = Opcodes.ACC_PRIVATE + Opcodes.ACC_STATIC + Opcodes.ACC_SYNTHETIC;
final MethodNode source = getMethods(key.getLeft()).get(key.getRight());
@SuppressWarnings("unchecked")
final String[] exceptions = (String[]) source.exceptions.toArray(ArrayUtils.EMPTY_STRING_ARRAY);
// non-public fields accessed
final Set fieldAccesses = new LinkedHashSet();
source.accept(new MethodVisitor(Opcodes.ASM5) {
@Override
public void visitFieldInsn(final int opcode, final String owner, final String name, final String desc) {
final FieldAccess fieldAccess = fieldAccess(Type.getObjectType(owner), name);
super.visitFieldInsn(opcode, owner, name, desc);
if (!Modifier.isPublic(fieldAccess.access)) {
fieldAccesses.add(fieldAccess);
}
}
});
final MethodNode withAccessibleAdvice =
new MethodNode(access, result, source.desc, source.signature, exceptions);
// spider own methods:
MethodVisitor mv = new NestedMethodInvocationHandler(withAccessibleAdvice, key.getLeft()); //NOPMD
if (!fieldAccesses.isEmpty()) {
// accessesNonPublicFields = true;
mv = new AccessibleAdvisor(mv, access, result, source.desc, new ArrayList(fieldAccesses));
}
source.accept(mv);
// private can only be called by other privileged methods, so no need to mark as privileged
if (!Modifier.isPrivate(source.access)) {
withAccessibleAdvice.visitAnnotation(Type.getType(Privileged.class).getDescriptor(), false).visitEnd();
}
withAccessibleAdvice.accept(this.cv);
return result;
}
private FieldAccess fieldAccess(final Type owner, final String name) {
final Pair key = Pair.of(owner, name);
if (!fieldAccessMap.containsKey(key)) {
try {
final MutableObject next = new MutableObject(owner);
final Deque stk = new ArrayDeque();
while (next.getValue() != null) {
stk.push(next.getValue());
InputStream bytecode = null;
try {
bytecode = privilizer().env.getClassfile(next.getValue().getInternalName()).getInputStream();
new ClassReader(bytecode).accept(privilizer().new PrivilizerClassVisitor() {
@Override
@SuppressWarnings("PMD.UseVarargs") //overridden method
public void visit(final int version, final int access, final String name,
final String signature, final String superName, final String[] interfaces) {
super.visit(version, access, name, signature, superName, interfaces);
next.setValue(Type.getObjectType(superName));
}
@Override
public FieldVisitor visitField(final int access, final String name, final String desc,
final String signature, final Object value) {
for (final Type type : stk) {
final Pair key = Pair.of(type, name);
// skip shadowed fields:
if (!fieldAccessMap.containsKey(key)) {
fieldAccessMap.put(key,
new FieldAccess(access, target, name, Type.getType(desc)));
}
}
return null;
}
}, ClassReader.SKIP_CODE);
} finally {
IOUtils.closeQuietly(bytecode);
}
if (fieldAccessMap.containsKey(key)) {
break;
}
}
} catch (final Exception e) {
throw new RuntimeException(e);
}
Validate.isTrue(fieldAccessMap.containsKey(key), "Could not locate %s.%s", owner.getClassName(), name);
}
return fieldAccessMap.get(key);
}
@Override
public void visitEnd() {
super.visitEnd();
((ClassNode) cv).accept(next);
}
private abstract class MethodInvocationHandler extends MethodVisitor {
MethodInvocationHandler(final MethodVisitor mvr) {
super(Opcodes.ASM5, mvr);
}
@Override
public void visitMethodInsn(final int opcode, final String owner, final String name, final String desc,
final boolean itf) {
if (opcode == Opcodes.INVOKESTATIC) {
final Method methd = new Method(name, desc);
final Pair methodKey = Pair.of(Type.getObjectType(owner), methd);
if (shouldImport(methodKey)) {
final String importedName = importMethod(methodKey);
super.visitMethodInsn(opcode, className, importedName, desc, itf);
return;
}
}
super.visitMethodInsn(opcode, owner, name, desc, itf);
}
abstract boolean shouldImport(Pair methodKey);
}
class NestedMethodInvocationHandler extends MethodInvocationHandler {
final Type owner;
NestedMethodInvocationHandler(final MethodVisitor mvr, final Type owner) {
super(mvr);
this.owner = owner;
}
@Override
boolean shouldImport(final Pair methodKey) {
// call anything called within a class hierarchy:
final Type called = methodKey.getLeft();
// "I prefer the short cut":
if (called.equals(owner)) {
return true;
}
try {
final Class> inner = load(called);
final Class> outer = load(owner);
return inner.isAssignableFrom(outer);
} catch (final ClassNotFoundException e) {
return false;
}
}
private Class> load(final Type type) throws ClassNotFoundException {
return privilizer().env.classLoader.loadClass(type.getClassName());
}
}
/**
* For every non-public referenced field of an imported method, replaces with reflective calls. Additionally, for
* every such field that is not accessible, sets the field's accessibility and clears it as the method exits.
*/
private class AccessibleAdvisor extends AdviceAdapter {
final Type bitSetType = Type.getType(BitSet.class);
final Type classType = Type.getType(Class.class);
final Type fieldType = Type.getType(java.lang.reflect.Field.class);
final Type fieldArrayType = Type.getType(java.lang.reflect.Field[].class);
final Type stringType = Type.getType(String.class);
final List fieldAccesses;
final Label begin = new Label();
int localFieldArray;
int bitSet;
int fieldCounter;
AccessibleAdvisor(final MethodVisitor mvr, final int access, final String name, final String desc,
final List fieldAccesses) {
super(ASM5, mvr, access, name, desc);
this.fieldAccesses = fieldAccesses;
}
@Override
protected void onMethodEnter() {
localFieldArray = newLocal(fieldArrayType);
bitSet = newLocal(bitSetType);
fieldCounter = newLocal(Type.INT_TYPE);
// create localFieldArray
push(fieldAccesses.size());
newArray(fieldArrayType.getElementType());
storeLocal(localFieldArray);
// create bitSet
newInstance(bitSetType);
dup();
push(fieldAccesses.size());
invokeConstructor(bitSetType, Method.getMethod("void (int)"));
storeLocal(bitSet);
// populate localFieldArray
push(0);
storeLocal(fieldCounter);
for (final FieldAccess access : fieldAccesses) {
prehandle(access);
iinc(fieldCounter, 1);
}
mark(begin);
}
private void prehandle(final FieldAccess access) {
// push owner.class literal
visitLdcInsn(access.owner);
push(access.name);
final Label next = new Label();
invokeVirtual(classType, new Method("getDeclaredField", fieldType, new Type[] { stringType }));
dup();
// store the field at localFieldArray[fieldCounter]:
loadLocal(localFieldArray);
swap();
loadLocal(fieldCounter);
swap();
arrayStore(fieldArrayType.getElementType());
dup();
invokeVirtual(fieldArrayType.getElementType(), Method.getMethod("boolean isAccessible()"));
final Label setAccessible = new Label();
// if false, setAccessible:
ifZCmp(EQ, setAccessible);
// else pop field instance
pop();
// and record that he was already accessible:
loadLocal(bitSet);
loadLocal(fieldCounter);
invokeVirtual(bitSetType, Method.getMethod("void set(int)"));
goTo(next);
mark(setAccessible);
push(true);
invokeVirtual(fieldArrayType.getElementType(), Method.getMethod("void setAccessible(boolean)"));
mark(next);
}
@Override
public void visitFieldInsn(final int opcode, final String owner, final String name, final String desc) {
final Pair key = Pair.of(Type.getObjectType(owner), name);
final FieldAccess fieldAccess = fieldAccessMap.get(key);
Validate.isTrue(fieldAccesses.contains(fieldAccess), "Cannot find field %s", key);
final int fieldIndex = fieldAccesses.indexOf(fieldAccess);
visitInsn(NOP);
loadLocal(localFieldArray);
push(fieldIndex);
arrayLoad(fieldArrayType.getElementType());
checkCast(fieldType);
final Method access;
if (opcode == PUTSTATIC) {
// value should have been at top of stack on entry; position the field under the value:
swap();
// add null object for static field deref and swap under value:
push((String) null);
swap();
if (fieldAccess.type.getSort() < Type.ARRAY) {
// box value:
valueOf(fieldAccess.type);
}
access = Method.getMethod("void set(Object, Object)");
} else {
access = Method.getMethod("Object get(Object)");
// add null object for static field deref:
push((String) null);
}
invokeVirtual(fieldType, access);
if (opcode == GETSTATIC) {
checkCast(privilizer().wrap(fieldAccess.type));
if (fieldAccess.type.getSort() < Type.ARRAY) {
unbox(fieldAccess.type);
}
}
}
@Override
public void visitMaxs(final int maxStack, final int maxLocals) {
// put try-finally around the whole method
final Label fny = mark();
// null exception type signifies finally block:
final Type exceptionType = null;
catchException(begin, fny, exceptionType);
onFinally();
throwException();
super.visitMaxs(maxStack, maxLocals);
}
@Override
protected void onMethodExit(final int opcode) {
if (opcode != ATHROW) {
onFinally();
}
}
private void onFinally() {
// loop over fields and return any non-null element to being inaccessible:
push(0);
storeLocal(fieldCounter);
final Label test = mark();
final Label increment = new Label();
final Label endFinally = new Label();
loadLocal(fieldCounter);
push(fieldAccesses.size());
ifCmp(Type.INT_TYPE, GeneratorAdapter.GE, endFinally);
loadLocal(bitSet);
loadLocal(fieldCounter);
invokeVirtual(bitSetType, Method.getMethod("boolean get(int)"));
// if true, increment:
ifZCmp(NE, increment);
loadLocal(localFieldArray);
loadLocal(fieldCounter);
arrayLoad(fieldArrayType.getElementType());
push(false);
invokeVirtual(fieldArrayType.getElementType(), Method.getMethod("void setAccessible(boolean)"));
mark(increment);
iinc(fieldCounter, 1);
goTo(test);
mark(endFinally);
}
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy