
org.skyway.spring.util.webservice.cxf.SpringSecurityCallbackHandler Maven / Gradle / Ivy
The newest version!
/**
* Copyright 2007 - 2011 Skyway Software, Inc.
*/
package org.skyway.spring.util.webservice.cxf;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import org.apache.ws.security.WSPasswordCallback;
/**
* Handles authentication of incoming web service calls by asking the spring
* authentication manager to authenticate the credentials. If the credentials
* are valid, the SecurityContext is initialized and set.
*
* @author jperkins
*
*/
public abstract class SpringSecurityCallbackHandler implements CallbackHandler {
protected abstract String getAuthenticationExceptionClassName();
protected abstract String getAuthenticationManagerClassName();
protected abstract String getSecurityContextHolderClassName();
protected abstract String getSecurityContextInterfaceName();
protected abstract String getSecurityContextClassName();
protected abstract String getAuthenticationInterfaceName();
protected abstract String getAuthenticationClassName();
private Object authenticationManager;
public SpringSecurityCallbackHandler() {
super();
}
public void setAuthenticationManager(Object authenticationManager){
this.authenticationManager = authenticationManager;
}
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
WSPasswordCallback passwordCallback = (WSPasswordCallback) callbacks[0];
authenticate(passwordCallback);
}
@SuppressWarnings("deprecation")
private void authenticate(WSPasswordCallback passwordCallback){
String userName = passwordCallback.getIdentifer();
String password = passwordCallback.getPassword();
Object securityContext;
Object authentication;
authentication = constructAuthenticationInstance(userName, password);
try {
authentication = authenticate(authenticationManager, authentication);
if (isAuthenticated(authentication)) {
// Set the security context
securityContext = getSecurityContext();
setAuthentication(securityContext, authentication);
setContext(securityContext);
} else {
throw new RuntimeException("Invalid credentials."); //$NON-NLS-1$
}
} catch (RuntimeException x) {
handleException(x);
}
}
protected Object constructAuthenticationInstance(String userName, String password){
// new UsernamePasswordAuthenticationToken(userName, password)
Constructor> constructor = getConstructor(getClassByName(getAuthenticationClassName()), Object.class, Object.class);
return callConstructor(constructor, userName, password);
}
protected Object authenticate(Object authenticationManager, Object authentication){
//return authenticationManager.authenticate(authentication);
Method method = getMethod(authenticationManager.getClass(), getAuthenticateMethodName(), getClassByName(getAuthenticationInterfaceName()));
return callMethod(authenticationManager, method, authentication);
}
protected boolean isAuthenticated(Object authentication){
// authentication.isAuthenticated()
Method method = getMethod(authentication.getClass(), getIsAuthenticatedMethodName());
return (Boolean)callMethod(authentication, method);
}
protected Object getSecurityContext(){
// SecurityContext securityContext = new SecurityContextImpl();
Constructor> constructor = getConstructor(getClassByName(getSecurityContextClassName()));
return callConstructor(constructor);
}
protected void setAuthentication(Object securityContext, Object authentication){
// securityContext.setAuthentication(authentication);
Method method = getMethod(securityContext.getClass(), getSetAuthenticationMethodName(), getClassByName(getAuthenticationInterfaceName()));
callMethod(securityContext, method, authentication);
}
protected void setContext(Object securityContext){
// SecurityContextHolder.setContext(securityContext);
Method method = getMethod(getClassByName(getSecurityContextHolderClassName()), getSetContextMethodName(), getClassByName(getSecurityContextInterfaceName()));
callMethod(null, method, securityContext);
}
protected void handleException(Throwable x){
Class> exceptionClass;
exceptionClass = getClassByName(getAuthenticationExceptionClassName());
if (exceptionClass.isInstance(x)){
throw new RuntimeException("Invalid credentials.", x);
}else if (x instanceof RuntimeException){
throw (RuntimeException)x;
}else{
throw new RuntimeException("Exception during authentication: ", x);
}
}
protected Constructor> getConstructor(Class> classToConstruct, Class>... parameterTypes){
try {
return classToConstruct.getConstructor(parameterTypes);
} catch (SecurityException e) {
throw new RuntimeException("Unable to get constructor for class: " + classToConstruct.getName(), e);
} catch (NoSuchMethodException e) {
throw new RuntimeException("Unable to get constructor for class: " + classToConstruct.getName(), e);
}
}
protected Object callConstructor(Constructor> constructor, Object... arguments){
try {
return constructor.newInstance(arguments);
} catch (IllegalArgumentException e) {
throw new RuntimeException("Unable to get instance of class: " + constructor.getDeclaringClass().getName(), e);
} catch (IllegalAccessException e) {
throw new RuntimeException("Unable to get instance of class: " + constructor.getDeclaringClass().getName(), e);
} catch (InvocationTargetException e) {
throw new RuntimeException("Unable to get instance of class: " + constructor.getDeclaringClass().getName(), e);
} catch (InstantiationException e) {
throw new RuntimeException("Unable to get instance of class: " + constructor.getDeclaringClass().getName(), e);
}
}
protected Method getMethod(Class> classToCall, String methodName, Class>... parameterTypes){
try {
return classToCall.getMethod(methodName, parameterTypes);
} catch (SecurityException e) {
throw new RuntimeException("Unable to get method: " + methodName, e);
} catch (NoSuchMethodException e) {
throw new RuntimeException("Unable to get method: " + methodName, e);
}
}
protected Object callMethod(Object instance, Method method, Object... arguments){
try {
return method.invoke(instance, arguments);
} catch (IllegalArgumentException e) {
throw new RuntimeException("Unable to call method: " + method, e);
} catch (IllegalAccessException e) {
throw new RuntimeException("Unable to call method: " + method, e);
} catch (InvocationTargetException e) {
handleException(e.getCause());
return null;
}
}
protected Class> getClassByName(String className){
try {
return Class.forName(className, true, getClass().getClassLoader());
} catch (ClassNotFoundException e) {
throw new RuntimeException("Unable to get class: " + className, e);
}
}
protected String getAuthenticateMethodName(){
return "authenticate";
}
protected String getIsAuthenticatedMethodName(){
return "isAuthenticated";
}
protected String getSetAuthenticationMethodName(){
return "setAuthentication";
}
protected String getSetContextMethodName(){
return "setContext";
}
}
© 2015 - 2025 Weber Informatics LLC | Privacy Policy