All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.tectonica.collections.SqliteKeyValueStore Maven / Gradle / Ivy

package com.tectonica.collections;

import java.io.Serializable;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import com.tectonica.util.JDBC;
import com.tectonica.util.JDBC.ConnListener;
import com.tectonica.util.JDBC.ExecutionContext;
import com.tectonica.util.JDBC.ResultSetIterator;
import com.tectonica.util.STR;
import com.tectonica.util.SerializeUtil;
import com.tectonica.util.SqliteUtil;

public class SqliteKeyValueStore extends KeyValueStore
{
	private final Class valueClass;
	private final String table;
	private final Serializer serializer;
	private final List> indexes;
	private final List indexeCols;
	private final ConcurrentHashMap locks;
	private final JDBC jdbc;

	/**
	 * creates a new key-value store backed by Sqlite
	 * 
	 * @param keyMapper
	 *            this optional parameter is suitable in situations where the key of an entry can be inferred from its value directly
	 *            (as opposed to when the key and value are stored separately). when provided, several convenience methods become applicable
	 */
	public SqliteKeyValueStore(Class valueClass, String connStr, KeyMapper keyMapper)
	{
		super(keyMapper);
		this.valueClass = valueClass;
		this.table = valueClass.getSimpleName();
		this.serializer = new JavaSerializer();
		this.indexes = new ArrayList<>();
		this.indexeCols = new ArrayList<>();
		this.locks = new ConcurrentHashMap<>();
		this.jdbc = SqliteUtil.connect(connStr);
		createTable();
	}

	@Override
	protected Cache createCache()
	{
		return new InMemCache();
	}

	/***********************************************************************************
	 * 
	 * GETTERS
	 * 
	 ***********************************************************************************/

	@Override
	protected V dbGet(final String key)
	{
		return jdbc.execute(new ConnListener()
		{
			@Override
			public V onConnection(Connection conn) throws SQLException
			{
				PreparedStatement stmt = conn.prepareStatement(sqlSelectSingle());
				stmt.setString(1, key);
				ResultSet rs = stmt.executeQuery();
				byte[] bytes = (rs.next()) ? rs.getBytes(1) : null;
				return serializer.bytesToObj(bytes, valueClass);
			}
		});
	}

	@Override
	public Iterator> iterator()
	{
		ExecutionContext ctx = jdbc.startExecute(new ConnListener()
		{
			@Override
			public ResultSet onConnection(Connection conn) throws SQLException
			{
				return conn.createStatement().executeQuery(sqlSelectAll());
			}
		});
		return entryIteratorOfResultSet(ctx);
	}

	@Override
	protected Iterator> dbOrderedIterator(final Collection keys)
	{
		return jdbc.execute(new ConnListener>>()
		{
			@Override
			public Iterator> onConnection(Connection conn) throws SQLException
			{
				PreparedStatement stmt = conn.prepareStatement(sqlSelectKeys(keys));
				int i = 0;
				for (String key : keys)
					stmt.setString(++i, key);
				List ordered = byKeyOrder(stmt.executeQuery(), keys);
				return entryIteratorOfRawIter(ordered.iterator());
			}
		});
	}

	@Override
	public Iterator keyIterator()
	{
		ExecutionContext ctx = jdbc.startExecute(new ConnListener()
		{
			@Override
			public ResultSet onConnection(Connection conn) throws SQLException
			{
				return conn.createStatement().executeQuery(sqlSelectAll());
			}
		});
		return keyIteratorOfResultSet(ctx);
	}

	@Override
	public Iterator valueIterator()
	{
		ExecutionContext ctx = jdbc.startExecute(new ConnListener()
		{
			@Override
			public ResultSet onConnection(Connection conn) throws SQLException
			{
				return conn.createStatement().executeQuery(sqlSelectAll());
			}
		});
		return valueIteratorOfResultSet(ctx);
	}

	/***********************************************************************************
	 * 
	 * SETTERS (UTILS)
	 * 
	 ***********************************************************************************/

	@Override
	protected Modifier getModifier(final String key, ModificationType purpose)
	{
		return new Modifier()
		{
			@Override
			public V getModifiableValue()
			{
				V value = usingCache ? cache.get(key) : null;
				if (value != null) // if we get a (local, in-memory) copy from the cache, we have to return a duplicate
					return serializer.copyOf(value);
				return dbGet(key);
			}

			@Override
			public void dbPut(final V value)
			{
				int updated = upsertRow(key, value, false);
				if (updated != 1)
					throw new RuntimeException("Unexpected dbUpdate() count: " + updated);
			}
		};
	}

	@Override
	protected Lock getModificationLock(String key)
	{
		Lock lock;
		Lock existing = locks.putIfAbsent(key, lock = new SelfRemoveLock(key));
		if (existing != null)
			lock = existing;
		return lock;
	}

	private class SelfRemoveLock extends ReentrantLock
	{
		private static final long serialVersionUID = 1L;

		private final String key;

		public SelfRemoveLock(String key)
		{
			this.key = key;
		}

		@Override
		public void unlock()
		{
			locks.remove(key); // TODO: concurrency bug! this may remove from map after another thread got it
			super.unlock();
		}
	}

	/***********************************************************************************
	 * 
	 * SETTERS
	 * 
	 ***********************************************************************************/

	@Override
	protected void dbInsert(final String key, final V value)
	{
		int inserted = upsertRow(key, value, true);
		if (inserted != 1)
			throw new RuntimeException("Unexpected dbInsert() count: " + inserted);
	}

	/***********************************************************************************
	 * 
	 * DELETERS
	 * 
	 ***********************************************************************************/

	@Override
	protected boolean dbDelete(final String key)
	{
		int deleted = jdbc.execute(new ConnListener()
		{
			@Override
			protected Integer onConnection(Connection conn) throws SQLException
			{
				PreparedStatement stmt = conn.prepareStatement(sqlDeleteSingle());
				stmt.setString(1, key);
				return stmt.executeUpdate();
			}
		});
		return (deleted != 0);
	}

	@Override
	protected int dbDeleteAll()
	{
		return jdbc.execute(new ConnListener()
		{
			@Override
			protected Integer onConnection(Connection conn) throws SQLException
			{
				return conn.createStatement().executeUpdate(sqlDeleteAll());
			}
		});
	}

	/***********************************************************************************
	 * 
	 * INDEXES
	 * 
	 ***********************************************************************************/

	@Override
	public  Index createIndex(final String indexName, IndexMapper mapFunc)
	{
		jdbc.execute(new ConnListener()
		{
			@Override
			protected Void onConnection(Connection conn) throws SQLException
			{
				try
				{
					conn.createStatement().executeUpdate(sqlAddColumn(indexName));
				}
				catch (Exception e)
				{
					// probably 'duplicate column name'
					System.out.println(e.toString());
				}
				conn.createStatement().executeUpdate(sqlCreateIndex(indexName));
				return null;
			}
		});
		SqliteIndexImpl index = new SqliteIndexImpl<>(mapFunc, indexName);
		indexes.add(index);
		indexeCols.add(colOfIndex(indexName));
		return index;
	}

	private class SqliteIndexImpl extends Index
	{
		public SqliteIndexImpl(IndexMapper mapFunc, String name)
		{
			super(mapFunc, name);
		}

		@Override
		public Iterator> iteratorOf(final F f)
		{
			return entryIteratorOfResultSet(selectByIndex(f));
		}

		@Override
		public Iterator keyIteratorOf(F f)
		{
			return keyIteratorOfResultSet(selectByIndex(f));
		}

		@Override
		public Iterator valueIteratorOf(F f)
		{
			return valueIteratorOfResultSet(selectByIndex(f));
		}

		private ExecutionContext selectByIndex(final F f)
		{
			ExecutionContext ctx = jdbc.startExecute(new ConnListener()
			{
				@Override
				public ResultSet onConnection(Connection conn) throws SQLException
				{
					PreparedStatement stmt = conn.prepareStatement(sqlSelectByIndex(name));
					stmt.setString(1, f.toString());
					return stmt.executeQuery();
				}
			});
			return ctx;
		}

		private String getIndexedFieldOf(V value)
		{
			F idx = mapper.getIndexedFieldOf(value);
			return (idx == null) ? null : idx.toString();
		}
	}

	/***********************************************************************************
	 * 
	 * SQL QUERIES
	 * 
	 ***********************************************************************************/

	private String sqlCreateTable()
	{
		return String.format("CREATE TABLE IF NOT EXISTS %s (K VARCHAR2 PRIMARY KEY, V BLOB)", table);
	}

	private String sqlSelectSingle()
	{
		return String.format("SELECT V FROM %s WHERE K=?", table);
	}

	private String sqlSelectAll()
	{
		return String.format("SELECT K,V FROM %s", table);
	}

	private String sqlSelectKeys(Collection keys)
	{
		return String.format("SELECT K,V FROM %s WHERE K IN (%s)", table, STR.implode("?", ",", keys.size()));
	}

	private String sqlUpsert(boolean strictInsert)
	{
		String statement = strictInsert ? "INSERT" : "REPLACE";
		String pfx = (indexeCols.size() > 0) ? "," : "";
		String cols = pfx + STR.implode(indexeCols, ",", false);
		String qm = pfx + STR.implode("?", ",", indexeCols.size());
		return String.format("%s INTO %s (K,V %s) VALUES (?,? %s)", statement, table, cols, qm);
	}

	private String sqlDeleteSingle()
	{
		return String.format("DELETE FROM %s WHERE K=?", table);
	}

	private String sqlDeleteAll()
	{
		return String.format("DELETE FROM %s", table); // not TRUNCATE, as we want the deleted-count
	}

	private String sqlAddColumn(String indexName)
	{
		// TODO: we save the indexes values as strings instead of their possibly other native data type, can be improved
		return String.format("ALTER TABLE %s ADD COLUMN %s VARCHAR2", table, colOfIndex(indexName));
	}

	private String sqlCreateIndex(String indexName)
	{
		return String.format("CREATE INDEX IF NOT EXISTS IDX_%s ON %s (%s)", indexName, table, colOfIndex(indexName));
	}

	private String sqlSelectByIndex(String indexName)
	{
		return String.format("SELECT K,V FROM %s WHERE %s=?", table, colOfIndex(indexName));
	}

	private String colOfIndex(String indexName)
	{
		return "_i_" + indexName;
	}

	/***********************************************************************************
	 * 
	 * DATABASE UTILS
	 * 
	 ***********************************************************************************/

	private void createTable()
	{
		jdbc.execute(new ConnListener()
		{
			@Override
			public Void onConnection(Connection conn) throws SQLException
			{
				conn.createStatement().execute(sqlCreateTable());
				return null;
			}
		});
	}

	private Iterator> entryIteratorOfResultSet(final ExecutionContext ctx)
	{
		final ResultSetIterator iter = new ResultSetIterator(ctx);
		return new Iterator>()
		{
			@Override
			public boolean hasNext()
			{
				return iter.hasNext();
			}

			@Override
			public KeyValue next()
			{
				final ResultSet rs = iter.next();
				return new KeyValue()
				{
					@Override
					public String getKey()
					{
						return rsGetKey(rs);
					}

					@Override
					public V getValue()
					{
						return rsGetValue(rs);
					}
				};
			}

			@Override
			public void remove()
			{
				throw new UnsupportedOperationException();
			}
		};
	}

	private Iterator keyIteratorOfResultSet(final ExecutionContext ctx)
	{
		final ResultSetIterator iter = new ResultSetIterator(ctx);
		return new Iterator()
		{
			@Override
			public boolean hasNext()
			{
				return iter.hasNext();
			}

			@Override
			public String next()
			{
				return rsGetKey(iter.next());
			}

			@Override
			public void remove()
			{
				throw new UnsupportedOperationException();
			}
		};
	}

	private Iterator valueIteratorOfResultSet(final ExecutionContext ctx)
	{
		final ResultSetIterator iter = new ResultSetIterator(ctx);
		return new Iterator()
		{
			@Override
			public boolean hasNext()
			{
				return iter.hasNext();
			}

			@Override
			public V next()
			{
				return rsGetValue(iter.next());
			}

			@Override
			public void remove()
			{
				throw new UnsupportedOperationException();
			}
		};
	}

	private String rsGetKey(ResultSet rs)
	{
		try
		{
			return rs.getString(1);
		}
		catch (SQLException e)
		{
			throw new RuntimeException(e);
		}
	}

	private V rsGetValue(ResultSet rs)
	{
		try
		{
			return serializer.bytesToObj(rs.getBytes(2), valueClass);
		}
		catch (SQLException e)
		{
			throw new RuntimeException(e);
		}
	}

	private static class RawKeyValue
	{
		final String key;
		final byte[] bytes;

		RawKeyValue(String key, byte[] bytes)
		{
			this.key = key;
			this.bytes = bytes;
		}
	}

	private List byKeyOrder(ResultSet rs, final Collection keys) throws SQLException
	{
		Map prefetch = new HashMap<>();
		while (rs.next())
			prefetch.put(rs.getString(1), rs.getBytes(2));
		List ordered = new ArrayList<>();
		for (String key : keys)
		{
			byte[] bytes = prefetch.get(key);
			if (bytes != null)
				ordered.add(new RawKeyValue(key, bytes));
		}
		return ordered;
	}

	private Iterator> entryIteratorOfRawIter(final Iterator iter)
	{
		return new Iterator>()
		{
			@Override
			public boolean hasNext()
			{
				return iter.hasNext();
			}

			@Override
			public KeyValue next()
			{
				final RawKeyValue rkv = iter.next();
				return new KeyValue()
				{
					@Override
					public String getKey()
					{
						return rkv.key;
					}

					@Override
					public V getValue()
					{
						return serializer.bytesToObj(rkv.bytes, valueClass);
					}
				};
			}

			@Override
			public void remove()
			{
				throw new UnsupportedOperationException();
			}
		};

	}

	private Integer upsertRow(final String key, final V value, final boolean strictInsert)
	{
		return jdbc.execute(new ConnListener()
		{
			@Override
			protected Integer onConnection(Connection conn) throws SQLException
			{
				PreparedStatement stmt = conn.prepareStatement(sqlUpsert(strictInsert));
				stmt.setString(1, key);
				stmt.setBytes(2, serializer.objToBytes(value));
				for (int i = 0; i < indexes.size(); i++)
				{
					SqliteIndexImpl index = indexes.get(i);
					String field = (value == null) ? null : index.getIndexedFieldOf(value);
					stmt.setString(3 + i, field);
				}
				return stmt.executeUpdate();
			}
		});
	}

	/***********************************************************************************
	 * 
	 * SERIALIZATION
	 * 
	 ***********************************************************************************/

	public static interface Serializer
	{
		V bytesToObj(byte[] bytes, Class clz);

		byte[] objToBytes(V obj);

		V copyOf(V obj);
	}

	private static final class JavaSerializer implements Serializer
	{
		@Override
		public V bytesToObj(byte[] bytes, Class clz)
		{
			return SerializeUtil.bytesToObj(bytes, clz);
		}

		@Override
		public byte[] objToBytes(V obj)
		{
			return SerializeUtil.objToBytes(obj);
		}

		public V copyOf(V obj)
		{
			return SerializeUtil.copyOf(obj);
		}
	}

	/***********************************************************************************
	 * 
	 * CACHE IMPLEMENTATION
	 * 
	 ***********************************************************************************/

	private class InMemCache implements Cache
	{
		// based on Guava cache
		private com.google.common.cache.Cache cache = com.google.common.cache.CacheBuilder.newBuilder().maximumSize(1000)
				.build();

		@Override
		public V get(String key)
		{
			return cache.getIfPresent(key);
		}

		@Override
		public Map get(Collection keys)
		{
			return cache.getAllPresent(keys);
		}

		@Override
		public void put(String key, V value)
		{
			cache.put(key, value);
		}

		@Override
		public void put(Map values)
		{
			cache.putAll(values);
		}

		@Override
		public void delete(String key)
		{
			cache.invalidate(key);
		}

		@Override
		public void deleteAll()
		{
			cache.invalidateAll();
		}
	};
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy