org.springframework.social.connect.jdbc.JdbcConnectionRepository Maven / Gradle / Ivy
/*
* Copyright 2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.social.connect.jdbc;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import org.springframework.dao.DuplicateKeyException;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.security.crypto.encrypt.TextEncryptor;
import org.springframework.social.connect.Connection;
import org.springframework.social.connect.ConnectionData;
import org.springframework.social.connect.ConnectionFactory;
import org.springframework.social.connect.ConnectionFactoryLocator;
import org.springframework.social.connect.ConnectionKey;
import org.springframework.social.connect.ConnectionRepository;
import org.springframework.social.connect.DuplicateConnectionException;
import org.springframework.social.connect.NoSuchConnectionException;
import org.springframework.social.connect.NotConnectedException;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
class JdbcConnectionRepository implements ConnectionRepository {
private final String userId;
private final JdbcTemplate jdbcTemplate;
private final ConnectionFactoryLocator connectionFactoryLocator;
private final TextEncryptor textEncryptor;
private final String tablePrefix;
public JdbcConnectionRepository(String userId, JdbcTemplate jdbcTemplate, ConnectionFactoryLocator connectionFactoryLocator, TextEncryptor textEncryptor, String tablePrefix) {
this.userId = userId;
this.jdbcTemplate = jdbcTemplate;
this.connectionFactoryLocator = connectionFactoryLocator;
this.textEncryptor = textEncryptor;
this.tablePrefix = tablePrefix;
}
public MultiValueMap> findAllConnections() {
List> resultList = jdbcTemplate.query(selectFromUserConnection() + " where userId = ? order by providerId, rank", connectionMapper, userId);
MultiValueMap> connections = new LinkedMultiValueMap>();
Set registeredProviderIds = connectionFactoryLocator.registeredProviderIds();
for (String registeredProviderId : registeredProviderIds) {
connections.put(registeredProviderId, Collections.>emptyList());
}
for (Connection connection : resultList) {
String providerId = connection.getKey().getProviderId();
if (connections.get(providerId).size() == 0) {
connections.put(providerId, new LinkedList>());
}
connections.add(providerId, connection);
}
return connections;
}
public List> findConnections(String providerId) {
return jdbcTemplate.query(selectFromUserConnection() + " where userId = ? and providerId = ? order by rank", connectionMapper, userId, providerId);
}
@SuppressWarnings("unchecked")
public List> findConnections(Class apiType) {
List connections = findConnections(getProviderId(apiType));
return (List>) connections;
}
public MultiValueMap> findConnectionsToUsers(MultiValueMap providerUsers) {
if (providerUsers == null || providerUsers.isEmpty()) {
throw new IllegalArgumentException("Unable to execute find: no providerUsers provided");
}
StringBuilder providerUsersCriteriaSql = new StringBuilder();
MapSqlParameterSource parameters = new MapSqlParameterSource();
parameters.addValue("userId", userId);
for (Iterator>> it = providerUsers.entrySet().iterator(); it.hasNext();) {
Entry> entry = it.next();
String providerId = entry.getKey();
providerUsersCriteriaSql.append("providerId = :providerId_").append(providerId).append(" and providerUserId in (:providerUserIds_").append(providerId).append(")");
parameters.addValue("providerId_" + providerId, providerId);
parameters.addValue("providerUserIds_" + providerId, entry.getValue());
if (it.hasNext()) {
providerUsersCriteriaSql.append(" or " );
}
}
List> resultList = new NamedParameterJdbcTemplate(jdbcTemplate).query(selectFromUserConnection() + " where userId = :userId and " + providerUsersCriteriaSql + " order by providerId, rank", parameters, connectionMapper);
MultiValueMap> connectionsForUsers = new LinkedMultiValueMap>();
for (Connection connection : resultList) {
String providerId = connection.getKey().getProviderId();
List userIds = providerUsers.get(providerId);
List> connections = connectionsForUsers.get(providerId);
if (connections == null) {
connections = new ArrayList>(userIds.size());
for (int i = 0; i < userIds.size(); i++) {
connections.add(null);
}
connectionsForUsers.put(providerId, connections);
}
String providerUserId = connection.getKey().getProviderUserId();
int connectionIndex = userIds.indexOf(providerUserId);
connections.set(connectionIndex, connection);
}
return connectionsForUsers;
}
public Connection getConnection(ConnectionKey connectionKey) {
try {
return jdbcTemplate.queryForObject(selectFromUserConnection() + " where userId = ? and providerId = ? and providerUserId = ?", connectionMapper, userId, connectionKey.getProviderId(), connectionKey.getProviderUserId());
} catch (EmptyResultDataAccessException e) {
throw new NoSuchConnectionException(connectionKey);
}
}
@SuppressWarnings("unchecked")
public Connection getConnection(Class apiType, String providerUserId) {
String providerId = getProviderId(apiType);
return (Connection) getConnection(new ConnectionKey(providerId, providerUserId));
}
@SuppressWarnings("unchecked")
public Connection getPrimaryConnection(Class apiType) {
String providerId = getProviderId(apiType);
Connection connection = (Connection) findPrimaryConnection(providerId);
if (connection == null) {
throw new NotConnectedException(providerId);
}
return connection;
}
@SuppressWarnings("unchecked")
public Connection findPrimaryConnection(Class apiType) {
String providerId = getProviderId(apiType);
return (Connection) findPrimaryConnection(providerId);
}
@Transactional
public void addConnection(Connection connection) {
try {
ConnectionData data = connection.createData();
int rank = jdbcTemplate.queryForObject("select coalesce(max(rank) + 1, 1) as rank from " + tablePrefix + "UserConnection where userId = ? and providerId = ?", new Object[]{ userId, data.getProviderId() }, Integer.class);
jdbcTemplate.update("insert into " + tablePrefix + "UserConnection (userId, providerId, providerUserId, rank, displayName, profileUrl, imageUrl, accessToken, secret, refreshToken, expireTime) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
userId, data.getProviderId(), data.getProviderUserId(), rank, data.getDisplayName(), data.getProfileUrl(), data.getImageUrl(), encrypt(data.getAccessToken()), encrypt(data.getSecret()), encrypt(data.getRefreshToken()), data.getExpireTime());
} catch (DuplicateKeyException e) {
throw new DuplicateConnectionException(connection.getKey());
}
}
@Transactional
public void updateConnection(Connection connection) {
ConnectionData data = connection.createData();
jdbcTemplate.update("update " + tablePrefix + "UserConnection set displayName = ?, profileUrl = ?, imageUrl = ?, accessToken = ?, secret = ?, refreshToken = ?, expireTime = ? where userId = ? and providerId = ? and providerUserId = ?",
data.getDisplayName(), data.getProfileUrl(), data.getImageUrl(), encrypt(data.getAccessToken()), encrypt(data.getSecret()), encrypt(data.getRefreshToken()), data.getExpireTime(), userId, data.getProviderId(), data.getProviderUserId());
}
@Transactional
public void removeConnections(String providerId) {
jdbcTemplate.update("delete from " + tablePrefix + "UserConnection where userId = ? and providerId = ?", userId, providerId);
}
@Transactional
public void removeConnection(ConnectionKey connectionKey) {
jdbcTemplate.update("delete from " + tablePrefix + "UserConnection where userId = ? and providerId = ? and providerUserId = ?", userId, connectionKey.getProviderId(), connectionKey.getProviderUserId());
}
// internal helpers
private String selectFromUserConnection() {
return "select userId, providerId, providerUserId, displayName, profileUrl, imageUrl, accessToken, secret, refreshToken, expireTime from " + tablePrefix + "UserConnection";
}
private Connection findPrimaryConnection(String providerId) {
List> connections = jdbcTemplate.query(selectFromUserConnection() + " where userId = ? and providerId = ? order by rank", connectionMapper, userId, providerId);
if (connections.size() > 0) {
return connections.get(0);
} else {
return null;
}
}
private final ServiceProviderConnectionMapper connectionMapper = new ServiceProviderConnectionMapper();
private final class ServiceProviderConnectionMapper implements RowMapper> {
public Connection mapRow(ResultSet rs, int rowNum) throws SQLException {
ConnectionData connectionData = mapConnectionData(rs);
ConnectionFactory connectionFactory = connectionFactoryLocator.getConnectionFactory(connectionData.getProviderId());
return connectionFactory.createConnection(connectionData);
}
private ConnectionData mapConnectionData(ResultSet rs) throws SQLException {
return new ConnectionData(rs.getString("providerId"), rs.getString("providerUserId"), rs.getString("displayName"), rs.getString("profileUrl"), rs.getString("imageUrl"),
decrypt(rs.getString("accessToken")), decrypt(rs.getString("secret")), decrypt(rs.getString("refreshToken")), expireTime(rs.getLong("expireTime")));
}
private String decrypt(String encryptedText) {
return encryptedText != null ? textEncryptor.decrypt(encryptedText) : encryptedText;
}
private Long expireTime(long expireTime) {
return expireTime == 0 ? null : expireTime;
}
}
private String getProviderId(Class apiType) {
return connectionFactoryLocator.getConnectionFactory(apiType).getProviderId();
}
private String encrypt(String text) {
return text != null ? textEncryptor.encrypt(text) : text;
}
}