org.jboss.security.auth.callback.DatabaseCallbackHandler Maven / Gradle / Ivy
/*
* JBoss, Home of Professional Open Source.
* Copyright 2011, Red Hat Middleware LLC, and individual contributors
* as indicated by the @author tags. See the copyright.txt file in the
* distribution for a full listing of individual contributors.
*
* This is free software; you can redistribute it and/or modify it
* under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of
* the License, or (at your option) any later version.
*
* This software is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this software; if not, write to the Free
* Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
* 02110-1301 USA, or see the FSF site: http://www.fsf.org.
*/
package org.jboss.security.auth.callback;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Map;
import javax.naming.InitialContext;
import javax.naming.NamingException;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.LoginException;
import javax.sql.DataSource;
import org.jboss.security.PicketBoxMessages;
import org.jboss.security.vault.SecurityVaultException;
import org.jboss.security.vault.SecurityVaultUtil;
/**
*
* A {@code CallbackHandler} that uses a DB.
*
*
* Configuration:
* There are two ways to specify the configuration.
*
* - Using the {@code #setConfiguration(Map)} method, which uses {@code String} based key/value pair.
* - Using the methods {@code #setConnectionUrl(String)}, {@code #setDbDriverName(String)}etc
*
*
*
* Either you can specify the connection url, driver class name or you can provide the jndi name of the {@code DataSource}.
*
* @author Anil Saldhana
* @since Oct 31, 2011
*/
public class DatabaseCallbackHandler extends AbstractCallbackHandler implements CallbackHandler
{
public static final String CONNECTION_URL = "connectionURL";
public static final String DS_JNDI_NAME = "dsJndiName";
public static final String DB_DRIVERNAME = "dbDriverName";
public static final String DB_USERNAME = "dbUserName";
public static final String DB_USERPASS = "dbUserPass";
public static final String PRINCIPALS_QUERY = "principalsQuery";
/**
* A DB specific connection url
*/
protected String connectionUrl;
/**
* JNDI Name of the Datasource
*/
protected String dsJndiName;
/**
* A DB username to connect
*/
protected String dsUserName;
/**
* A DB password to connect
*/
protected String dsUserPass;
/**
* A DB Driver Class Name
*/
protected String dbDriverName;
/** The sql query to obtain the user password */
protected String principalsQuery = "select Password from Principals where PrincipalID=?";
public DatabaseCallbackHandler()
{
}
/**
* Get the DB specific connection URL
* Eg: "jdbc:hsqldb:mem:unit_test"
* @return
*/
public String getConnectionUrl()
{
return connectionUrl;
}
public void setConnectionUrl(String connectionUrl)
{
this.connectionUrl = connectionUrl;
}
/**
* Get the JNDI name of the SQL Datasource
* @return
*/
public String getDsJndiName()
{
return dsJndiName;
}
public void setDsJndiName(String dsJndiName)
{
this.dsJndiName = dsJndiName;
}
/**
* Get the DB user name
* @return
*/
public String getDsUserName()
{
return dsUserName;
}
public void setDsUserName(String dsUserName)
{
this.dsUserName = dsUserName;
}
/**
* Get the DB user pass
* @return
*/
public String getDsUserPass()
{
return dsUserPass;
}
public void setDsUserPass(String dsUserPass)
{
this.dsUserPass = dsUserPass;
}
/**
* Get the fully qualified name of sql driver class
* Eg: org.hsqldb.jdbc.JDBCDriver
* @return
*/
public String getDbDriverName()
{
return dbDriverName;
}
public void setDbDriverName(String dbDriverName)
{
this.dbDriverName = dbDriverName;
}
public String getPrincipalsQuery()
{
return principalsQuery;
}
public void setPrincipalsQuery(String principalsQuery) {
this.principalsQuery = principalsQuery;
}
public String getUserName() {
return userName;
}
public void setUserName(String theUserName)
{
if(theUserName == null)
{
throw PicketBoxMessages.MESSAGES.invalidNullArgument("userName");
}
userName = theUserName;
}
/**
* Set a {@code Map} that contains keys that are strings and values that are strings
* @param config
*/
public void setConfiguration(Map config)
{
String tmp = null;
dbDriverName = config.get(DB_DRIVERNAME);
connectionUrl = config.get(CONNECTION_URL);
if(connectionUrl == null || connectionUrl.length() == 0)
{
dsJndiName = config.get(DS_JNDI_NAME);
}
dsUserName = config.get(DB_USERNAME);
dsUserPass = config.get(DB_USERPASS);
if(dsUserPass != null)
{
if(SecurityVaultUtil.isVaultFormat(dsUserPass))
{
try
{
dsUserPass = SecurityVaultUtil.getValueAsString(dsUserPass);
}
catch (SecurityVaultException e)
{
throw new RuntimeException(e);
}
}
}
tmp = config.get(PRINCIPALS_QUERY);
if(tmp != null)
{
principalsQuery = tmp;
}
}
/*
* @see javax.security.auth.callback.CallbackHandler#handle(javax.security.auth.callback.Callback[])
*/
public void handle(Callback[] callbacks) throws IOException,
UnsupportedCallbackException
{
if(userName == null)
{
userName = getUserName(callbacks);
}
for (int i = 0; i < callbacks.length; i++)
{
Callback callback = callbacks[i];
this.handleCallBack( callback );
}
}
/**
* Handle a {@code Callback}
* @param c callback
* @throws UnsupportedCallbackException If the callback is not supported by this handler
* @throws IOException
*/
protected void handleCallBack( Callback c ) throws UnsupportedCallbackException, IOException
{
if(c instanceof VerifyPasswordCallback)
{
VerifyPasswordCallback vpc = (VerifyPasswordCallback) c;
try
{
handleVerification(vpc);
}
catch (LoginException e)
{
throw new IOException(e);
}
}
if(c instanceof PasswordCallback == false)
return;
PasswordCallback passwdCallback = (PasswordCallback) c;
passwdCallback.setPassword(getPassword().toCharArray());
}
protected void handleVerification(VerifyPasswordCallback vpc) throws LoginException
{
String userPass = vpc.getValue();
String passwordFromDB = getPassword();
if(userPass.equals(passwordFromDB))
{
vpc.setVerified(true);
}
else
{
throw new LoginException(PicketBoxMessages.MESSAGES.authenticationFailedMessage());
}
}
private String getPassword()
{
String password = null;
Connection conn = null;
PreparedStatement ps = null;
ResultSet rs = null;
try
{
conn = getConnection();
ps = conn.prepareStatement(principalsQuery);
ps.setString(1, userName);
rs = ps.executeQuery();
if( rs.next() == false )
{
throw PicketBoxMessages.MESSAGES.unableToFindPrincipalInDB(userName);
}
password = rs.getString(1);
}
catch (Exception e)
{
throw new RuntimeException(e);
}
finally
{
safeClose(rs);
safeClose(ps);
safeClose(conn);
}
return password;
}
private Connection getConnection() throws SQLException, NamingException
{
Connection conn = null;
if(dbDriverName != null)
{
try
{
Class.forName(dbDriverName);
}
catch (ClassNotFoundException e)
{
throw new RuntimeException(e);
}
}
if(connectionUrl != null)
{
if(dsUserName != null)
{
conn = DriverManager.getConnection(connectionUrl, dsUserName, dsUserPass);
}
else
{
conn = DriverManager.getConnection(connectionUrl);
}
}
else
{
InitialContext ic = new InitialContext();
if(dsJndiName == null)
{
throw PicketBoxMessages.MESSAGES.unableToLookupDataSource();
}
DataSource ds = (DataSource) ic.lookup(dsJndiName);
if(ds != null)
{
conn = ds.getConnection();
}
}
return conn;
}
protected void safeClose(ResultSet rs)
{
if( rs != null)
{
try
{
rs.close();
}
catch (SQLException e)
{
}
}
}
protected void safeClose(Connection conn)
{
if( conn != null)
{
try
{
conn.close();
}
catch (SQLException e)
{
}
}
}
protected void safeClose(Statement stat)
{
if( stat != null)
{
try
{
stat.close();
}
catch (SQLException e)
{
}
}
}
}