All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.skyway.spring.util.webservice.cxf.SpringSecurityCallbackHandler Maven / Gradle / Ivy

The newest version!
/**
* Copyright 2007 - 2011 Skyway Software, Inc.
*/
package org.skyway.spring.util.webservice.cxf;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;

import org.apache.ws.security.WSPasswordCallback;

/**
 * Handles authentication of incoming web service calls by asking the spring
 * authentication manager to authenticate the credentials.  If the credentials
 * are valid, the SecurityContext is initialized and set.
 * 
 * @author jperkins
 *
 */
public abstract class SpringSecurityCallbackHandler implements CallbackHandler {
	protected abstract String getAuthenticationExceptionClassName();
	protected abstract String getAuthenticationManagerClassName();
	protected abstract String getSecurityContextHolderClassName();
	protected abstract String getSecurityContextInterfaceName();
	protected abstract String getSecurityContextClassName();
	protected abstract String getAuthenticationInterfaceName();
	protected abstract String getAuthenticationClassName();
	
	private Object authenticationManager;
	
	public SpringSecurityCallbackHandler() {
		super();
	}

	public void setAuthenticationManager(Object authenticationManager){
		this.authenticationManager = authenticationManager;
	}
	
	public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
		WSPasswordCallback passwordCallback = (WSPasswordCallback) callbacks[0];

		authenticate(passwordCallback);
	}

	@SuppressWarnings("deprecation")
	private void authenticate(WSPasswordCallback passwordCallback){
		String userName = passwordCallback.getIdentifer();
		String password = passwordCallback.getPassword();
		Object securityContext;
		Object authentication;
		
		authentication = constructAuthenticationInstance(userName, password);
		try {
			authentication = authenticate(authenticationManager, authentication);
			if (isAuthenticated(authentication)) {
				// Set the security context
				securityContext = getSecurityContext();
				
				setAuthentication(securityContext, authentication);
				setContext(securityContext);
			} else {
				throw new RuntimeException("Invalid credentials."); //$NON-NLS-1$
			}
		} catch (RuntimeException x) {
			handleException(x);
		}
	}
	
	protected Object constructAuthenticationInstance(String userName, String password){
		// new UsernamePasswordAuthenticationToken(userName, password)
		Constructor constructor = getConstructor(getClassByName(getAuthenticationClassName()), Object.class, Object.class);
		
		return callConstructor(constructor, userName, password);
	}

	protected Object authenticate(Object authenticationManager, Object authentication){
		//return authenticationManager.authenticate(authentication);
		Method method = getMethod(authenticationManager.getClass(), getAuthenticateMethodName(), getClassByName(getAuthenticationInterfaceName()));
		
		return callMethod(authenticationManager, method, authentication);
	}
	
	protected boolean isAuthenticated(Object authentication){
		// authentication.isAuthenticated()
		Method method = getMethod(authentication.getClass(), getIsAuthenticatedMethodName());
		
		return (Boolean)callMethod(authentication, method);
	}
	
	protected Object getSecurityContext(){
		// SecurityContext securityContext = new SecurityContextImpl();
		Constructor constructor = getConstructor(getClassByName(getSecurityContextClassName()));
		
		return callConstructor(constructor);
	}
	
	protected void setAuthentication(Object securityContext, Object authentication){
		// securityContext.setAuthentication(authentication);
		Method method = getMethod(securityContext.getClass(), getSetAuthenticationMethodName(), getClassByName(getAuthenticationInterfaceName()));
		
		callMethod(securityContext, method, authentication);
	}
	
	protected void setContext(Object securityContext){
		// SecurityContextHolder.setContext(securityContext);
		Method method = getMethod(getClassByName(getSecurityContextHolderClassName()), getSetContextMethodName(), getClassByName(getSecurityContextInterfaceName()));
		
		callMethod(null, method, securityContext);
	}
	
	protected void handleException(Throwable x){
		Class exceptionClass;
		
		exceptionClass = getClassByName(getAuthenticationExceptionClassName());
		if (exceptionClass.isInstance(x)){
			throw new RuntimeException("Invalid credentials.", x);
		}else if (x instanceof RuntimeException){
			throw (RuntimeException)x;
		}else{
			throw new RuntimeException("Exception during authentication: ", x);
		}
	}
	
	protected Constructor getConstructor(Class classToConstruct, Class... parameterTypes){
		try {
			return classToConstruct.getConstructor(parameterTypes);
		} catch (SecurityException e) {
			throw new RuntimeException("Unable to get constructor for class: " + classToConstruct.getName(), e);
		} catch (NoSuchMethodException e) {
			throw new RuntimeException("Unable to get constructor for class: " + classToConstruct.getName(), e);
		}
	}
	
	protected Object callConstructor(Constructor constructor, Object... arguments){
		try {
			return constructor.newInstance(arguments);
		} catch (IllegalArgumentException e) {
			throw new RuntimeException("Unable to get instance of class: " + constructor.getDeclaringClass().getName(), e);
		} catch (IllegalAccessException e) {
			throw new RuntimeException("Unable to get instance of class: " + constructor.getDeclaringClass().getName(), e);
		} catch (InvocationTargetException e) {
			throw new RuntimeException("Unable to get instance of class: " + constructor.getDeclaringClass().getName(), e);
		} catch (InstantiationException e) {
			throw new RuntimeException("Unable to get instance of class: " + constructor.getDeclaringClass().getName(), e);
		}
	}
	
	protected Method getMethod(Class classToCall, String methodName, Class... parameterTypes){
		try {
			return classToCall.getMethod(methodName, parameterTypes);
		} catch (SecurityException e) {
			throw new RuntimeException("Unable to get method: " + methodName, e);
		} catch (NoSuchMethodException e) {
			throw new RuntimeException("Unable to get method: " + methodName, e);
		}
	}
	
	protected Object callMethod(Object instance, Method method, Object... arguments){
		try {
			return method.invoke(instance, arguments);
		} catch (IllegalArgumentException e) {
			throw new RuntimeException("Unable to call method: " + method, e);
		} catch (IllegalAccessException e) {
			throw new RuntimeException("Unable to call method: " + method, e);
		} catch (InvocationTargetException e) {
			handleException(e.getCause());
			return null;
		}
	}
	
	protected Class getClassByName(String className){
		try {
			return Class.forName(className, true, getClass().getClassLoader());
		} catch (ClassNotFoundException e) {
			throw new RuntimeException("Unable to get class: " + className, e);
		}
	}
	
	protected String getAuthenticateMethodName(){
		return "authenticate";
	}

	protected String getIsAuthenticatedMethodName(){
		return "isAuthenticated";
	}

	protected String getSetAuthenticationMethodName(){
		return "setAuthentication";
	}

	protected String getSetContextMethodName(){
		return "setContext";
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy