All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.impossibl.postgres.system.BasicContext Maven / Gradle / Ivy

Go to download

A new JDBC driver for PostgreSQL aimed at supporting the advanced features of JDBC and Postgres.

There is a newer version: 0.7.0-89abc52
Show newest version
/**
 * Copyright (c) 2013, impossibl.com
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 *  * Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer.
 *  * Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *  * Neither the name of impossibl.com nor the names of its contributors may
 *    be used to endorse or promote products derived from this software
 *    without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 */
package com.impossibl.postgres.system;

import com.impossibl.postgres.datetime.DateTimeFormat;
import com.impossibl.postgres.datetime.ISODateFormat;
import com.impossibl.postgres.datetime.ISOTimeFormat;
import com.impossibl.postgres.datetime.ISOTimestampFormat;
import com.impossibl.postgres.mapper.Mapper;
import com.impossibl.postgres.mapper.PropertySetter;
import com.impossibl.postgres.protocol.BindExecCommand;
import com.impossibl.postgres.protocol.DataRow;
import com.impossibl.postgres.protocol.PrepareCommand;
import com.impossibl.postgres.protocol.Protocol;
import com.impossibl.postgres.protocol.QueryCommand;
import com.impossibl.postgres.protocol.ResultField;
import com.impossibl.postgres.protocol.v30.ProtocolFactoryImpl;
import com.impossibl.postgres.system.tables.PgAttribute;
import com.impossibl.postgres.system.tables.PgProc;
import com.impossibl.postgres.system.tables.PgType;
import com.impossibl.postgres.types.Registry;
import com.impossibl.postgres.types.Type;
import com.impossibl.postgres.types.Type.Category;
import com.impossibl.postgres.utils.Converter;
import com.impossibl.postgres.utils.Factory;
import com.impossibl.postgres.utils.Timer;

import static com.impossibl.postgres.system.Settings.FIELD_DATETIME_FORMAT_CLASS;
import static com.impossibl.postgres.system.Settings.STANDARD_CONFORMING_STRINGS;
import static com.impossibl.postgres.utils.guava.Strings.nullToEmpty;

import java.io.IOException;
import java.net.SocketAddress;
import java.nio.charset.Charset;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Properties;
import java.util.TimeZone;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Logger;
import java.util.regex.Pattern;

import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Arrays.asList;
import static java.util.logging.Level.WARNING;


public class BasicContext implements Context {

  private static final Logger logger = Logger.getLogger(BasicContext.class.getName());

  private static class PreparedQuery {

    String name;
    List parameterTypes;
    List resultFields;

    PreparedQuery(String name, List parameterTypes, List resultFields) {
      this.name = name;
      this.parameterTypes = parameterTypes;
      this.resultFields = resultFields;
    }
  }

  private static class NotificationKey {

    public String name;
    public Pattern channelNameFilter;

    NotificationKey(String name, Pattern channelNameFilter) {
      this.name = name;
      this.channelNameFilter = channelNameFilter;
    }

  }


  protected Registry registry;
  protected Map> targetTypeMap;
  protected Charset charset;
  protected TimeZone timeZone;
  protected DateTimeFormat dateFormatter;
  protected DateTimeFormat timeFormatter;
  protected DateTimeFormat timestampFormatter;
  protected DecimalFormat decimalFormatter;
  protected DecimalFormat currencyFormatter;
  protected Properties settings;
  protected Version serverVersion;
  protected KeyData keyData;
  protected Protocol protocol;
  protected Map notificationListeners;
  protected Map utilQueries;


  public BasicContext(SocketAddress address, Properties settings, Map> targetTypeMap) throws IOException, NoticeException {
    this.targetTypeMap = new HashMap<>(targetTypeMap);
    this.settings = settings;
    this.charset = UTF_8;
    this.timeZone = TimeZone.getTimeZone("UTC");
    this.dateFormatter = new ISODateFormat();
    this.timeFormatter = new ISOTimeFormat();
    this.timestampFormatter = new ISOTimestampFormat();
    this.notificationListeners = new ConcurrentHashMap<>();
    this.registry = new Registry(this);
    this.protocol = new ProtocolFactoryImpl().connect(address, this);
    this.utilQueries = new HashMap<>();
  }

  protected void shutdown() {
    protocol.shutdown();
  }

  public Version getServerVersion() {
    return serverVersion;
  }

  public void setServerVersion(Version serverVersion) {
    this.serverVersion = serverVersion;
  }

  @Override
  public Registry getRegistry() {
    return registry;
  }

  @Override
  public Protocol getProtocol() {
    return protocol;
  }

  @Override
  public Object getSetting(String name) {
    return settings.get(name);
  }

  @Override
  public  T getSetting(String name, Class type) {
    return type.cast(settings.get(name));
  }

  public  T getSetting(String name, Converter converter) {
    return converter.apply(settings.get(name));
  }

  @Override
  public  T getSetting(String name, T defaultValue) {
    Object val = settings.get(name);
    if (val == null)
      return defaultValue;
    if ((defaultValue.getClass() == int.class || defaultValue.getClass() == Integer.class) && val instanceof String) {
      return (T) defaultValue.getClass().cast(Integer.valueOf((String) val));
    }
    if ((defaultValue.getClass() == long.class || defaultValue.getClass() == Long.class) && val instanceof String) {
      return (T) defaultValue.getClass().cast(Long.valueOf((String) val));
    }
    if ((defaultValue.getClass() == boolean.class || defaultValue.getClass() == Boolean.class) && val instanceof String) {
      return (T) defaultValue.getClass().cast(Boolean.valueOf((String) val));
    }
    return (T) defaultValue.getClass().cast(val);
  }

  @Override
  public boolean isSettingEnabled(String name) {
    Object val = getSetting(name);
    if (val instanceof String)
      return ((String) val).equalsIgnoreCase("on");
    if (val instanceof Boolean)
      return (Boolean) val;
    return false;
  }

  @Override
  public Class lookupInstanceType(Type type) {

    Class cls = targetTypeMap.get(type.getName());
    if (cls == null) {
      if (type.getCategory() == Category.Array)
        return Object[].class;
      else
        cls = HashMap.class;
    }

    return cls;
  }

  @Override
  public Charset getCharset() {
    return charset;
  }

  @Override
  public TimeZone getTimeZone() {
    return timeZone;
  }

  @Override
  public KeyData getKeyData() {
    return keyData;
  }

  @Override
  public DateTimeFormat getDateFormatter() {
    return dateFormatter;
  }

  @Override
  public DateTimeFormat getTimeFormatter() {
    return timeFormatter;
  }

  @Override
  public DateTimeFormat getTimestampFormatter() {
    return timestampFormatter;
  }

  @Override
  public DecimalFormat getDecimalFormatter() {
    return decimalFormatter;
  }

  @Override
  public DecimalFormat getCurrencyFormatter() {
    return currencyFormatter;
  }

  protected void init() throws IOException, NoticeException {

    loadTypes();

    prepareRefreshTypeQueries();

    loadLocale();
  }

  private void loadLocale() throws IOException, NoticeException {

    for (DataRow row : queryResults("SELECT name, setting FROM pg_settings WHERE name IN ('lc_numeric', 'lc_time')")) {

      String localeSpec = row.getColumn(1).toString();

      switch (localeSpec.toUpperCase(Locale.US)) {
        case "C":
        case "POSIX":
          localeSpec = "en_US";
          break;
      }

      String[] localeIds = localeSpec.split("_|\\.");

      switch (row.getColumn(0).toString()) {
        case "lc_numeric":
          Locale numLocale = new Locale.Builder().setLanguageTag(localeIds[0]).setRegion(localeIds[1]).build();
          decimalFormatter = (DecimalFormat) DecimalFormat.getNumberInstance(numLocale);
          decimalFormatter.setParseBigDecimal(true);
          currencyFormatter = (DecimalFormat) NumberFormat.getCurrencyInstance(numLocale);
          currencyFormatter.setParseBigDecimal(true);
          break;
        case "lc_time":
          Locale timeLocale = new Locale.Builder().setLanguageTag(localeIds[0]).setRegion(localeIds[1]).build();
      }

      row.release();
    }
  }

  private void loadTypes() throws IOException, NoticeException {

    Timer timer = new Timer();

    //Load types
    String typeSQL = PgType.INSTANCE.getSQL(serverVersion);
    List pgTypes = queryResults(typeSQL, PgType.Row.class);

    //Load attributes
    String attrsSQL = PgAttribute.INSTANCE.getSQL(serverVersion);
    List pgAttrs = queryResults(attrsSQL, PgAttribute.Row.class);

    //Load procs
    String procsSQL = PgProc.INSTANCE.getSQL(serverVersion);
    List pgProcs = queryResults(procsSQL, PgProc.Row.class);

    logger.fine("query time: " + timer.getLap() + "ms");

    //Update the registry with known types
    registry.update(pgTypes, pgAttrs, pgProcs);

    logger.fine("load time: " + timer.getLap() + "ms");
  }

  private void prepareRefreshTypeQueries() throws IOException {

    prepareUtilQuery("refresh-type", PgType.INSTANCE.getSQL(serverVersion) + " where t.oid = $1");

    prepareUtilQuery("refresh-type-attrs", PgAttribute.INSTANCE.getSQL(serverVersion) + " and a.attrelid = $1", "int4");

    prepareUtilQuery("refresh-types", PgType.INSTANCE.getSQL(serverVersion) + " where t.oid > $1", "int4");

    prepareUtilQuery("refresh-types-attrs", PgAttribute.INSTANCE.getSQL(serverVersion) + " and a.attrelid = any( $1 )", "int4[]");

    prepareUtilQuery("refresh-reltype", PgType.INSTANCE.getSQL(serverVersion) + " where t.typrelid = $1", "int4");

  }

  @Override
  public void refreshType(int typeId) {

    int latestKnownTypeId = registry.getLatestKnownTypeId();
    if (latestKnownTypeId >= typeId) {
      //Refresh this specific type
      refreshSpecificType(typeId);
    }
    else {
      //Load all new types we haven't seent
      refreshTypes(latestKnownTypeId);
    }

  }

  void refreshSpecificType(int typeId) {

    try {

      //Load types
      List pgTypes = queryResults("@refresh-type", PgType.Row.class, typeId);

      if (pgTypes.isEmpty()) {
        return;
      }

      //Load attributes
      List pgAttrs = queryResults("@refresh-type-attrs", PgAttribute.Row.class, pgTypes.get(0).relationId);

      registry.update(pgTypes, pgAttrs, Collections.emptyList());
    }
    catch (IOException | NoticeException e) {
      //Ignore errors
    }

  }

  void refreshTypes(int latestTypeId) {

    try {

      //Load types
      List pgTypes = queryResults("@refresh-types", PgType.Row.class, latestTypeId);

      if (pgTypes.isEmpty()) {
        return;
      }

      Integer[] typeIds = new Integer[pgTypes.size()];
      for (int c = 0; c < pgTypes.size(); ++c)
        typeIds[c] = pgTypes.get(c).relationId;

      //Load attributes
      List pgAttrs = queryResults("@refresh-types-attrs", PgAttribute.Row.class, (Object) typeIds);

      registry.update(pgTypes, pgAttrs, Collections.emptyList());
    }
    catch (IOException | NoticeException e) {
      logger.log(WARNING, "Error refreshing types", e);
    }

  }

  @Override
  public void refreshRelationType(int relationId) {

    try {

      //Load types
      List pgTypes = queryResults("@refresh-reltype", PgType.Row.class, relationId);

      if (pgTypes.isEmpty()) {
        return;
      }

      //Load attributes
      List pgAttrs = queryResults("@refresh-type-attrs", PgAttribute.Row.class, relationId);

      registry.update(pgTypes, pgAttrs, Collections.emptyList());
    }
    catch (IOException | NoticeException e) {
      //Ignore errors
    }

  }

  public boolean isUtilQueryPrepared(String name) {
    return utilQueries.containsKey(name);
  }

  public PreparedQuery prepareUtilQuery(String name, String sql, String... parameterTypeNames) throws IOException {

    List parameterTypes = new ArrayList<>(parameterTypeNames.length);
    for (String parameterTypeName : parameterTypeNames) {
      parameterTypes.add(registry.loadType(parameterTypeName));
    }

    return prepareUtilQuery(name, sql, parameterTypes);
  }

  public PreparedQuery prepareUtilQuery(String name, String sql, List parameterTypes) throws IOException {

    PrepareCommand prep = protocol.createPrepare(name, sql, parameterTypes);
    protocol.execute(prep);

    if (prep.getError() != null) {
      throw new IOException("unable to prepare query: " + prep.getError().getMessage());
    }

    PreparedQuery pq = new PreparedQuery(name, prep.getDescribedParameterTypes(), prep.getDescribedResultFields());
    utilQueries.put(name, pq);
    return pq;
  }

  private PreparedQuery prepareQuery(String queryTxt) throws NoticeException, IOException {

    if (queryTxt.charAt(0) == '@') {
      PreparedQuery util = utilQueries.get(queryTxt.substring(1));
      if (util == null) {
        throw new IOException("invalid utility query");
      }
      return util;
    }

    PrepareCommand prepare = protocol.createPrepare(null, queryTxt, Collections.emptyList());

    protocol.execute(prepare);

    if (prepare.getError() != null) {
      throw new NoticeException("Error preparing query", prepare.getError());
    }

    return new PreparedQuery(null, prepare.getDescribedParameterTypes(), prepare.getDescribedResultFields());
  }

  private  List convertResults(Class rowType, List columnFields, List dataRows) throws IOException {

    List columnSetters = Mapper.buildMapping(rowType, columnFields);

    List results = new ArrayList<>(dataRows.size());
    for (int r = 0; r < dataRows.size(); ++r) {

      DataRow dataRow = dataRows.get(r);
      T row = Factory.createInstance(rowType, columnFields.size());

      for (int c = 0; c < columnSetters.size(); ++c) {

        Object columnValue = dataRow.getColumn(c);

        columnSetters.get(c).set(row, columnValue);
      }

      dataRow.release();

      results.add(row);
    }

    return results;
  }

  public  List queryResults(String queryTxt, Class rowType, Object... params) throws IOException, NoticeException {

    QueryCommand.ResultBatch resultBatch = queryBatch(queryTxt, params);

    return convertResults(rowType, resultBatch.fields, resultBatch.results);
  }

  public List queryResults(String queryTxt) throws IOException, NoticeException {

    QueryCommand.ResultBatch resultBatch;

    if (queryTxt.charAt(0) == '@') {

      PreparedQuery pq = prepareQuery(queryTxt);

      resultBatch = preparedQuery(null, pq.name, Collections.emptyList(), Collections.emptyList(), pq.resultFields);
    }
    else {

      QueryCommand query = protocol.createQuery(queryTxt);

      protocol.execute(query);

      if (query.getError() != null) {
        throw new NoticeException("Error querying", query.getError());
      }

      List resultBatches = query.getResultBatches();

      if (resultBatches.isEmpty()) {
        resultBatch = null;
      }
      else {
        resultBatch = query.getResultBatches().get(0);
      }

    }

    if (resultBatch == null) {
      return Collections.emptyList();
    }

    return resultBatch.results;
  }

  public void query(String queryTxt) throws IOException, NoticeException {

    if (queryTxt.charAt(0) == '@') {

      PreparedQuery pq = prepareQuery(queryTxt);

      preparedQuery(null, pq.name, Collections.emptyList(), Collections.emptyList(), pq.resultFields);
    }

    QueryCommand query = protocol.createQuery(queryTxt);

    protocol.execute(query);

    if (query.getError() != null) {
      throw new NoticeException("Error querying", query.getError());
    }

  }

  public String queryFirstResultString(String queryTxt) throws IOException, NoticeException {

    List res = queryResults(queryTxt);
    if (res.isEmpty()) {
      return "";
    }

    Object val = res.get(0).getColumn(0);

    for (DataRow row : res) {
      row.release();
    }

    if (val == null)
      return "";

    return val.toString();
  }

  public QueryCommand.ResultBatch queryBatch(String queryTxt, Object... params) throws IOException, NoticeException {

    PreparedQuery pq = prepareQuery(queryTxt);

    return preparedQuery(null, pq.name, pq.parameterTypes, asList(params), pq.resultFields);
  }

  private QueryCommand.ResultBatch preparedQuery(String portalName, String statementName, List paramTypes, List paramValues,
                                                 List resultFields) throws IOException, NoticeException {

    BindExecCommand query = protocol.createBindExec(portalName, statementName, paramTypes, paramValues, resultFields);

    protocol.execute(query);

    if (query.getError() != null) {
      throw new NoticeException("Error executing query", query.getError());
    }

    List resultBatches = query.getResultBatches();
    if (resultBatches.isEmpty())
      return null;

    return resultBatches.get(0);
  }

  public void setKeyData(int processId, int secretKey) {

    keyData = new KeyData();
    keyData.processId = processId;
    keyData.secretKey = secretKey;
  }

  public void updateSystemParameter(String name, String value) {

    logger.config("system parameter: " + name + "=" + value);

    switch (name) {

      case "server_version":

        serverVersion = Version.parse(value);
        break;

      case "DateStyle":

        String[] parsedDateStyle = DateStyle.parse(value);

        if (parsedDateStyle == null) {
          logger.warning("Invalid DateStyle encountered");
        }
        else {

          dateFormatter = DateStyle.getDateFormatter(parsedDateStyle);
          if (dateFormatter == null) {
            logger.warning("Unknown Date format, reverting to default");
            dateFormatter = new ISODateFormat();
          }

          timeFormatter = DateStyle.getTimeFormatter(parsedDateStyle);
          if (timeFormatter == null) {
            logger.warning("Unknown Time format, reverting to default");
            timeFormatter = new ISOTimeFormat();
          }

          timestampFormatter = DateStyle.getTimestampFormatter(parsedDateStyle);
          if (timestampFormatter == null) {
            logger.warning("Unknown Timestamp format, reverting to default");
            timestampFormatter = new ISOTimestampFormat();
          }
        }
        break;

      case "TimeZone":

        timeZone = TimeZone.getTimeZone(value);
        break;

      case "integer_datetimes":

        settings.put(FIELD_DATETIME_FORMAT_CLASS, Integer.class);
        break;

      case "client_encoding":

        charset = Charset.forName(value);
        break;

      case STANDARD_CONFORMING_STRINGS:

        settings.put(STANDARD_CONFORMING_STRINGS, value.equals("on"));
        break;

      default:
        break;
    }

  }

  public void addNotificationListener(String name, String channelNameFilter, NotificationListener listener) {

    name = nullToEmpty(name);
    channelNameFilter = channelNameFilter != null ? channelNameFilter : ".*";

    Pattern channelNameFilterPattern = Pattern.compile(channelNameFilter);

    NotificationKey key = new NotificationKey(name, channelNameFilterPattern);

    notificationListeners.put(key, listener);
  }

  public synchronized void removeNotificationListener(NotificationListener listener) {

    Iterator> iter = notificationListeners.entrySet().iterator();
    while (iter.hasNext()) {

      Map.Entry entry = iter.next();

      NotificationListener iterListener = entry.getValue();
      if (iterListener == null || iterListener.equals(listener)) {

        iter.remove();
      }

    }
  }

  public synchronized void removeNotificationListener(String listenerName) {

    Iterator> iter = notificationListeners.entrySet().iterator();
    while (iter.hasNext()) {

      Map.Entry entry = iter.next();

      String iterListenerName = entry.getKey().name;
      NotificationListener iterListener = entry.getValue();
      if (iterListenerName.equals(listenerName) || iterListener == null) {

        iter.remove();
      }

    }
  }

  @Override
  public synchronized void reportNotification(int processId, String channelName, String payload) {

    Iterator> iter = notificationListeners.entrySet().iterator();
    while (iter.hasNext()) {

      Map.Entry entry = iter.next();

      NotificationListener listener = entry.getValue();
      if (entry.getKey().channelNameFilter.matcher(channelName).matches()) {

        listener.notification(processId, channelName, payload);
      }

    }

  }

  @Override
  public Context unwrap() {
    return this;
  }

}