All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.minimalj.repository.sql.SqlRepository Maven / Gradle / Ivy

Go to download

A java framework aiming for a minimal programming style. Includes GUI and persistence layer.

There is a newer version: 2.5.0.0
Show newest version
package org.minimalj.repository.sql;

import java.io.InputStream;
import java.io.Serializable;
import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;

import javax.sql.DataSource;

import org.apache.derby.jdbc.EmbeddedDataSource;
import org.minimalj.application.Configuration;
import org.minimalj.model.Code;
import org.minimalj.model.EnumUtils;
import org.minimalj.model.Keys;
import org.minimalj.model.View;
import org.minimalj.model.ViewUtil;
import org.minimalj.model.properties.ChainedProperty;
import org.minimalj.model.properties.FieldProperty;
import org.minimalj.model.properties.FlatProperties;
import org.minimalj.model.properties.PropertyInterface;
import org.minimalj.model.test.ModelTest;
import org.minimalj.repository.TransactionalRepository;
import org.minimalj.repository.criteria.By;
import org.minimalj.repository.criteria.Criteria;
import org.minimalj.util.CloneHelper;
import org.minimalj.util.Codes;
import org.minimalj.util.Codes.CodeCacheItem;
import org.minimalj.util.CsvReader;
import org.minimalj.util.FieldUtils;
import org.minimalj.util.GenericUtils;
import org.minimalj.util.IdUtils;
import org.minimalj.util.LoggingRuntimeException;
import org.minimalj.util.StringUtils;

/**
 * The Mapper to a relationale Database
 * 
 */
public class SqlRepository implements TransactionalRepository {
	private static final Logger logger = Logger.getLogger(SqlRepository.class.getName());
	public static final boolean CREATE_TABLES = true;
	
	private final SqlDialect sqlDialect;
	
	private final List> mainClasses;
	private final Map, AbstractTable> tables = new LinkedHashMap, AbstractTable>();
	private final Map> tableByName = new HashMap>();
	private final Map, LinkedHashMap> columnsForClass = new HashMap<>(200);
	
	private final DataSource dataSource;
	
	private Connection autoCommitConnection;
	private final BlockingDeque connectionDeque = new LinkedBlockingDeque<>();
	private final ThreadLocal threadLocalTransactionConnection = new ThreadLocal<>();

	private final HashMap, CodeCacheItem> codeCache = new HashMap<>();
	
	public SqlRepository(DataSource dataSource, Class... classes) {
		this(dataSource, createTablesOnInitialize(dataSource), classes);
	}

	public SqlRepository(DataSource dataSource, boolean createTablesOnInitialize, Class... classes) {
		this.dataSource = dataSource;
		this.mainClasses = Arrays.asList(classes);
		Connection connection = getAutoCommitConnection();
		try {
			sqlDialect = findDialect(connection);
			for (Class clazz : classes) {
				addClass(clazz);
			}
			testModel(classes);
			if (createTablesOnInitialize) {
				createTables();
				createCodes();
			}
		} catch (SQLException x) {
			throw new LoggingRuntimeException(x, logger, "Could not determine product name of database");
		}
	}

	private SqlDialect findDialect(Connection connection) throws SQLException {
		if (Configuration.available("MjSqlDialect")) {
			return Configuration.getClazz("MjSqlDialect", SqlDialect.class);
		}
		
		String databaseProductName = connection.getMetaData().getDatabaseProductName();
		if (StringUtils.equals(databaseProductName, "MySQL")) {
			return new SqlDialect.MariaSqlDialect();
		} else if (StringUtils.equals(databaseProductName, "Apache Derby")) {
			return new SqlDialect.DerbySqlDialect();
		} else if (StringUtils.equals(databaseProductName, "Oracle")) {
			return new SqlDialect.OracleSqlDialect();				
		} else {
			throw new RuntimeException("Only Oracle, MySQL/MariaDB and Derby DB supported at the moment. ProductName: " + databaseProductName);
		}
	}
	
	private Connection getAutoCommitConnection() {
		try {
			// problem with isValid in maria db driver < 1.1.8 
			// if (autoCommitConnection == null || !autoCommitConnection.isValid(0)) {
			if (autoCommitConnection == null) {
				autoCommitConnection = dataSource.getConnection();
				autoCommitConnection.setAutoCommit(true);
			}
			return autoCommitConnection;
		} catch (Exception e) {
			throw new LoggingRuntimeException(e, logger, "Not possible to create autocommit connection");
		}
	}
	
	public SqlDialect getSqlDialect() {
		return sqlDialect;
	}

	public boolean isMainClasses(Class clazz) {
		return mainClasses.contains(clazz);
	}
	
	@Override
	public void startTransaction(int transactionIsolationLevel) {
		if (isTransactionActive()) return;
		
		Connection transactionConnection = allocateConnection(transactionIsolationLevel);
		threadLocalTransactionConnection.set(transactionConnection);
	}

	@Override
	public void endTransaction(boolean commit) {
		Connection transactionConnection = threadLocalTransactionConnection.get();
		if (transactionConnection == null) return;
		
		try {
			if (commit) {
				transactionConnection.commit();
			} else {
				transactionConnection.rollback();
			}
		} catch (SQLException x) {
			throw new LoggingRuntimeException(x, logger, "Transaction failed");
		}
		
		releaseConnection(transactionConnection);
		threadLocalTransactionConnection.set(null);
	}
	
	private Connection allocateConnection(int transactionIsolationLevel) {
		Connection connection = connectionDeque.poll();
		while (true) {
			boolean valid = false;
			try {
				valid = connection != null && connection.isValid(0);
			} catch (SQLException x) {
				// ignore
			}
			if (valid) {
				return connection;
			}
			try {
				connection = dataSource.getConnection();
				connection.setTransactionIsolation(transactionIsolationLevel);
				connection.setAutoCommit(false);
				return connection;
			} catch (Exception e) {
				// this could happen if there are already too many connections
				e.printStackTrace();

				logger.log(Level.FINE, "Not possible to create additional connection", e);
			}
			// so no connection available and not possible to create one
			// block and wait till a connection is in deque
			try {
				connectionDeque.poll(10, TimeUnit.SECONDS);
			} catch (InterruptedException e) {
				logger.log(Level.FINEST, "poll for connection interrupted", e);
			}
		}
	}
	
	private void releaseConnection(Connection connection) {
		// last in first out in the hope that recent accessed objects are the fastest
		connectionDeque.push(connection);
	}
	
	/**
	 * Use with care. Removes all content of all tables. Should only
	 * be used for JUnit tests.
	 */
	public void clear() {
		List> tableList = new ArrayList>(tables.values());
		for (AbstractTable table : tableList) {
			table.clear();
		}
	}

	public boolean isTransactionActive() {
		Connection connection = threadLocalTransactionConnection.get();
		return connection != null;
	}
	
	Connection getConnection() {
		Connection connection = threadLocalTransactionConnection.get();
		if (connection != null) {
			return connection;
		} else {
			connection = getAutoCommitConnection();
			return connection;
		}
	}
	
	private static boolean createTablesOnInitialize(DataSource dataSource) {
		return dataSource instanceof EmbeddedDataSource && "create".equals(((EmbeddedDataSource) dataSource).getCreateDatabase());
	}
	
	@Override
	public  T read(Class clazz, Object id) {
		Table table = getTable(clazz);
		return table.read(id);
	}

	public  T readVersion(Class clazz, Object id, Integer time) {
		HistorizedTable table = (HistorizedTable) getTable(clazz);
		return table.read(id, time);
	}

	@Override
	public  List read(Class resultClass, Criteria criteria, int maxResults) {
		if (View.class.isAssignableFrom(resultClass)) {
			Class viewedClass = ViewUtil.getViewedClass(resultClass);
			Table table = getTable(viewedClass);
			return table.readView(resultClass, criteria, maxResults);
		} else {
			Table table = getTable(resultClass);
			return table.read(criteria, maxResults);
		}
	}

	@Override
	public  Object insert(T object) {
		if (object == null) throw new NullPointerException();
		@SuppressWarnings("unchecked")
		Table table = getTable((Class) object.getClass());
		return table.insert(object);
	}

	@Override
	public  void update(T object) {
		if (object == null) throw new NullPointerException();
		@SuppressWarnings("unchecked")
		Table table = getTable((Class) object.getClass());
		table.update(object);
	}

	public  void delete(T object) {
		delete(object.getClass(), IdUtils.getId(object));
	}

	@Override
	public  void delete(Class clazz, Object id) {
		Table table = getTable(clazz);
		// TODO do in transaction and merge with insert/update
		table.delete(id);
	}
	
	public  void deleteAll(Class clazz) {
		Table table = getTable(clazz);
		table.clear();
	}

	public  List loadHistory(Class clazz, Object id, int maxResult) {
		@SuppressWarnings("unchecked")
		Table table = (Table) getTable(clazz);
		if (table instanceof HistorizedTable) {
			HistorizedTable historizedTable = (HistorizedTable) table;
			int maxVersion = historizedTable.getMaxVersion(id);
			int maxResults = Math.min(maxVersion + 1, maxResult);
			List result = new ArrayList<>(maxResults);
			for (int i = 0; i List getList(LazyList list) {
		CrossTable subTable = (CrossTable) getTableByName().get(list.getListName());
		return subTable.readAll(list.getParentId());
	}
	
	@Override
	public  ELEMENT add(LazyList list, ELEMENT element) {
		CrossTable subTable = (CrossTable) getTableByName().get(list.getListName());
		return subTable.addElement(list.getParentId(), element);
	}
	
	@Override
	public  void remove(LazyList list, int position) {
		throw new RuntimeException("Not yet implemented");
	}
	
	//
	
	private PreparedStatement createStatement(Connection connection, String query, Object[] parameters) throws SQLException {
		PreparedStatement preparedStatement = AbstractTable.createStatement(getConnection(), query, false);
		int param = 1; // !
		for (Object parameter : parameters) {
			setParameter(preparedStatement, param++, parameter);
		}
		return preparedStatement;
	}
	
	public LinkedHashMap findColumns(Class clazz) {
		if (columnsForClass.containsKey(clazz)) {
			return columnsForClass.get(clazz);
		}
		
		LinkedHashMap columns = new LinkedHashMap();
		for (Field field : clazz.getFields()) {
			if (!FieldUtils.isPublic(field) || FieldUtils.isStatic(field) || FieldUtils.isTransient(field)) continue;
			String fieldName = StringUtils.toSnakeCase(field.getName()).toUpperCase();
			if (StringUtils.equals(fieldName, "ID", "VERSION", "HISTORIZED")) continue;
			if (FieldUtils.isList(field)) continue;
			if (FieldUtils.isFinal(field) && !FieldUtils.isSet(field) && !Codes.isCode(field.getType())) {
				Map inlinePropertys = findColumns(field.getType());
				boolean hasClassName = FieldUtils.hasClassName(field) && !FlatProperties.hasCollidingFields(clazz, field.getType(), field.getName());
				for (String inlineKey : inlinePropertys.keySet()) {
					String key = inlineKey;
					if (!hasClassName) {
						key = fieldName + "_" + inlineKey;
					}
					key = SqlIdentifier.buildIdentifier(key, getMaxIdentifierLength(), columns.keySet());
					columns.put(key, new ChainedProperty(new FieldProperty(field), inlinePropertys.get(inlineKey)));
				}
			} else {
				fieldName = SqlIdentifier.buildIdentifier(fieldName, getMaxIdentifierLength(), columns.keySet());
				columns.put(fieldName, new FieldProperty(field));
			}
		}
		columnsForClass.put(clazz, columns);
		return columns;
	}	
	
	/*
	 * TODO: should be merged with the setParameter in AbstractTable.
	 */
	private void setParameter(PreparedStatement preparedStatement, int param, Object value) throws SQLException {
		if (value instanceof Enum) {
			Enum e = (Enum) value;
			value = e.ordinal();
		} else if (value instanceof LocalDate) {
			value = java.sql.Date.valueOf((LocalDate) value);
		} else if (value instanceof LocalTime) {
			value = java.sql.Time.valueOf((LocalTime) value);
		} else if (value instanceof LocalDateTime) {
			value = java.sql.Timestamp.valueOf((LocalDateTime) value);
		}
		preparedStatement.setObject(param, value);
	}
	
	public  List execute(Class clazz, String query, int maxResults, Serializable... parameters) {
		try (PreparedStatement preparedStatement = createStatement(getConnection(), query, parameters)) {
			try (ResultSet resultSet = preparedStatement.executeQuery()) {
				List result = new ArrayList<>();
				while (resultSet.next() && result.size() < maxResults) {
					result.add(readResultSetRow(clazz, resultSet));
				}
				return result;
			}
		} catch (SQLException x) {
			throw new LoggingRuntimeException(x, logger, "Couldn't execute query");
		}
	}
	
	public  T execute(Class clazz, String query, Serializable... parameters) {
		try (PreparedStatement preparedStatement = createStatement(getConnection(), query, parameters)) {
			try (ResultSet resultSet = preparedStatement.executeQuery()) {
				T result = null;
				if (resultSet.next()) {
					result = readResultSetRow(clazz, resultSet);
				}
				return result;
			}
		} catch (SQLException x) {
			throw new LoggingRuntimeException(x, logger, "Couldn't execute query");
		}
	}
	
	public  R readResultSetRow(Class clazz, ResultSet resultSet) throws SQLException {
		Map, Map> loadedReferences = new HashMap<>();
		return readResultSetRow(clazz, resultSet, loadedReferences);
	}
	
	@SuppressWarnings("unchecked")
	public  R readResultSetRow(Class clazz, ResultSet resultSet, Map, Map> loadedReferences) throws SQLException {
		if (clazz == Integer.class) {
			return (R) Integer.valueOf(resultSet.getInt(1));
		} else if (clazz == BigDecimal.class) {
			return (R) resultSet.getBigDecimal(1);
		} else if (clazz == String.class) {
			return (R) resultSet.getString(1);
		}
		
		Object id = null;
		Integer position = 0;
		R result = CloneHelper.newInstance(clazz);
		
		LinkedHashMap columns = findColumns(clazz);
		
		// first read the resultSet completly then resolve references
		// derby db mixes closing of resultSets.
		
		Map values = new HashMap<>(resultSet.getMetaData().getColumnCount() * 3);
		for (int columnIndex = 1; columnIndex <= resultSet.getMetaData().getColumnCount(); columnIndex++) {
			String columnName = resultSet.getMetaData().getColumnName(columnIndex);
			if ("ID".equalsIgnoreCase(columnName)) {
				id = resultSet.getObject(columnIndex);
				IdUtils.setId(result, id);
				continue;
			} else if ("VERSION".equalsIgnoreCase(columnName)) {
				IdUtils.setVersion(result, resultSet.getInt(columnIndex));
				continue;
			} else if ("POSITION".equalsIgnoreCase(columnName)) {
				position = resultSet.getInt(columnIndex);
				continue;				
			} else if ("HISTORIZED".equalsIgnoreCase(columnName)) {
				IdUtils.setHistorized(result, resultSet.getInt(columnIndex));
				continue;
			}
			
			PropertyInterface property = columns.get(columnName);
			if (property == null) continue;
			
			Class fieldClass = property.getClazz();
			boolean isByteArray = fieldClass.isArray() && fieldClass.getComponentType() == Byte.TYPE;

			Object value = isByteArray ? resultSet.getBytes(columnIndex) : resultSet.getObject(columnIndex);
			if (value == null) continue;
			values.put(property, value);
		}
		
		if (!loadedReferences.containsKey(clazz)) {
			loadedReferences.put(clazz, new HashMap<>());
		}
		Object key = position == null ? id : id + "-" + position;
		if (loadedReferences.get(clazz).containsKey(key)) {
			return (R) loadedReferences.get(clazz).get(key);
		} else {
			loadedReferences.get(clazz).put(key, result);
		}
		
		for (Map.Entry entry : values.entrySet()) {
			Object value = entry.getValue();
			PropertyInterface property = entry.getKey();
			if (value != null) {
				Class fieldClass = property.getClazz();
				if (Code.class.isAssignableFrom(fieldClass)) {
					Class codeClass = (Class) fieldClass;
					value = getCode(codeClass, value);
				} else if (View.class.isAssignableFrom(fieldClass)) {
					Class viewedClass = ViewUtil.getViewedClass(fieldClass);
					Table referenceTable = getTable(viewedClass);
					value = referenceTable.readView(fieldClass, value, loadedReferences);
				} else if (IdUtils.hasId(fieldClass)) {
					if (loadedReferences.containsKey(fieldClass) && loadedReferences.get(fieldClass).containsKey(value)) {
						value = loadedReferences.get(fieldClass).get(value);
					} else {
						Table referenceTable = getTable(fieldClass);
						value = referenceTable.read(value, loadedReferences);
					}
				} else if (AbstractTable.isDependable(property)) {
					value = getTable(fieldClass).read(value);
				} else if (fieldClass == Set.class) {
					Set set = (Set) property.getValue(result);
					Class enumClass = GenericUtils.getGenericClass(property.getType());
					EnumUtils.fillSet((int) value, enumClass, set);
					continue; // skip setValue, it's final
				} else {
					value = sqlDialect.convertToFieldClass(fieldClass, value);
				}
				property.setValue(result, value);
			}
		}
		return result;
	}
	
	//
	
	 void addClass(Class clazz) {
		if (!tables.containsKey(clazz)) {
			boolean historized = FieldUtils.hasValidHistorizedField(clazz);
			tables.put(clazz, null); // break recursion. at some point it is checked if a clazz is already in the tables map.
			Table table = historized ? new HistorizedTable(this, clazz) : new Table(this, clazz);
			tables.put(table.getClazz(), table);
		}
	}
	
	private void createTables() {
		List> tableList = new ArrayList>(tables.values());
		for (AbstractTable table : tableList) {
			table.createTable(sqlDialect);
		}
		for (AbstractTable table : tableList) {
			table.createIndexes(sqlDialect);
		}
		for (AbstractTable table : tableList) {
			table.createConstraints(sqlDialect);
		}
	}

	private void createCodes() {
		createConstantCodes();
		createCsvCodes();
	}
	
	@SuppressWarnings("unchecked")
	private void createConstantCodes() {
		for (AbstractTable table : tables.values()) {
			if (Code.class.isAssignableFrom(table.getClazz())) {
				Class codeClass = (Class) table.getClazz(); 
				List constants = Codes.getConstants(codeClass);
				for (Code code : constants) {
					((Table) table).insert(code);
				}
			}
		}
	}

	@SuppressWarnings("unchecked")
	private void createCsvCodes() {
		List> tableList = new ArrayList>(tables.values());
		for (AbstractTable table : tableList) {
			if (Code.class.isAssignableFrom(table.getClazz())) {
				Class clazz = (Class) table.getClazz();
				InputStream is = clazz.getResourceAsStream(clazz.getSimpleName() + ".csv");
				if (is != null) {
					CsvReader reader = new CsvReader(is);
					List values = reader.readValues(clazz);
					for (Code value : values) {
						((Table) table).insert(value);
					}
				}
			}
		}
	}
	
	@SuppressWarnings("unchecked")
	public  AbstractTable getAbstractTable(Class clazz) {
		if (!tables.containsKey(clazz)) {
			throw new IllegalArgumentException(clazz.getName());
		}
		return (AbstractTable) tables.get(clazz);
	}

	public  Table getTable(Class clazz) {
		AbstractTable table = getAbstractTable(clazz);
		if (!(table instanceof Table)) throw new IllegalArgumentException(clazz.getName());
		return (Table) table;
	}

	public  Table getTable(String className) {
		for (Entry, AbstractTable> entry : tables.entrySet()) {
			if (entry.getKey().getName().equals(className)) {
				return (Table) entry.getValue();
			}
		}
		return null;
	}
	
	public String name(Object classOrKey) {
		if (classOrKey instanceof Class) {
			return table((Class) classOrKey);
		} else {
			return column(classOrKey);
		}
	}

	public String table(Class clazz) {
		AbstractTable table = getAbstractTable(clazz);
		return table.getTableName();
	}
	
	public String column(Object key) {
		PropertyInterface property = Keys.getProperty(key);
		Class declaringClass = property.getDeclaringClass();
		AbstractTable table = getAbstractTable(declaringClass);
		return table.column(property);
	}
	
	public boolean tableExists(Class clazz) {
		return tables.containsKey(clazz);
	}
	
	private void testModel(Class[] classes) {
		ModelTest test = new ModelTest(classes);
		if (!test.getProblems().isEmpty()) {
			test.logProblems();
			throw new IllegalArgumentException("The persistent classes don't apply to the given rules");
		}
	}

	public  T getCode(Class clazz, Object codeId) {
		if (isLoading(clazz)) {
			// this special case is needed to break a possible reference cycle
			return getTable(clazz).read(codeId);
		}
		List codes = getCodes(clazz);
		return Codes.findCode(codes, codeId);
	}
	
	@SuppressWarnings("unchecked")
	private  boolean isLoading(Class clazz) {
		CodeCacheItem cacheItem = (CodeCacheItem) codeCache.get(clazz);
		return cacheItem != null && cacheItem.isLoading();
	}

	@SuppressWarnings("unchecked")
	 List getCodes(Class clazz) {
		synchronized (clazz) {
			CodeCacheItem cacheItem = (CodeCacheItem) codeCache.get(clazz);
			if (cacheItem == null || !cacheItem.isValid()) {
				updateCode(clazz);
			}
			cacheItem = (CodeCacheItem) codeCache.get(clazz);
			List codes = cacheItem.getCodes();
			return codes;
		}
	}

	private  void updateCode(Class clazz) {
		CodeCacheItem codeCacheItem = new CodeCacheItem();
		codeCache.put(clazz, codeCacheItem);
		List codes = getTable(clazz).read(By.all(), Integer.MAX_VALUE);
		codeCacheItem.setCodes(codes);
	}

	public void invalidateCodeCache(Class clazz) {
		codeCache.remove(clazz);
	}

	public int getMaxIdentifierLength() {
		return sqlDialect.getMaxIdentifierLength();
	}

	public Map> getTableByName() {
		return tableByName;
	}
}