All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.icij.extract.mysql.MySQLConcurrentMap Maven / Gradle / Ivy

There is a newer version: 7.4.0
Show newest version
package org.icij.extract.mysql;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Collection;
import java.util.Map;
import java.util.Set;

public class MySQLConcurrentMap extends SQLConcurrentMap {

	private final String table;

	public MySQLConcurrentMap(final DataSource dataSource, final SQLMapCodec codec, final String table) {
		super(dataSource, codec);
		this.table = table;
	}

	private void executeInsert(final Connection c, final K key, final V value) throws SQLException {
		final Map values = codec.encodeValue(value);
		values.putAll(codec.encodeKey(key));

		try (final PreparedStatement q = c.prepareStatement("INSERT " + table + " SET " +
				String.join(",", values.keySet().stream().map(k -> k + " = ?").toArray(String[]::new)) + ";")) {

			int i = 1;
			for (String k : values.keySet()) {
				q.setObject(i++, values.get(k));
			}

			q.executeUpdate();
		}
	}

	private int executeInsertOrUpdate(final Connection c, final K key, final V value) throws SQLException {
		final Map values = codec.encodeValue(value);
		values.putAll(codec.encodeKey(key));

		final Set keys = values.keySet();

		final String sql = "INSERT INTO " + table + " (" +
				String.join(", ", keys.toArray(new String[keys.size()])) +
				") VALUES(" +
				String.join(", ", keys.stream().map(k -> "?").toArray(String[]::new)) +
				") ON DUPLICATE KEY UPDATE " +
				String.join(", ", keys.stream().map(k -> k + " = VALUES(" + k + ")").toArray(String[]::new)) + ";";

		try (final PreparedStatement q = c.prepareStatement(sql)) {
			int i = 1;

			for (String k : keys) {
				q.setObject(i++, values.get(k));
			}

			return q.executeUpdate();
		}
	}

	private int executeUpdate(final Connection c, final K key, final V value) throws SQLException {
		final Map values = codec.encodeValue(value);
		final Set valuesKeySet = values.keySet();

		final Map keys = codec.encodeKey(key);
		final Set keysKeySet = keys.keySet();

		try (final PreparedStatement q = c.prepareStatement("UPDATE " + table + " SET " +
				String.join(",", valuesKeySet.stream().map(k -> k + " = ?").toArray(String[]::new)) +
				" WHERE " +
				String.join(" AND ", keysKeySet.stream().map(k -> k + " = ?").toArray(String[]::new)) + ";")) {
			int i = 1;

			for (String k: valuesKeySet) {
				q.setObject(i++, values.get(k));
			}

			for (String k: keysKeySet) {
				q.setObject(i++, keys.get(k));
			}

			return q.executeUpdate();
		}
	}

	private V executeSelectForUpdate(final Connection c, final Object key) throws SQLException {
		final Map keys = codec.encodeKey(key);
		final Set keySet = keys.keySet();

		try (final PreparedStatement q = c.prepareStatement("SELECT * FROM " + table + " WHERE " +
				String.join(" AND ", keySet.stream().map(k -> k + " = ?").toArray(String[]::new)) +
				" LIMIT 1 FOR UPDATE;")) {
			int i = 1;
			for (String k: keySet) {
				q.setObject(i++, keys.get(k));
			}

			try (final ResultSet rs = q.executeQuery()) {
				if (!rs.next()) {
					return null;
				}

				return codec.decodeValue(rs);
			}
		}
	}

	private V executeSelect(final Connection c, final Object key) throws SQLException {
		final Map keys = codec.encodeKey(key);
		final Set keySet = keys.keySet();

		try (final PreparedStatement q = c.prepareStatement("SELECT * FROM " + table + " WHERE " +
				String.join(" AND ", keySet.stream().map(k -> k + " = ?").toArray(String[]::new)) + ";")) {
			int i = 1;
			for (String k: keySet) {
				q.setObject(i++, keys.get(k));
			}

			try (final ResultSet rs = q.executeQuery()) {
				if (!rs.next()) {
					return null;
				}

				return codec.decodeValue(rs);
			}
		}
	}

	@Override
	public boolean replace(final K key, final V oldValue, final V newValue) {
		return dataSource.withConnectionUnchecked(c -> {
			final int result;
			c.setAutoCommit(false);

			try {
				if (!oldValue.equals(executeSelectForUpdate(c, key))) {
					c.rollback();
					return false;
				}

				result = executeUpdate(c, key, newValue);
			} catch (final SQLException e) {
				c.rollback();
				throw e;
			}

			c.commit();
			return result > 0;
		});
	}

	@Override
	public V replace(final K key, final V value) {
		return dataSource.withConnectionUnchecked(c -> {
			final V oldValue;
			c.setAutoCommit(false);

			try {
				oldValue = executeSelectForUpdate(c, key);
				executeUpdate(c, key, value);
			} catch (SQLException e) {
				c.rollback();
				throw e;
			}

			c.commit();
			return oldValue;
		});
	}

	@Override
	public void clear() {
		dataSource.withStatementUnchecked("DELETE FROM " + table + ";",
				(CheckedConsumer) PreparedStatement::executeUpdate);
	}

	@Override
	public boolean remove(final Object key, final Object value) {
		return dataSource.withConnectionUnchecked(c -> {
			final int result;
			c.setAutoCommit(false);

			try {
				if (!value.equals(executeSelectForUpdate(c, key))) {
					c.rollback();
					return false;
				}
			} catch (SQLException e) {
				c.rollback();
				throw e;
			}

			final Map values = codec.encodeKey(key);
			values.putAll(codec.encodeValue(value));

			final Set keySet = values.keySet();

			try (final PreparedStatement q = c.prepareStatement("DELETE FROM " + table + " WHERE " +
					String.join(" AND ", keySet.stream().map(k -> k + " = ?").toArray(String[]::new)) + ";")) {
				int i = 1;
				for (String k : keySet) {
					q.setObject(i++, values.get(k));
				}

				result = q.executeUpdate();
			} catch (SQLException e) {
				c.rollback();
				throw e;
			}

			c.commit();
			return result > 0;
		});
	}

	@Override
	public V remove(final Object key) {
		return dataSource.withConnectionUnchecked(c -> {
			final V oldValue;
			c.setAutoCommit(false);

			try {
				oldValue = executeSelectForUpdate(c, key);
			} catch (SQLException e) {
				c.rollback();
				throw e;
			}

			final Map keys = codec.encodeKey(key);
			final Set keySet = keys.keySet();

			try (final PreparedStatement q = c.prepareStatement("DELETE FROM " + table + " WHERE " +
					String.join(" AND ", keySet.stream().map(k -> k + " = ?").toArray(String[]::new)) + ";")) {
				int i = 1;
				for (String k : keySet) {
					q.setObject(i++, keys.get(k));
				}

				q.executeUpdate();
			} catch (SQLException e) {
				c.rollback();
				throw e;
			}

			c.commit();
			return oldValue;
		});
	}

	@Override
	public V put(final K key, final V value) {
		return dataSource.withConnectionUnchecked(c -> {
			final V oldValue;
			c.setAutoCommit(false);

			try {
				oldValue = executeSelectForUpdate(c, key);
			} catch (SQLException e) {
				c.rollback();
				throw e;
			}

			try {
				if (null == oldValue) {
					executeUpdate(c, key, value);
				} else {

					// There's a race condition here, like with #putIfAbsent. See below.
					// TODO: use a lock.
					executeInsert(c, key, value);
				}
			} catch (SQLException e) {
				c.rollback();
				throw e;
			}

			c.commit();
			return oldValue;
		});
	}

	@Override
	public V putIfAbsent(final K key, final V value) {
		return dataSource.withConnectionUnchecked(c -> {
			final V oldValue = executeSelect(c, key);

			if (null != oldValue) {
				return oldValue;
			}

			// There's a race condition here. An exception might be thrown if a record with the same keys is inserted
			// between the call to #get(...) and this point.
			// TODO: use a lock.
			executeInsert(c, key, value);
			return null;
		});
	}

	@Override
	public boolean fastPut(final K key, final V value) {
		return dataSource.withConnectionUnchecked(c -> executeInsertOrUpdate(c, key, value) > 0);
	}

	@Override
	public int size() {
		return dataSource.withStatementUnchecked("SELECT COUNT(*) FROM " + table + ";", q -> {
			try (final ResultSet rs = q.executeQuery()) {
				rs.next();
				return rs.getInt(1);
			}
		});
	}

	@Override
	public void putAll(final Map m) {
		dataSource.withConnectionUnchecked(c -> {
			for (Map.Entry e : m.entrySet()) {
				executeInsertOrUpdate(c, e.getKey(), e.getValue());
			}
		});
	}

	@Override
	public boolean containsKey(final Object key) {
		final Map keys = codec.encodeKey(key);
		final Set keySet = keys.keySet();

		return dataSource.withStatementUnchecked("SELECT EXISTS(SELECT * FROM " + table + " WHERE " +
				String.join(" AND ", keySet.stream().map(k -> k + " = ?").toArray(String[]::new)) + ";", q -> {
			int i = 1;
			for (String k: keySet) {
				q.setObject(i++, keys.get(k));
			}

			try (final ResultSet rs = q.executeQuery()) {
				rs.next();
				return rs.getBoolean(1);
			}
		});
	}

	@Override
	public boolean containsValue(final Object value) {
		final Map values = codec.encodeValue(value);
		final Set keySet = values.keySet();

		return dataSource.withStatementUnchecked("SELECT EXISTS(SELECT * FROM " + table + " WHERE " +
				String.join(" AND ", keySet.stream().map((k)-> k + " = ?").toArray(String[]::new)) + ");", q -> {
			int i = 1;
			for (String k: keySet) {
				q.setObject(i++, values.get(k));
			}

			try (final ResultSet rs = q.executeQuery()) {
				rs.next();
				return rs.getBoolean(1);
			}
		});
	}

	@Override
	public boolean isEmpty() {
		return dataSource.withStatementUnchecked("SELECT EXISTS(SELECT * FROM " + table + ");", q -> {
			final ResultSet rs = q.executeQuery();

			rs.next();
			return rs.getBoolean(1);
		});
	}

	@Override
	public V get(final Object key) {
		return dataSource.withConnectionUnchecked(c -> {
			return executeSelect(c, key);
		});
	}

	@Override
	public Set> entrySet() {
		return new MySQLSet<>(dataSource, new EntrySetCodec(), table);
	}

	@Override
	public Set keySet() {
		return new MySQLSet<>(dataSource, new KeySetCodec(), table);
	}

	@Override
	public Collection values() {
		return new MySQLSet<>(dataSource, new ValuesCodec(), table);
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy