com.github.shawven.security.connect.MyConnectionRepository Maven / Gradle / Ivy
/*
* Copyright 2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.shawven.security.connect;
import org.springframework.dao.DuplicateKeyException;
import org.springframework.dao.EmptyResultDataAccessException;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcTemplate;
import org.springframework.security.crypto.encrypt.TextEncryptor;
import org.springframework.social.connect.*;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.*;
import java.util.Map.Entry;
public class MyConnectionRepository implements ConnectionRepository {
private final String userId;
private final JdbcTemplate jdbcTemplate;
private final ConnectionFactoryLocator connectionFactoryLocator;
private final TextEncryptor textEncryptor;
private final String table;
public MyConnectionRepository(String userId, JdbcTemplate jdbcTemplate, ConnectionFactoryLocator connectionFactoryLocator, TextEncryptor textEncryptor, String table) {
this.userId = userId;
this.jdbcTemplate = jdbcTemplate;
this.connectionFactoryLocator = connectionFactoryLocator;
this.textEncryptor = textEncryptor;
this.table = table;
}
@Override
public MultiValueMap> findAllConnections() {
List> resultList = jdbcTemplate.query(selectFromUserConnection() + " where user_id = ? order by provider_id, rank", connectionMapper, userId);
MultiValueMap> connections = new LinkedMultiValueMap>();
Set registeredProviderIds = connectionFactoryLocator.registeredProviderIds();
for (String registeredProviderId : registeredProviderIds) {
connections.put(registeredProviderId, Collections.>emptyList());
}
for (Connection> connection : resultList) {
String providerId = connection.getKey().getProviderId();
if (connections.get(providerId).size() == 0) {
connections.put(providerId, new LinkedList>());
}
connections.add(providerId, connection);
}
return connections;
}
@Override
public List> findConnections(String providerId) {
return jdbcTemplate.query(selectFromUserConnection() + " where user_id = ? and provider_id = ? order by rank", connectionMapper, userId, providerId);
}
@Override
@SuppressWarnings("unchecked")
public List> findConnections(Class apiType) {
List> connections = findConnections(getProviderId(apiType));
return (List>) connections;
}
@Override
public MultiValueMap> findConnectionsToUsers(MultiValueMap providerUsers) {
if (providerUsers == null || providerUsers.isEmpty()) {
throw new IllegalArgumentException("Unable to execute find: no providerUsers provided");
}
StringBuilder providerUsersCriteriaSql = new StringBuilder();
MapSqlParameterSource parameters = new MapSqlParameterSource();
parameters.addValue("userId", userId);
for (Iterator>> it = providerUsers.entrySet().iterator(); it.hasNext();) {
Entry> entry = it.next();
String providerId = entry.getKey();
providerUsersCriteriaSql.append("providerId = :providerId_").append(providerId).append(" and providerUserId in (:providerUserIds_").append(providerId).append(")");
parameters.addValue("providerId_" + providerId, providerId);
parameters.addValue("providerUserIds_" + providerId, entry.getValue());
if (it.hasNext()) {
providerUsersCriteriaSql.append(" or " );
}
}
List> resultList = new NamedParameterJdbcTemplate(jdbcTemplate).query(selectFromUserConnection() + " where user_id = :userId and " + providerUsersCriteriaSql + " order by providerId, rank", parameters, connectionMapper);
MultiValueMap> connectionsForUsers = new LinkedMultiValueMap>();
for (Connection> connection : resultList) {
String providerId = connection.getKey().getProviderId();
List userIds = providerUsers.get(providerId);
List> connections = connectionsForUsers.get(providerId);
if (connections == null) {
connections = new ArrayList<>(userIds.size());
for (int i = 0; i < userIds.size(); i++) {
connections.add(null);
}
connectionsForUsers.put(providerId, connections);
}
String providerUserId = connection.getKey().getProviderUserId();
int connectionIndex = userIds.indexOf(providerUserId);
connections.set(connectionIndex, connection);
}
return connectionsForUsers;
}
@Override
public Connection> getConnection(ConnectionKey connectionKey) {
try {
return jdbcTemplate.queryForObject(selectFromUserConnection() + " where user_id = ? and provider_id = ? and provider_user_id = ?", connectionMapper, userId, connectionKey.getProviderId(), connectionKey.getProviderUserId());
} catch (EmptyResultDataAccessException e) {
throw new NoSuchConnectionException(connectionKey);
}
}
@Override
@SuppressWarnings("unchecked")
public Connection getConnection(Class apiType, String providerUserId) {
String providerId = getProviderId(apiType);
return (Connection) getConnection(new ConnectionKey(providerId, providerUserId));
}
@Override
@SuppressWarnings("unchecked")
public Connection getPrimaryConnection(Class apiType) {
String providerId = getProviderId(apiType);
Connection connection = (Connection) findPrimaryConnection(providerId);
if (connection == null) {
throw new NotConnectedException(providerId);
}
return connection;
}
@Override
@SuppressWarnings("unchecked")
public Connection findPrimaryConnection(Class apiType) {
String providerId = getProviderId(apiType);
return (Connection) findPrimaryConnection(providerId);
}
@Override
@Transactional
public void addConnection(Connection> connection) {
try {
ConnectionData data = connection.createData();
int rank = jdbcTemplate.queryForObject("select coalesce(max(rank) + 1, 1) as rank from " + table + " where user_id = ? and provider_id = ?", new Object[]{ userId, data.getProviderId() }, Integer.class);
jdbcTemplate.update("insert into " + table + " (user_id, provider_id, provider_user_id, rank, display_name, profile_url, image_url, access_token, secret, refresh_token, expire_time) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
userId, data.getProviderId(), data.getProviderUserId(), rank, data.getDisplayName(), data.getProfileUrl(), data.getImageUrl(), encrypt(data.getAccessToken()), encrypt(data.getSecret()), encrypt(data.getRefreshToken()), data.getExpireTime());
} catch (DuplicateKeyException e) {
throw new DuplicateConnectionException(connection.getKey());
}
}
@Override
@Transactional
public void updateConnection(Connection> connection) {
ConnectionData data = connection.createData();
jdbcTemplate.update("update " + table + " set display_name = ?, profile_url = ?, image_url = ?, access_token = ?, secret = ?, refresh_token = ?, expire_time = ? where user_id = ? and provider_id = ? and provider_user_id = ?",
data.getDisplayName(), data.getProfileUrl(), data.getImageUrl(), encrypt(data.getAccessToken()), encrypt(data.getSecret()), encrypt(data.getRefreshToken()), data.getExpireTime(), userId, data.getProviderId(), data.getProviderUserId());
}
@Transactional
public void updateConnection(ConnectionData data) {
this.jdbcTemplate.update("update " + this.table + " set display_name = ?, profile_url = ?, image_url = ?, access_token = ?, secret = ?, refresh_token = ?, expire_time = ? where user_id = ? and provider_id = ? and provider_user_id = ?", new Object[]{data.getDisplayName(), data.getProfileUrl(), data.getImageUrl(), this.encrypt(data.getAccessToken()), this.encrypt(data.getSecret()), this.encrypt(data.getRefreshToken()), data.getExpireTime(), this.userId, data.getProviderId(), data.getProviderUserId()});
}
@Override
@Transactional
public void removeConnections(String providerId) {
jdbcTemplate.update("delete from " + table + " where user_id = ? and provider_id = ?", userId, providerId);
}
@Override
@Transactional
public void removeConnection(ConnectionKey connectionKey) {
jdbcTemplate.update("delete from " + table + " where user_id = ? and provider_id = ? and provider_user_id = ?", userId, connectionKey.getProviderId(), connectionKey.getProviderUserId());
}
// internal helpers
private String selectFromUserConnection() {
return "select userId, providerId, providerUserId, displayName, profileUrl, imageUrl, accessToken, secret, refreshToken, expireTime from " + table ;
}
private Connection> findPrimaryConnection(String providerId) {
List> connections = jdbcTemplate.query(selectFromUserConnection() + " where user_id = ? and provider_id = ? order by rank", connectionMapper, userId, providerId);
if (connections.size() > 0) {
return connections.get(0);
} else {
return null;
}
}
private final ServiceProviderConnectionMapper connectionMapper = new ServiceProviderConnectionMapper();
private final class ServiceProviderConnectionMapper implements RowMapper> {
@Override
public Connection> mapRow(ResultSet rs, int rowNum) throws SQLException {
ConnectionData connectionData = mapConnectionData(rs);
ConnectionFactory> connectionFactory = connectionFactoryLocator.getConnectionFactory(connectionData.getProviderId());
return connectionFactory.createConnection(connectionData);
}
private ConnectionData mapConnectionData(ResultSet rs) throws SQLException {
return new ConnectionData(
rs.getString("provider_id"),
rs.getString("provider_user_id"),
rs.getString("display_name"),
rs.getString("profile_url"),
rs.getString("image_url"),
decrypt(rs.getString("access_token")),
decrypt(rs.getString("secret")),
decrypt(rs.getString("refresh_token")),
expireTime(rs.getLong("expire_time")));
}
private String decrypt(String encryptedText) {
return encryptedText != null ? textEncryptor.decrypt(encryptedText) : encryptedText;
}
private Long expireTime(long expireTime) {
return expireTime == 0 ? null : expireTime;
}
}
private String getProviderId(Class apiType) {
return connectionFactoryLocator.getConnectionFactory(apiType).getProviderId();
}
private String encrypt(String text) {
return text != null ? textEncryptor.encrypt(text) : text;
}
}