All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter Maven / Gradle / Ivy

/*
 * Copyright 2002-2018 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.security.web.authentication.ui;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Map;
import java.util.function.Function;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;

import org.springframework.security.core.AuthenticationException;
import org.springframework.security.web.WebAttributes;
import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices;
import org.springframework.util.Assert;
import org.springframework.web.filter.GenericFilterBean;
import org.springframework.web.util.HtmlUtils;

/**
 * For internal use with namespace configuration in the case where a user doesn't
 * configure a login page. The configuration code will insert this filter in the chain
 * instead.
 *
 * Will only work if a redirect is used to the login page.
 *
 * @author Luke Taylor
 * @since 2.0
 */
public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {

	public static final String DEFAULT_LOGIN_PAGE_URL = "/login";

	public static final String ERROR_PARAMETER_NAME = "error";

	private String loginPageUrl;

	private String logoutSuccessUrl;

	private String failureUrl;

	private boolean formLoginEnabled;

	private boolean openIdEnabled;

	private boolean oauth2LoginEnabled;

	private boolean saml2LoginEnabled;

	private String authenticationUrl;

	private String usernameParameter;

	private String passwordParameter;

	private String rememberMeParameter;

	private String openIDauthenticationUrl;

	private String openIDusernameParameter;

	private String openIDrememberMeParameter;

	private Map oauth2AuthenticationUrlToClientName;

	private Map saml2AuthenticationUrlToProviderName;

	private Function> resolveHiddenInputs = (request) -> Collections.emptyMap();

	public DefaultLoginPageGeneratingFilter() {
	}

	public DefaultLoginPageGeneratingFilter(AbstractAuthenticationProcessingFilter filter) {
		if (filter instanceof UsernamePasswordAuthenticationFilter) {
			init((UsernamePasswordAuthenticationFilter) filter, null);
		}
		else {
			init(null, filter);
		}
	}

	public DefaultLoginPageGeneratingFilter(UsernamePasswordAuthenticationFilter authFilter,
			AbstractAuthenticationProcessingFilter openIDFilter) {
		init(authFilter, openIDFilter);
	}

	private void init(UsernamePasswordAuthenticationFilter authFilter,
			AbstractAuthenticationProcessingFilter openIDFilter) {
		this.loginPageUrl = DEFAULT_LOGIN_PAGE_URL;
		this.logoutSuccessUrl = DEFAULT_LOGIN_PAGE_URL + "?logout";
		this.failureUrl = DEFAULT_LOGIN_PAGE_URL + "?" + ERROR_PARAMETER_NAME;
		if (authFilter != null) {
			initAuthFilter(authFilter);
		}
		if (openIDFilter != null) {
			initOpenIdFilter(openIDFilter);
		}
	}

	private void initAuthFilter(UsernamePasswordAuthenticationFilter authFilter) {
		this.formLoginEnabled = true;
		this.usernameParameter = authFilter.getUsernameParameter();
		this.passwordParameter = authFilter.getPasswordParameter();
		if (authFilter.getRememberMeServices() instanceof AbstractRememberMeServices) {
			this.rememberMeParameter = ((AbstractRememberMeServices) authFilter.getRememberMeServices()).getParameter();
		}
	}

	private void initOpenIdFilter(AbstractAuthenticationProcessingFilter openIDFilter) {
		this.openIdEnabled = true;
		this.openIDusernameParameter = "openid_identifier";
		if (openIDFilter.getRememberMeServices() instanceof AbstractRememberMeServices) {
			this.openIDrememberMeParameter = ((AbstractRememberMeServices) openIDFilter.getRememberMeServices())
					.getParameter();
		}
	}

	/**
	 * Sets a Function used to resolve a Map of the hidden inputs where the key is the
	 * name of the input and the value is the value of the input. Typically this is used
	 * to resolve the CSRF token.
	 * @param resolveHiddenInputs the function to resolve the inputs
	 */
	public void setResolveHiddenInputs(Function> resolveHiddenInputs) {
		Assert.notNull(resolveHiddenInputs, "resolveHiddenInputs cannot be null");
		this.resolveHiddenInputs = resolveHiddenInputs;
	}

	public boolean isEnabled() {
		return this.formLoginEnabled || this.openIdEnabled || this.oauth2LoginEnabled || this.saml2LoginEnabled;
	}

	public void setLogoutSuccessUrl(String logoutSuccessUrl) {
		this.logoutSuccessUrl = logoutSuccessUrl;
	}

	public String getLoginPageUrl() {
		return this.loginPageUrl;
	}

	public void setLoginPageUrl(String loginPageUrl) {
		this.loginPageUrl = loginPageUrl;
	}

	public void setFailureUrl(String failureUrl) {
		this.failureUrl = failureUrl;
	}

	public void setFormLoginEnabled(boolean formLoginEnabled) {
		this.formLoginEnabled = formLoginEnabled;
	}

	public void setOpenIdEnabled(boolean openIdEnabled) {
		this.openIdEnabled = openIdEnabled;
	}

	public void setOauth2LoginEnabled(boolean oauth2LoginEnabled) {
		this.oauth2LoginEnabled = oauth2LoginEnabled;
	}

	public void setSaml2LoginEnabled(boolean saml2LoginEnabled) {
		this.saml2LoginEnabled = saml2LoginEnabled;
	}

	public void setAuthenticationUrl(String authenticationUrl) {
		this.authenticationUrl = authenticationUrl;
	}

	public void setUsernameParameter(String usernameParameter) {
		this.usernameParameter = usernameParameter;
	}

	public void setPasswordParameter(String passwordParameter) {
		this.passwordParameter = passwordParameter;
	}

	public void setRememberMeParameter(String rememberMeParameter) {
		this.rememberMeParameter = rememberMeParameter;
		this.openIDrememberMeParameter = rememberMeParameter;
	}

	public void setOpenIDauthenticationUrl(String openIDauthenticationUrl) {
		this.openIDauthenticationUrl = openIDauthenticationUrl;
	}

	public void setOpenIDusernameParameter(String openIDusernameParameter) {
		this.openIDusernameParameter = openIDusernameParameter;
	}

	public void setOauth2AuthenticationUrlToClientName(Map oauth2AuthenticationUrlToClientName) {
		this.oauth2AuthenticationUrlToClientName = oauth2AuthenticationUrlToClientName;
	}

	public void setSaml2AuthenticationUrlToProviderName(Map saml2AuthenticationUrlToProviderName) {
		this.saml2AuthenticationUrlToProviderName = saml2AuthenticationUrlToProviderName;
	}

	@Override
	public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
			throws IOException, ServletException {
		doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
	}

	private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
			throws IOException, ServletException {
		boolean loginError = isErrorPage(request);
		boolean logoutSuccess = isLogoutSuccess(request);
		if (isLoginUrlRequest(request) || loginError || logoutSuccess) {
			String loginPageHtml = generateLoginPageHtml(request, loginError, logoutSuccess);
			response.setContentType("text/html;charset=UTF-8");
			response.setContentLength(loginPageHtml.getBytes(StandardCharsets.UTF_8).length);
			response.getWriter().write(loginPageHtml);
			return;
		}
		chain.doFilter(request, response);
	}

	private String generateLoginPageHtml(HttpServletRequest request, boolean loginError, boolean logoutSuccess) {
		String errorMsg = "Invalid credentials";
		if (loginError) {
			HttpSession session = request.getSession(false);
			if (session != null) {
				AuthenticationException ex = (AuthenticationException) session
						.getAttribute(WebAttributes.AUTHENTICATION_EXCEPTION);
				errorMsg = (ex != null) ? ex.getMessage() : "Invalid credentials";
			}
		}
		String contextPath = request.getContextPath();
		StringBuilder sb = new StringBuilder();
		sb.append("\n");
		sb.append("\n");
		sb.append("  \n");
		sb.append("    \n");
		sb.append("    \n");
		sb.append("    \n");
		sb.append("    \n");
		sb.append("    Please sign in\n");
		sb.append("    \n");
		sb.append("    \n");
		sb.append("  \n");
		sb.append("  \n");
		sb.append("     
\n"); if (this.formLoginEnabled) { sb.append("
\n"); sb.append(" \n"); sb.append(createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + "

\n"); sb.append(" \n"); sb.append(" \n"); sb.append("

\n"); sb.append("

\n"); sb.append(" \n"); sb.append(" \n"); sb.append("

\n"); sb.append(createRememberMe(this.rememberMeParameter) + renderHiddenInputs(request)); sb.append(" \n"); sb.append("
\n"); } if (this.openIdEnabled) { sb.append("
\n"); sb.append(" \n"); sb.append(createError(loginError, errorMsg) + createLogoutSuccess(logoutSuccess) + "

\n"); sb.append(" \n"); sb.append(" \n"); sb.append("

\n"); sb.append(createRememberMe(this.openIDrememberMeParameter) + renderHiddenInputs(request)); sb.append(" \n"); sb.append("
\n"); } if (this.oauth2LoginEnabled) { sb.append(""); sb.append(createError(loginError, errorMsg)); sb.append(createLogoutSuccess(logoutSuccess)); sb.append("\n"); for (Map.Entry clientAuthenticationUrlToClientName : this.oauth2AuthenticationUrlToClientName .entrySet()) { sb.append(" \n"); } sb.append("
"); String url = clientAuthenticationUrlToClientName.getKey(); sb.append(""); String clientName = HtmlUtils.htmlEscape(clientAuthenticationUrlToClientName.getValue()); sb.append(clientName); sb.append(""); sb.append("
\n"); } if (this.saml2LoginEnabled) { sb.append(""); sb.append(createError(loginError, errorMsg)); sb.append(createLogoutSuccess(logoutSuccess)); sb.append("\n"); for (Map.Entry relyingPartyUrlToName : this.saml2AuthenticationUrlToProviderName .entrySet()) { sb.append(" \n"); } sb.append("
"); String url = relyingPartyUrlToName.getKey(); sb.append(""); String partyName = HtmlUtils.htmlEscape(relyingPartyUrlToName.getValue()); sb.append(partyName); sb.append(""); sb.append("
\n"); } sb.append("
\n"); sb.append(""); return sb.toString(); } private String renderHiddenInputs(HttpServletRequest request) { StringBuilder sb = new StringBuilder(); for (Map.Entry input : this.resolveHiddenInputs.apply(request).entrySet()) { sb.append("\n"); } return sb.toString(); } private String createRememberMe(String paramName) { if (paramName == null) { return ""; } return "

Remember me on this computer.

\n"; } private boolean isLogoutSuccess(HttpServletRequest request) { return this.logoutSuccessUrl != null && matches(request, this.logoutSuccessUrl); } private boolean isLoginUrlRequest(HttpServletRequest request) { return matches(request, this.loginPageUrl); } private boolean isErrorPage(HttpServletRequest request) { return matches(request, this.failureUrl); } private static String createError(boolean isError, String message) { if (!isError) { return ""; } return "
" + HtmlUtils.htmlEscape(message) + "
"; } private static String createLogoutSuccess(boolean isLogoutSuccess) { if (!isLogoutSuccess) { return ""; } return "
You have been signed out
"; } private boolean matches(HttpServletRequest request, String url) { if (!"GET".equals(request.getMethod()) || url == null) { return false; } String uri = request.getRequestURI(); int pathParamIndex = uri.indexOf(';'); if (pathParamIndex > 0) { // strip everything after the first semi-colon uri = uri.substring(0, pathParamIndex); } if (request.getQueryString() != null) { uri += "?" + request.getQueryString(); } if ("".equals(request.getContextPath())) { return uri.equals(url); } return uri.equals(request.getContextPath() + url); } }




© 2015 - 2024 Weber Informatics LLC | Privacy Policy