All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.adeptnet.auth.sso.SSOAuthModule Maven / Gradle / Ivy

/*
 * Copyright 2015 Francois Steyn - Adept Internet (PTY) LTD ([email protected]).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.adeptnet.auth.sso;

import org.adeptnet.auth.sso.common.SSOCredentials;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.LoginContext;
import javax.security.auth.login.LoginException;
import javax.security.auth.message.AuthException;
import javax.security.auth.message.AuthStatus;
import javax.security.auth.message.MessageInfo;
import javax.security.auth.message.MessagePolicy;
import javax.security.auth.message.config.ServerAuthContext;
import javax.security.auth.message.module.ServerAuthModule;
import javax.servlet.RequestDispatcher;
import javax.servlet.ServletException;
import javax.servlet.ServletRequestWrapper;
import javax.servlet.ServletResponseWrapper;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import javax.servlet.http.HttpSession;
import org.adeptnet.auth.kerberos.Krb5;
import org.adeptnet.auth.saml.SAMLClient;
import org.adeptnet.auth.saml.SAMLException;
import org.opensaml.ws.message.encoder.MessageEncodingException;
import org.adeptnet.auth.sso.common.Krb5Credentials;
import org.adeptnet.auth.sso.common.SAMLCredentials;
import org.adeptnet.auth.sso.common.SSOCallback;

/**
 *
 * @author Francois Steyn - Adept Internet (PTY) LTD ([email protected])
 */
public class SSOAuthModule implements ServerAuthModule, ServerAuthContext, CallbackHandler {

    private static final Logger LOG = Logger.getLogger(SSOAuthModule.class.getName());

    private static final Class[] SUPPORTED_MESSAGE_TYPES = new Class[]{HttpServletRequest.class, HttpServletResponse.class};
    private static final String IS_MANDATORY_INFO_KEY = "javax.security.auth.message.MessagePolicy.isMandatory";
    private static final String SESSION_SAVED_SUBJECT_KEY = "ServerAuthModule.SAVED.SUBJECT";
    private static final String PARAM_JAAS_CONTEXT_PARAM = "jaas-context";
    private static final String PARAM_LOGIN_PAGE = "login-page";
    private static final String PARAM_LOGIN_ERROR = "login-error-page";
    //private static final String PARAM_DEFAULT_PAGE = "default-page";
    private static final String J_SECURITY_CHECK = "/j_security_check";
    private static final String J_SECURITY_LOGOUT = "/j_security_logout";

    private CallbackHandler handler;
    private String _jaasCtx;
    private String _loginPage;
    //private String _defaultPage;
    private String _loginErrorPage;
    private SSOCredentials credentials;
    private Map options;

    private String getStringOption(final java.util.Map options, final String name) throws AuthException {
        final Object val = options.get(name);
        if (val instanceof String) {
            return (String) val;
        } else {
            throw new AuthException(String.format("'%s' must be supplied as a property in the provider-config in the domain.xml file!", name));
        }
    }

    private String getLoginPage() throws AuthException {
        if (_loginPage == null) {
            _loginPage = getStringOption(options, PARAM_LOGIN_PAGE);
        }
        return _loginPage;
    }

    private String getLoginErrorPage() throws AuthException {
        if (_loginErrorPage == null) {
            _loginErrorPage = getStringOption(options, PARAM_LOGIN_ERROR);
        }
        return _loginErrorPage;
    }

    private String getJaasCtx() throws AuthException {
        if (_jaasCtx == null) {
            _jaasCtx = getStringOption(options, PARAM_JAAS_CONTEXT_PARAM);
        }
        return _jaasCtx;
    }

    @Override
    public void initialize(final MessagePolicy requestPolicy, final MessagePolicy responsePolicy, final CallbackHandler handler, final Map options) throws AuthException {
        if (LOG.isLoggable(Level.FINER)) {
            LOG.finer("initialize");
        }
        if (options == null) {
            throw new AuthException("options is null");
        }
        if (handler == null) {
            throw new AuthException("handler is null");
        }
        this.handler = handler;
        this.options = options;
        getLoginPage();
        getJaasCtx();
    }

    @Override
    public Class[] getSupportedMessageTypes() {
        if (LOG.isLoggable(Level.FINER)) {
            LOG.finer("getSupportedMessageTypes");
        }
        return SUPPORTED_MESSAGE_TYPES;
    }

    private boolean isMandatory(final MessageInfo messageInfo) {
        return Boolean.valueOf((String) messageInfo.getMap().get(IS_MANDATORY_INFO_KEY));
    }

    private AuthStatus redirectToErrorScreen(final HttpServletRequest request, final HttpServletResponse response, final String messageText, final String errorText) throws AuthException {
        if (LOG.isLoggable(Level.FINER)) {
            LOG.finer(String.format("redirectToErrorScreen messageText=%s errorText=%s", messageText, errorText));
        }
        request.setAttribute("messageText", messageText);
        request.setAttribute("errorText", errorText);
        final RequestDispatcher rDispatcher = request.getRequestDispatcher(getLoginErrorPage());
        try {
            rDispatcher.forward(request, response);
        } catch (ServletException | IOException ex) {
            LOG.log(Level.SEVERE, ex.getMessage(), ex);
        }
        return AuthStatus.SEND_FAILURE;
    }

    private AuthStatus redirectToErrorScreen(final HttpServletRequest request, final HttpServletResponse response, final Throwable throwable) throws AuthException {
        LOG.log(Level.SEVERE, throwable.getMessage(), throwable);
        return redirectToErrorScreen(request, response, throwable.getClass().getName(), throwable.getMessage());
    }

    private void doSSOHeader(final HttpServletRequest request, final HttpServletResponse response) {
        if (request.getHeader(Krb5.AUTHORIZATION) == null) {
            response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
            response.setHeader(Krb5.WWW_AUTHENTICATE, Krb5.NEGOTIATE);
        }
        request.setAttribute("doSSO", request.getRequestURI());
    }

    private AuthStatus redirectToLoginScreen(final HttpServletRequest request, final HttpServletResponse response) throws AuthException {
        doSSOHeader(request, response);
        final RequestDispatcher rDispatcher = request.getRequestDispatcher(getLoginPage());
        try {
            rDispatcher.forward(request, response);
        } catch (ServletException | IOException ex) {
            LOG.log(Level.SEVERE, ex.getMessage(), ex);
            final AuthException ae = new AuthException(String.format("Redirect to loginPage: %s", getLoginPage()));
            ae.initCause(ex);
            throw ae;
        }

        return AuthStatus.SEND_CONTINUE;
    }

    private AuthStatus logoutSession(final HttpServletRequest request, final HttpServletResponse response, final Subject clientSubject) throws AuthException {
        try {
            request.logout();
            if (LOG.isLoggable(Level.FINE)) {
                LOG.log(Level.FINE, String.format("validateRequest %s ==> logout redirect ", request.getRequestURI()));
            }
            try {
                final HttpSession session = request.getSession(false);
                if (session != null) {
                    session.invalidate();
                }
            } catch (Throwable e) {
                LOG.log(Level.WARNING, "Session was already invalid ", e);
            }
            return redirectToLoginScreen(request, response);
        } catch (ServletException ex) {
            return redirectToErrorScreen(request, response, ex);
        }
    }

    private boolean restoreSavedCredentials(final Subject clientSubject, final HttpSession session, final HttpServletRequest request, final HttpServletResponse response) {
        final Subject savedClientSubject = (session.getAttribute(SESSION_SAVED_SUBJECT_KEY) instanceof Subject ? (Subject) session.getAttribute(SESSION_SAVED_SUBJECT_KEY) : null); // NOPMD , long names
        if (savedClientSubject != null) {
            clientSubject.getPrincipals().addAll(savedClientSubject.getPrincipals());
            clientSubject.getPublicCredentials().addAll(savedClientSubject.getPublicCredentials());
            clientSubject.getPrivateCredentials().addAll(savedClientSubject.getPrivateCredentials());
            if (LOG.isLoggable(Level.FINE)) {
                LOG.log(Level.FINE, "CustomFormAuthModule validateRequest {0} ==> restored pricipals ==> SUCCESS  {1}\n\n\n\n", new Object[]{request.getRequestURI(), response.getStatus()});
            }
            return true;
        }
        return false;
    }

    private String getOrigin(final HttpServletRequest request) {
        final StringBuilder stringBuilder = new StringBuilder();
        for (final java.util.Enumeration vias = request.getHeaders("Via"); vias.hasMoreElements();) {
            if (stringBuilder.length() > 0) {
                stringBuilder.append(", ");
            }
            stringBuilder.append("Via:").append(vias.nextElement());
        }
        for (final java.util.Enumeration vias = request.getHeaders("x-forwarded-for"); vias.hasMoreElements();) {
            if (stringBuilder.length() > 0) {
                stringBuilder.append(", ");
            }
            stringBuilder.append("x-forwarded-for:").append(vias.nextElement());
        }
        return String.format("%s:%s %s", request.getContextPath(), request.getRemoteAddr(), stringBuilder.toString());
    }

    private boolean isLoginPage(final String uri) throws AuthException {
        return uri.endsWith(getLoginPage());
    }

    private boolean isRedirectUrl(final String uri) throws AuthException {
        return (isLoginPage(uri)) || (uri.endsWith(J_SECURITY_CHECK)) || (uri.endsWith(J_SECURITY_LOGOUT));
    }

    private String getRedirectUrl(final String uri) throws AuthException {
        if (isLoginPage(uri)) {
            return uri.substring(0, uri.length() - getLoginPage().length());
        }
        if (uri.endsWith(J_SECURITY_CHECK)) {
            return uri.substring(0, uri.length() - J_SECURITY_CHECK.length());
        }
        if (uri.endsWith(J_SECURITY_LOGOUT)) {
            return uri.substring(0, uri.length() - J_SECURITY_LOGOUT.length());
        }
        return uri;
    }

    private AuthStatus doRedirect(final HttpServletRequest request, final HttpServletResponse response, final String url) throws AuthException {
        try {
            if (LOG.isLoggable(Level.FINER)) {
                LOG.finer(String.format("sendRedirect: %s", url));
            }
            response.sendRedirect(url);
            return AuthStatus.SEND_CONTINUE;
        } catch (IOException ex) {
            return redirectToErrorScreen(request, response, ex);
        }
    }

    @Override
    public AuthStatus validateRequest(final MessageInfo messageInfo, final Subject clientSubject, final Subject serviceSubject) throws AuthException {
        try {
            final HttpServletRequest request = (HttpServletRequest) messageInfo.getRequestMessage();
            final HttpServletResponse response = (HttpServletResponse) messageInfo.getResponseMessage();

            if (LOG.isLoggable(Level.FINE)) {
                LOG.fine(String.format("validateRequest: %s - %s", handler, request.getRequestURI()));
            }

            final String auth = request.getHeader(Krb5.AUTHORIZATION);
            if (request.getRequestURI().endsWith(J_SECURITY_LOGOUT) && (auth == null)) {
                return logoutSession(request, response, clientSubject);
            }

            final HttpSession session = request.getSession(true);
            if ((session != null) && (restoreSavedCredentials(clientSubject, session, request, response))) {
                if (isRedirectUrl(request.getRequestURI())) {
                    return doRedirect(request, response, request.getContextPath());
                } else {
                    return AuthStatus.SUCCESS;
                }
            }

            if (!isMandatory(messageInfo) && !request.getRequestURI().endsWith(J_SECURITY_CHECK) && (auth == null)) {
                if (LOG.isLoggable(Level.FINE)) {
                    LOG.log(Level.FINE, "CustomFormAuthModule validateRequest notMandatory {0} ==> SUCCESS {1}\n\n\n\n", new Object[]{request.getRequestURI(), response.getStatus()});
                }
                if (isLoginPage(request.getRequestURI())) {
                    doSSOHeader(request, response);
                }
                return AuthStatus.SUCCESS;
            }

            final String fragment = request.getParameter("j_fragment");
            final String url = request.getParameter("j_url");
            final String saml = request.getParameter(SAMLClient.SAML_RESPONSE);
            final String redirectUrl;
            if ((auth != null) && (auth.startsWith(String.format("%s ", Krb5.NEGOTIATE)))) {
                credentials = new Krb5Credentials(request.getServerName(), auth.split(" ")[1], getOrigin(request));
                if (isRedirectUrl(request.getRequestURI())) {
                    redirectUrl = request.getContextPath();
                } else {
                    redirectUrl = null;
                }
            } else if (saml != null) {
                credentials = new SAMLCredentials(request.getServerName(), "GET".equalsIgnoreCase(request.getMethod()), saml, request.getQueryString(), getOrigin(request));
                final String relayState = request.getParameter(SAMLClient.SAML_RELAYSTATE);
                redirectUrl = relayState == null ? null : new String(Base64.getDecoder().decode(relayState.getBytes()));
            } else if (url != null) {
                final String relayState = new String(Base64.getEncoder().encode(String.format("%s%s", getRedirectUrl(url), fragment).getBytes()));
                try {
                    Common.getInstance(options).doSAMLRedirect(request, response, relayState);
                } catch (SAMLException | MessageEncodingException ex) {
                    return redirectToErrorScreen(request, response, ex);
                }
                return AuthStatus.SEND_CONTINUE;
            } else {
                return redirectToLoginScreen(request, response);
            }

            try {
                final LoginContext lc = new LoginContext(getJaasCtx(), clientSubject, this);
                lc.login();
                session.setAttribute(SESSION_SAVED_SUBJECT_KEY, clientSubject);// Save the Subject...
            } catch (LoginException ex) {
                return redirectToErrorScreen(request, response, ex);
            }
            if (redirectUrl == null) {
                return AuthStatus.SUCCESS;
            } else {
                return doRedirect(request, response, redirectUrl);
            }
        } catch (Throwable ex) {
            LOG.log(Level.SEVERE, ex.getMessage(), ex);
            throw ex;
        }
    }

    @Override
    public AuthStatus secureResponse(final MessageInfo messageInfo, final Subject serviceSubject) throws AuthException {
        if (LOG.isLoggable(Level.FINER)) {
            LOG.finer("secureResponse");
        }
        boolean wrapped = false;
        HttpServletRequest r = (HttpServletRequest) messageInfo.getRequestMessage();
        while (r != null && r instanceof HttpServletRequestWrapper) {
            r = (HttpServletRequest) ((ServletRequestWrapper) r).getRequest();
            wrapped = true;
        }
        if (wrapped) {
            messageInfo.setRequestMessage(r);
        }
        wrapped = false;
        HttpServletResponse s = (HttpServletResponse) messageInfo.getResponseMessage();
        while (s != null && s instanceof HttpServletResponseWrapper) {
            s = (HttpServletResponse) ((ServletResponseWrapper) s).getResponse();
            wrapped = true;
        }
        if (wrapped) {
            messageInfo.setResponseMessage(s);
        }

        return AuthStatus.SEND_SUCCESS;
    }

    @Override
    public void cleanSubject(final MessageInfo messageInfo, final Subject subject) throws AuthException {
        if (LOG.isLoggable(Level.FINER)) {
            LOG.finer("cleanSubject");
        }
        if (subject == null) {
            return;
        }
        final Object o = messageInfo.getRequestMessage();
        if ((o != null) && (o instanceof HttpServletRequest)) {
            final HttpServletRequest request = (HttpServletRequest) o;
            final HttpSession session = request.getSession(false);
            if (session != null) {
                if (LOG.isLoggable(Level.FINER)) {
                    LOG.finer("session.removeAttribute");
                }
                session.removeAttribute(SESSION_SAVED_SUBJECT_KEY);
            }
        }
        try {
            if (LOG.isLoggable(Level.FINER)) {
                LOG.finer("lc.logout");
            }
            final LoginContext lc = new LoginContext(getJaasCtx(), subject, this);
            lc.logout();
        } catch (LoginException ex) {
            LOG.log(Level.SEVERE, ex.getMessage(), ex);
        }

        if (LOG.isLoggable(Level.FINER)) {
            LOG.finer("clear subject");
        }
        subject.getPrincipals().clear();
        subject.getPrivateCredentials().clear();
        subject.getPublicCredentials().clear();
    }

    @Override
    public void handle(final Callback[] callbacks) throws IOException, UnsupportedCallbackException {
        if (LOG.isLoggable(Level.FINER)) {
            LOG.finer(String.format("handle %d", callbacks.length));
        }
        final List toHandle = new ArrayList<>();
        for (final Callback cb : callbacks) {
            if (LOG.isLoggable(Level.FINER)) {
                LOG.finer(String.format("handle %s: %s", cb.getClass(), cb));
            }
            if (cb instanceof SSOCallback) {
                final SSOCallback sso = (SSOCallback) cb;
                sso.setCredentials(credentials);
                continue;
            }
            toHandle.add(cb);
        }
        if ((handler != null) && (!toHandle.isEmpty())) {
            try {
                handler.handle(toHandle.toArray(new Callback[toHandle.size()]));
            } catch (IOException | UnsupportedCallbackException ex) {
                LOG.log(Level.SEVERE, ex.getMessage(), ex);
            }
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy