All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.arcadedb.postgres.PostgresNetworkExecutor Maven / Gradle / Ivy

There is a newer version: 25.1.1
Show newest version
/*
 * Copyright © 2021-present Arcade Data Ltd ([email protected])
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 * SPDX-FileCopyrightText: 2021-present Arcade Data Ltd ([email protected])
 * SPDX-License-Identifier: Apache-2.0
 */
package com.arcadedb.postgres;

import com.arcadedb.Constants;
import com.arcadedb.GlobalConfiguration;
import com.arcadedb.database.Database;
import com.arcadedb.database.DatabaseContext;
import com.arcadedb.database.DatabaseFactory;
import com.arcadedb.database.DatabaseInternal;
import com.arcadedb.database.Document;
import com.arcadedb.exception.CommandParsingException;
import com.arcadedb.exception.DatabaseOperationException;
import com.arcadedb.graph.Edge;
import com.arcadedb.graph.Vertex;
import com.arcadedb.log.LogManager;
import com.arcadedb.network.binary.ChannelBinaryServer;
import com.arcadedb.query.sql.SQLQueryEngine;
import com.arcadedb.query.sql.executor.IteratorResultSet;
import com.arcadedb.query.sql.executor.Result;
import com.arcadedb.query.sql.executor.ResultInternal;
import com.arcadedb.query.sql.executor.ResultSet;
import com.arcadedb.schema.DocumentType;
import com.arcadedb.server.ArcadeDBServer;
import com.arcadedb.server.security.ServerSecurityException;
import com.arcadedb.server.security.ServerSecurityUser;
import com.arcadedb.utility.DateUtils;
import com.arcadedb.utility.FileUtils;
import com.arcadedb.utility.Pair;

import java.io.EOFException;
import java.io.IOException;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.logging.Level;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * Postgres Reference for Protocol Messages: https://www.postgresql.org/docs/9.6/protocol-message-formats.html
 *
 * @author Luca Garulli ([email protected])
 */
public class PostgresNetworkExecutor extends Thread {
  public enum ERROR_SEVERITY {FATAL, ERROR}

  public static final  String                                         PG_SERVER_VERSION          = "10.5";
  private static final int                                            BUFFER_LENGTH              = 32 * 1024;
  private final        ArcadeDBServer                                 server;
  private              Database                                       database;
  private final        ChannelBinaryServer                            channel;
  private volatile     boolean                                        shutdown                   = false;
  private final        byte[]                                         buffer                     = new byte[BUFFER_LENGTH];
  private              int                                            nextByte                   = 0;
  private              boolean                                        reuseLastByte              = false;
  private              String                                         userName                   = null;
  private              String                                         databaseName               = null;
  private              String                                         userPassword               = null;
  private              int                                            consecutiveErrors          = 0;
  private              long                                           processIdSequence          = 0;
  private static final Map> ACTIVE_SESSIONS            = new ConcurrentHashMap<>();
  private final        Map                    portals                    = new HashMap<>();
  private final        boolean                                        DEBUG                      = GlobalConfiguration.POSTGRES_DEBUG.getValueAsBoolean();
  private final        Map                            connectionProperties       = new HashMap<>();
  private              boolean                                        explicitTransactionStarted = false;
  private              boolean                                        errorInTransaction         = false;
  private final        Set                                    ignoreQueriesAppNames      = new HashSet<>(//
      List.of("dbvis", "Database Navigator - Pool"));
  private final        Set                                    ignoreQueries              = new HashSet<>(//
      List.of(//
          "select distinct PRIVILEGE_TYPE as PRIVILEGE_NAME from INFORMATION_SCHEMA.USAGE_PRIVILEGES order by PRIVILEGE_TYPE asc",//
          "SELECT oid, typname FROM pg_type"));

  private interface ReadMessageCallback {
    void read(char type, long length) throws IOException;
  }

  private interface WriteMessageCallback {
    void write() throws IOException;
  }

  public PostgresNetworkExecutor(final ArcadeDBServer server, final Socket socket, final Database database) throws IOException {
    setName(Constants.PRODUCT + "-postgres/" + socket.getInetAddress());
    this.server = server;
    this.channel = new ChannelBinaryServer(socket, server.getConfiguration());
    this.database = database;
  }

  public void close() {
    shutdown = true;
    if (channel != null)
      channel.close();
  }

  @Override
  public void run() {
    try {
      if (!readStartupMessage(true))
        return;

      writeMessage("request for password", () -> channel.writeUnsignedInt(3), 'R', 8);

      waitForAMessage();

      if (!readMessage("password", (type, length) -> userPassword = readString(), 'p'))
        return;

      if (!openDatabase())
        return;

      writeMessage("authentication ok", () -> channel.writeUnsignedInt(0), 'R', 8);

      // BackendKeyData
      final long pid = processIdSequence++;
      final long secret = Math.abs(new Random().nextInt(10000000));
      writeMessage("backend key data", () -> {
        channel.writeUnsignedInt((int) pid);
        channel.writeUnsignedInt((int) secret);
      }, 'K', 12);

      ACTIVE_SESSIONS.put(pid, new Pair<>(secret, this));

      sendServerParameter("server_version", PG_SERVER_VERSION);
      sendServerParameter("server_encoding", "UTF8");
      sendServerParameter("client_encoding", "UTF8");

      try {
        writeReadyForQueryMessage();

        while (!shutdown) {
          try {

            readMessage("any", (type, length) -> {
              consecutiveErrors = 0;

              switch (type) {
              case 'P' -> parseCommand();
              case 'B' -> bindCommand();
              case 'E' -> executeCommand();
              case 'Q' -> queryCommand();
              case 'S' -> syncCommand();
              case 'D' -> describeCommand();
              case 'C' -> closeCommand();
              case 'H' -> flushCommand();
              case 'X' -> {
                // TERMINATE
                shutdown = true;
                return;
              }
              default -> throw new PostgresProtocolException("Message '" + type + "' not managed");
              }

            }, 'P', 'B', 'E', 'Q', 'S', 'D', 'C', 'H', 'X');

          } catch (final Exception e) {
            setErrorInTx();

            if (e instanceof PostgresProtocolException) {
              LogManager.instance().log(this, Level.SEVERE, e.getMessage(), e);
              LogManager.instance().log(this, Level.SEVERE, "PSQL: Closing connection with client");
              return;
            } else {
              LogManager.instance().log(this, Level.SEVERE, "PSQL: Error on reading request: %s", e, e.getMessage());
              if (++consecutiveErrors > 3) {
                LogManager.instance().log(this, Level.SEVERE, "PSQL: Closing connection with client");
                return;
              }
            }
          }
        }
      } finally {
        ACTIVE_SESSIONS.remove(pid);
      }

    } finally {
      close();
    }
  }

  private void syncCommand() {
    if (DEBUG)
      LogManager.instance().log(this, Level.INFO, "PSQL: sync (thread=%s)", Thread.currentThread().getId());

    if (errorInTransaction) {
      // DISCARDED PREVIOUS MESSAGES TILL THIS POINT
      database.rollback();
      errorInTransaction = false;
    } else if (!explicitTransactionStarted) {
      if (database.isTransactionActive())
        database.commit();
    }
    writeReadyForQueryMessage();
  }

  private void flushCommand() {
    if (DEBUG)
      LogManager.instance().log(this, Level.INFO, "PSQL: flush (thread=%s)", Thread.currentThread().getId());
    writeReadyForQueryMessage();
  }

  private void closeCommand() throws IOException {
    final byte closeType = channel.readByte();
    final String prepStatementOrPortal = readString();

    if (errorInTransaction)
      return;

    if (closeType == 'P')
      getPortal(prepStatementOrPortal, true);

    if (DEBUG)
      LogManager.instance().log(this, Level.INFO, "PSQL: close '%s' type=%s (thread=%s)", prepStatementOrPortal, (char) closeType,
          Thread.currentThread().getId());

    writeMessage("close complete", null, '3', 4);
  }

  private void describeCommand() throws IOException {
    final byte type = channel.readByte();
    final String portalName = readString();

    if (DEBUG)
      LogManager.instance()
          .log(this, Level.INFO, "PSQL: describe '%s' type=%s (errorInTransaction=%s thread=%s)", portalName, (char) type,
              errorInTransaction, Thread.currentThread().getId());

    if (errorInTransaction)
      return;

    final PostgresPortal portal = getPortal(portalName, false);
    if (portal == null) {
      writeNoData();
      return;
    }

    if (type == 'P') {
      if (portal.sqlStatement != null) {
        final Object[] parameters = portal.parameterValues != null ? portal.parameterValues.toArray() : new Object[0];
        final ResultSet resultSet = portal.sqlStatement.execute(database, parameters);
        portal.executed = true;
        if (portal.isExpectingResult) {
          portal.cachedResultset = browseAndCacheResultSet(resultSet, 0);
          portal.columns = getColumns(portal.cachedResultset);
          writeRowDescription(portal.columns);
        } else
          writeNoData();
      } else {
        if (portal.columns != null)
          writeRowDescription(portal.columns);
      }
    } else if (type == 'S') {
      writeNoData();
    } else
      throw new PostgresProtocolException("Unexpected describe type '" + type + "'");
  }

  private void executeCommand() {
    try {
      final String portalName = readString();
      final int limit = (int) channel.readUnsignedInt();

      if (errorInTransaction)
        return;

      final PostgresPortal portal = getPortal(portalName, true);
      if (portal == null) {
        writeNoData();
        return;
      }

      if (DEBUG)
        LogManager.instance()
            .log(this, Level.INFO, "PSQL: execute (portal=%s) (limit=%d)-> %s (thread=%s)", portalName, limit, portal,
                Thread.currentThread().getId());

      if (portal.ignoreExecution)
        writeNoData();
      else {
        if (!portal.executed) {
          final Object[] parameters = portal.parameterValues != null ? portal.parameterValues.toArray() : new Object[0];
          final ResultSet resultSet = portal.sqlStatement.execute(database, parameters);
          portal.executed = true;
          if (portal.isExpectingResult) {
            portal.cachedResultset = browseAndCacheResultSet(resultSet, limit);
            portal.columns = getColumns(portal.cachedResultset);
            writeRowDescription(portal.columns);
          }
        }

        if (portal.isExpectingResult) {
          if (portal.columns == null)
            portal.columns = getColumns(portal.cachedResultset);

          writeDataRows(portal.cachedResultset, portal.columns);
          writeCommandComplete(portal.query, portal.cachedResultset == null ? 0 : portal.cachedResultset.size());
        } else
          writeNoData();
      }
    } catch (final CommandParsingException e) {
      setErrorInTx();
      writeError(ERROR_SEVERITY.ERROR, "Syntax error on executing query: " + e.getCause().getMessage(), "42601");
    } catch (final Exception e) {
      setErrorInTx();
      writeError(ERROR_SEVERITY.ERROR, "Error on executing query: " + e.getMessage(), "XX000");
    }
  }

  private void queryCommand() {
    try {
      String queryText = readString().trim();
      if (queryText.endsWith(";"))
        queryText = queryText.substring(0, queryText.length() - 1);

      if (errorInTransaction)
        return;

      if (queryText.isEmpty()) {
        emptyQueryResponse();
        return;
      }

      if (DEBUG)
        LogManager.instance().log(this, Level.INFO, "PSQL: query -> %s (thread=%s)", queryText, Thread.currentThread().getId());

      final Query query = getLanguageAndQuery(queryText);

      final ResultSet resultSet;
      if (queryText.startsWith("SET ")) {
        setConfiguration(queryText);
        resultSet = new IteratorResultSet(createResultSet("STATUS", "Setting ignored").iterator());
      } else if (queryText.equals("SELECT VERSION()"))
        resultSet = new IteratorResultSet(createResultSet("VERSION", "11.0.0").iterator());
      else if (queryText.equals("SELECT CURRENT_SCHEMA()"))
        resultSet = new IteratorResultSet(createResultSet("CURRENT_SCHEMA", database.getName()).iterator());
      else if (queryText.equalsIgnoreCase("BEGIN") || queryText.equalsIgnoreCase("BEGIN TRANSACTION")) {
        explicitTransactionStarted = true;
        database.begin();
        resultSet = new IteratorResultSet(Collections.emptyIterator());
      } else if (ignoreQueries.contains(queryText))
        resultSet = new IteratorResultSet(Collections.emptyIterator());
      else
        resultSet = database.command(query.language, query.query);

      final List cachedResultset = browseAndCacheResultSet(resultSet, 0);

      final Map columns = getColumns(cachedResultset);
      writeRowDescription(columns);
      writeDataRows(cachedResultset, columns);
      writeCommandComplete(queryText, cachedResultset.size());

    } catch (final CommandParsingException e) {
      setErrorInTx();
      writeError(ERROR_SEVERITY.ERROR, "Syntax error on executing query: " + e.getCause().getMessage(), "42601");
    } catch (final Exception e) {
      setErrorInTx();
      writeError(ERROR_SEVERITY.ERROR, "Error on executing query: " + e.getMessage(), "XX000");
    } finally {
      writeReadyForQueryMessage();
    }
  }

  private void writeReadyForQueryMessage() {
    final byte transactionStatus;
    if (explicitTransactionStarted)
      transactionStatus = 'T';
    else
      transactionStatus = 'I';

    writeMessage("ready for query", () -> channel.writeByte(transactionStatus), 'Z', 5);
  }

  private List browseAndCacheResultSet(final ResultSet resultSet, final int limit) {
    final List cachedResultSet = new ArrayList<>();
    while (resultSet.hasNext()) {
      final Result row = resultSet.next();
      if (row == null)
        continue;

      cachedResultSet.add(row);

      if (limit > 0 && cachedResultSet.size() >= limit) {
        portalSuspendedResponse();
        break;
      }
    }
    return cachedResultSet;
  }

  private Map getColumns(final List resultSet) {
    final Map columns = new LinkedHashMap<>();

    if (resultSet != null) {
      boolean atLeastOneElement = false;
      for (final Result row : resultSet) {
        if (row.isElement())
          atLeastOneElement = true;

        final Set propertyNames = row.getPropertyNames();
        for (final String p : propertyNames) {
          final Object value = row.getProperty(p);
          if (value != null) {
            PostgresType valueType = columns.get(p);

            if (valueType == null) {
              // FIND THE VALUE TYPE AND WRITE IT IN THE DATA DESCRIPTION
              final Class valueClass = value.getClass();

              for (final PostgresType t : PostgresType.values()) {
                if (t.cls.isAssignableFrom(valueClass)) {
                  valueType = t;
                  break;
                }
              }

              if (valueType == null)
                valueType = PostgresType.VARCHAR;

              columns.put(p, valueType);
            }
          }
        }
      }

      if (atLeastOneElement) {
        columns.put("@rid", PostgresType.VARCHAR);
        columns.put("@type", PostgresType.VARCHAR);
        columns.put("@cat", PostgresType.CHAR);
      }
    }

    return columns;
  }

  private void writeRowDescription(final Map columns) {
    if (columns == null)
      return;

    final ByteBuffer bufferDescription = ByteBuffer.allocate(64 * 1024).order(ByteOrder.BIG_ENDIAN);

    for (final Map.Entry col : columns.entrySet()) {
      final String columnName = col.getKey();
      final PostgresType columnType = col.getValue();

      bufferDescription.put(columnName.getBytes(DatabaseFactory.getDefaultCharset()));//The field name.
      bufferDescription.put((byte) 0);

      bufferDescription.putInt(
          0); //If the field can be identified as a column of a specific table, the object ID of the table; otherwise zero.
      bufferDescription.putShort(
          (short) 0); //If the field can be identified as a column of a specific table, the attribute number of the column; otherwise zero.
      bufferDescription.putInt(columnType.code);// The object ID of the field's data type.
      bufferDescription.putShort(
          (short) columnType.size);// The data type size (see pg_type.typlen). Note that negative values denote variable-width types.
      bufferDescription.putInt(
          columnType.modifier);// The type modifier (see pg_attribute.atttypmod). The meaning of the modifier is type-specific.
      bufferDescription.putShort(
          (short) 0); // The format code being used for the field. Currently will be zero (text) or one (binary). In a RowDescription returned from the statement variant of Describe, the format code is not yet known and will always be zero.
    }

    bufferDescription.flip();
    writeMessage("row description", () -> {
      channel.writeUnsignedShort((short) columns.size());
      channel.writeBuffer(bufferDescription);
    }, 'T', 4 + 2 + bufferDescription.limit());
  }

  private void writeDataRows(final List resultSet, final Map columns) throws IOException {
    if (resultSet.isEmpty())
      return;

    final ByteBuffer bufferData = ByteBuffer.allocate(128 * 1024).order(ByteOrder.BIG_ENDIAN);
    final ByteBuffer bufferValues = ByteBuffer.allocate(128 * 1024).order(ByteOrder.BIG_ENDIAN);

    for (final Result row : resultSet) {
      bufferData.clear();
      bufferValues.clear();
      bufferValues.putShort((short) columns.size()); // Int16 The number of column values that follow (possibly zero).

      for (final Map.Entry entry : columns.entrySet()) {
        final String propertyName = entry.getKey();

        Object value = null;
        if (propertyName.equals("@rid"))
          value = row.isElement() ? row.getElement().get().getIdentity() : null;
        else if (propertyName.equals("@type"))
          value = row.isElement() ? row.getElement().get().getTypeName() : null;
        else if (propertyName.equals("@out")) {
          if (row.isElement()) {
            final Document record = row.getElement().get();
            if (record instanceof Vertex)
              value = ((Vertex) record).countEdges(Vertex.DIRECTION.OUT, null);
            else if (record instanceof Edge)
              value = ((Edge) record).getOut();
          }
        } else if (propertyName.equals("@in")) {
          if (row.isElement()) {
            final Document record = row.getElement().get();
            if (record instanceof Vertex)
              value = ((Vertex) record).countEdges(Vertex.DIRECTION.IN, null);
            else if (record instanceof Edge)
              value = ((Edge) record).getIn();
          }
        } else if (propertyName.equals("@cat")) {
          if (row.isElement()) {
            final Document record = row.getElement().get();
            if (record instanceof Vertex)
              value = "v";
            else if (record instanceof Edge)
              value = "e";
            else
              value = "d";
          }
        } else
          value = row.getProperty(propertyName);

        entry.getValue().serializeAsText(entry.getValue().code, bufferValues, value);
      }

      bufferValues.flip();
      bufferData.put((byte) 'D');
      bufferData.putInt(4 + bufferValues.limit());
      bufferData.put(bufferValues);

      bufferData.flip();
      channel.writeBuffer(bufferData);
    }

    channel.flush();

    if (DEBUG)
      LogManager.instance().log(this, Level.INFO, "PSQL:-> %d row data (%s) (thread=%s)", resultSet.size(),
          FileUtils.getSizeAsString(bufferData.limit()), Thread.currentThread().getId());
  }

  private void bindCommand() {
    try {
      // BIND
      final String portalName = readString();
      final String sourcePreparedStatement = readString();

      final PostgresPortal portal = getPortal(portalName, false);
      if (portal == null) {
        writeMessage("bind complete", null, '2', 4);
        return;
      }

      if (DEBUG)
        LogManager.instance().log(this, Level.INFO, "PSQL: bind (portal=%s) -> %s (thread=%s)", portalName, sourcePreparedStatement,
            Thread.currentThread().getId());

      final int paramFormatCount = channel.readShort();
      if (paramFormatCount > 0) {
        portal.parameterFormats = new ArrayList<>(paramFormatCount);
        for (int i = 0; i < paramFormatCount; i++) {
          final int formatCode = channel.readUnsignedShort();
          portal.parameterFormats.add(formatCode);
        }
      }

      final int paramValuesCount = channel.readShort();
      if (paramValuesCount > 0) {
        portal.parameterValues = new ArrayList<>(paramValuesCount);
        for (int i = 0; i < paramValuesCount; i++) {
          final long paramSize = channel.readUnsignedInt();
          final byte[] paramValue = new byte[(int) paramSize];
          channel.readBytes(paramValue);

          portal.parameterValues.add(//
              PostgresType.deserialize(portal.parameterTypes.get(i), portal.parameterFormats.get(i), paramValue)//
          );
        }
      }

      final int resultFormatCount = channel.readShort();
      if (resultFormatCount > 0) {
        portal.resultFormats = new ArrayList<>(resultFormatCount);
        for (int i = 0; i < resultFormatCount; i++) {
          final int resultFormat = channel.readUnsignedShort();
          portal.resultFormats.add(resultFormat);
        }
      }

      if (errorInTransaction)
        return;

      writeMessage("bind complete", null, '2', 4);

    } catch (final Exception e) {
      setErrorInTx();
      writeError(ERROR_SEVERITY.ERROR, "Error on parsing bind message: " + e.getMessage(), "XX000");
    }
  }

  private void parseCommand() {
    try {
      // PARSE
      final String portalName = readString();
      final PostgresPortal portal = new PostgresPortal(readString());
      final int paramCount = channel.readShort();

      if (paramCount > 0) {
        portal.parameterTypes = new ArrayList<>(paramCount);
        for (int i = 0; i < paramCount; i++) {
          final long param = channel.readUnsignedInt();
          portal.parameterTypes.add(param);
        }
      }

      if (DEBUG)
        LogManager.instance()
            .log(this, Level.INFO, "PSQL: parse (portal=%s) -> %s (params=%d) (errorInTransaction=%s thread=%s)", portalName,
                portal.query, paramCount, errorInTransaction, Thread.currentThread().getId());

      if (errorInTransaction)
        return;

      if (portal.query.isEmpty()) {
        emptyQueryResponse();
        return;
      }

      final String upperCaseText = portal.query.toUpperCase(Locale.ENGLISH);

      if (portal.query.isEmpty() ||//
          (ignoreQueriesAppNames.contains(connectionProperties.get("application_name")) &&//
              ignoreQueries.contains(portal.query))) {
        // RETURN EMPTY RESULT
        portal.executed = true;
        portal.cachedResultset = new ArrayList<>();
      } else if (upperCaseText.startsWith("SAVEPOINT ")) {
        portal.ignoreExecution = true;
      } else if (upperCaseText.startsWith("SET ")) {
        setConfiguration(portal.query);
        portal.ignoreExecution = true;
      } else if (upperCaseText.equals("SELECT VERSION()")) {
        createResultSet(portal, "VERSION", "11.0.0");

      } else if (upperCaseText.equals("SELECT CURRENT_SCHEMA()")) {
        createResultSet(portal, "CURRENT_SCHEMA", database.getName());

      } else if (upperCaseText.equals("SHOW TRANSACTION ISOLATION LEVEL")) {
        final Database.TRANSACTION_ISOLATION_LEVEL dbIsolationLevel = database.getTransactionIsolationLevel();
        final String level = dbIsolationLevel.name().replace('_', ' ');
        createResultSet(portal, "LEVEL", level);

      } else if (upperCaseText.startsWith("SHOW ")) {
        createResultSet(portal, "CURRENT_SCHEMA", database.getName());

      } else if ("dbvis".equals(connectionProperties.get("application_name"))) {
        // SPECIAL CASES
        if (portal.query.equals(
            "SELECT nspname AS TABLE_SCHEM, NULL AS TABLE_CATALOG FROM pg_catalog.pg_namespace  WHERE nspname <> 'pg_toast' AND (nspname !~ '^pg_temp_'  OR nspname = (pg_catalog.current_schemas(true))[1]) AND (nspname !~ '^pg_toast_temp_'  OR nspname = replace((pg_catalog.current_schemas(true))[1], 'pg_temp_', 'pg_toast_temp_'))  ORDER BY TABLE_SCHEM")
            || portal.query.equals("SELECT     COLLATION_SCHEMA,     COLLATION_NAME FROM     INFORMATION_SCHEMA.COLLATIONS")) {
          // SPECIAL CASE DB VISUALIZER

          portal.executed = true;
          portal.cachedResultset = new ArrayList<>();

          final Map map = new HashMap<>();
          map.put("TABLE_CATALOG", null);
          map.put("TABLE_SCHEM", "");

          final Result result = new ResultInternal(map);
          portal.cachedResultset.add(result);

          portal.columns = new HashMap<>();
          portal.columns.put("TABLE_CATALOG", PostgresType.VARCHAR);
          portal.columns.put("TABLE_SCHEM", PostgresType.VARCHAR);

        } else if (portal.query.contains("ORDER BY TABLE_TYPE,TABLE_SCHEM,TABLE_NAME ")) {
          portal.executed = true;
          portal.cachedResultset = new ArrayList<>();
          for (final DocumentType t : database.getSchema().getTypes()) {
            final Map map = new HashMap<>();
            map.put("TABLE_CAT", "");
            map.put("TABLE_SCHEM", "");
            map.put("TABLE_TYPE", "TABLE");
            map.put("TABLE_NAME", t.getName());
            map.put("REMARKS", "");
            map.put("TYPE_CAT", "");
            map.put("TYPE_SCHEM", "");
            map.put("TYPE_NAME", "");
            map.put("SELF_REFERENCING_COL_NAME", "");
            map.put("REF_GENERATION", "");

            final Result result = new ResultInternal(map);
            portal.cachedResultset.add(result);

            portal.columns = new HashMap<>();
            portal.columns.put("TABLE_CAT", PostgresType.VARCHAR);
            portal.columns.put("TABLE_SCHEM", PostgresType.VARCHAR);
            portal.columns.put("TABLE_TYPE", PostgresType.VARCHAR);
            portal.columns.put("TABLE_NAME", PostgresType.VARCHAR);
            portal.columns.put("REMARKS", PostgresType.VARCHAR);
            portal.columns.put("TYPE_CAT", PostgresType.VARCHAR);
            portal.columns.put("TYPE_SCHEM", PostgresType.VARCHAR);
            portal.columns.put("TYPE_NAME", PostgresType.VARCHAR);
            portal.columns.put("SELF_REFERENCING_COL_NAME", PostgresType.VARCHAR);
            portal.columns.put("REF_GENERATION", PostgresType.VARCHAR);
          }
        }
      } else if (portal.query.equals(
          "select distinct GRANTEE as USER_NAME, 'N' as IS_EXPIRED, 'N' as IS_LOCKED from INFORMATION_SCHEMA.USAGE_PRIVILEGES order by GRANTEE asc")) {
        portal.executed = true;
        portal.cachedResultset = new ArrayList<>();
        final Map map = new HashMap<>();
        map.put("USER_NAME", "root");
        map.put("IS_EXPIRED", 'N');
        map.put("IS_LOCKED", 'N');
        final Result result = new ResultInternal(map);
        portal.cachedResultset.add(result);

        portal.columns = new HashMap<>();
        portal.columns.put("USER_NAME", PostgresType.VARCHAR);
        portal.columns.put("IS_EXPIRED", PostgresType.CHAR);
        portal.columns.put("IS_LOCKED", PostgresType.CHAR);

      } else if (portal.query.equals(
          "select CHARACTER_SET_NAME as CHARSET_NAME, -1 as MAX_LENGTH from INFORMATION_SCHEMA.CHARACTER_SETS order by CHARACTER_SET_NAME asc")) {
        portal.executed = true;
        portal.cachedResultset = new ArrayList<>();
        final Map map = new HashMap<>();
        map.put("CHARSET_NAME", "UTF-8");
        map.put("MAX_LENGTH", -1);
        final Result result = new ResultInternal(map);
        portal.cachedResultset.add(result);

        portal.columns = new HashMap<>();
        portal.columns.put("CHARSET_NAME", PostgresType.VARCHAR);
        portal.columns.put("MAX_LENGTH", PostgresType.INTEGER);
      } else if (//
          portal.query.equals(
              "select NSPNAME as SCHEMA_NAME, case when lower(NSPNAME)='pg_catalog' then 'Y' else 'N' end as IS_PUBLIC, case when lower(NSPNAME)='information_schema' then 'Y' else 'N' end as IS_SYSTEM, 'N' as IS_EMPTY from PG_CATALOG.PG_NAMESPACE order by NSPNAME asc")
              || portal.query.equals(
              "select SCHEMA_NAME, case when lower(SCHEMA_NAME)='pg_catalog' then 'Y' else 'N' end as IS_PUBLIC, case when lower(SCHEMA_NAME)='information_schema' then 'Y' else 'N' end as IS_SYSTEM, 'N' as IS_EMPTY from INFORMATION_SCHEMA.SCHEMATA order by SCHEMA_NAME asc")) {

        portal.executed = true;
        portal.cachedResultset = new ArrayList<>();

        for (final String dbName : server.getDatabaseNames()) {
          final Map map = new HashMap<>();
          map.put("SCHEMA_NAME", dbName);
          map.put("IS_PUBLIC", "Y");
          map.put("IS_SYSTEM", "N");
          map.put("IS_EMPTY", "N");
          final Result result = new ResultInternal(map);
          portal.cachedResultset.add(result);
        }

        portal.columns = new HashMap<>();
        portal.columns.put("SCHEMA_NAME", PostgresType.VARCHAR);
        portal.columns.put("IS_PUBLIC", PostgresType.CHAR);
        portal.columns.put("IS_SYSTEM", PostgresType.CHAR);
        portal.columns.put("IS_EMPTY", PostgresType.CHAR);
      } else if (portal.query.equals(
          "SELECT nspname AS TABLE_SCHEM, NULL AS TABLE_CATALOG FROM pg_catalog.pg_namespace  WHERE nspname <> 'pg_toast' AND (nspname !~ '^pg_temp_'  OR nspname = (pg_catalog.current_schemas(true))[1]) AND (nspname !~ '^pg_toast_temp_'  OR nspname = replace((pg_catalog.current_schemas(true))[1], 'pg_temp_', 'pg_toast_temp_'))  AND nspname LIKE E'%' ORDER BY TABLE_SCHEM")) {

        portal.executed = true;
        portal.cachedResultset = new ArrayList<>();

        for (final DocumentType t : database.getSchema().getTypes()) {
          final Map map = new HashMap<>();
          map.put("TABLE_SCHEM", t.getName());
          map.put("TABLE_CATALOG", database.getName());
          final Result result = new ResultInternal(map);
          portal.cachedResultset.add(result);
        }

        portal.columns = new HashMap<>();
        portal.columns.put("TABLE_SCHEM", PostgresType.VARCHAR);
        portal.columns.put("TABLE_CATALOG", PostgresType.VARCHAR);
      } else {
        final Query query = getLanguageAndQuery(portal.query);

        switch (query.language) {
        case "sql":
          final SQLQueryEngine sqlEngine = (SQLQueryEngine) database.getQueryEngine("sql");
          portal.sqlStatement = sqlEngine.parse(query.query, (DatabaseInternal) database);

          if (portal.query.equalsIgnoreCase("BEGIN") || portal.query.equalsIgnoreCase("BEGIN TRANSACTION")) {
            explicitTransactionStarted = true;
            setEmptyResultSet(portal);
          } else if (portal.query.equalsIgnoreCase("COMMIT")) {
            explicitTransactionStarted = false;
            setEmptyResultSet(portal);
          }
          break;

        default:
          portal.executed = true;
          final ResultSet resultSet = database.command(query.language, query.query);
          portal.cachedResultset = browseAndCacheResultSet(resultSet, 0);
          portal.columns = getColumns(portal.cachedResultset);
        }
      }

      portals.put(portalName, portal);

      // ParseComplete
      writeMessage("parse complete", null, '1', 4);

    } catch (final CommandParsingException e) {
      setErrorInTx();
      writeError(ERROR_SEVERITY.ERROR, "Syntax error on parsing query: " + e.getCause().getMessage(), "42601");
    } catch (final Exception e) {
      setErrorInTx();
      writeError(ERROR_SEVERITY.ERROR, "Error on parsing query: " + e.getMessage(), "XX000");
    }
  }

  private void setConfiguration(final String query) {
    final String q = query.substring("SET ".length());
    String[] parts = q.split("=");
    if (parts.length < 2)
      parts = q.split(" TO ");

    parts[0] = parts[0].trim();
    parts[1] = parts[1].trim();

    if (parts[1].startsWith("'") || parts[1].startsWith("\""))
      parts[1] = parts[1].substring(1, parts[1].length() - 1);

    if (parts[0].equals("datestyle")) {
      if (parts[1].equals("ISO"))
        database.getSchema().setDateTimeFormat(DateUtils.DATE_TIME_ISO_8601_FORMAT);
      else
        LogManager.instance().log(this, Level.INFO, "datestyle '%s' not supported", parts[1]);
    }

    connectionProperties.put(parts[0], parts[1]);
  }

  private void setEmptyResultSet(final PostgresPortal portal) {
    portal.executed = true;
    portal.isExpectingResult = true;
    portal.cachedResultset = Collections.emptyList();
    portal.columns = getColumns(portal.cachedResultset);
  }

  private void sendServerParameter(final String name, final String value) {
    final byte[] nameBytes = name.getBytes(StandardCharsets.UTF_8);
    final byte[] valueBytes = value.getBytes(StandardCharsets.UTF_8);

    final int length = 4 + nameBytes.length + 1 + valueBytes.length + 1;

    writeMessage("parameter status", () -> {
      writeString(name);
      writeString(value);
    }, 'S', length);
  }

  private boolean openDatabase() {
    if (databaseName == null) {
      writeError(ERROR_SEVERITY.FATAL, "Database not selected", "HV00Q");
      return false;
    }

    try {
      final ServerSecurityUser dbUser = server.getSecurity().authenticate(userName, userPassword, databaseName);

      database = server.getDatabase(databaseName);

      DatabaseContext.INSTANCE.init((DatabaseInternal) database).setCurrentUser(dbUser.getDatabaseUser(database));

      database.setAutoTransaction(true);

    } catch (final ServerSecurityException e) {
      writeError(ERROR_SEVERITY.FATAL, "Credentials not valid", "28P01");
      return false;
    } catch (final DatabaseOperationException e) {
      writeError(ERROR_SEVERITY.FATAL, "Database does not exist", "HV00Q");
      return false;
    }

    return true;
  }

  private boolean readStartupMessage(final boolean no2ssl) {
    try {
      final long len = channel.readUnsignedInt();
      final long protocolVersion = channel.readUnsignedInt();
      if (protocolVersion == 80877103) {
        // REQUEST FOR SSL, NOT SUPPORTED
        if (no2ssl) {
          channel.writeByte((byte) 'N');
          channel.flush();

          LogManager.instance().log(this, Level.INFO,
              "PSQL: received not supported SSL connection request. Sending back error message to the client");

          // REPEAT
          return readStartupMessage(false);
        }

        throw new PostgresProtocolException("SSL authentication is not supported");
      } else if (protocolVersion == 80877102) {
        // CANCEL REQUEST, IGNORE IT
        final long pid = channel.readUnsignedInt();
        final long secret = channel.readUnsignedInt();

        LogManager.instance().log(this, Level.INFO, "PSQL: Received cancel request pid %d", pid);

        final Pair session = ACTIVE_SESSIONS.get(pid);
        if (session != null) {
          if (session.getFirst() == secret) {
            LogManager.instance().log(this, Level.INFO, "PSQL: Canceling session " + pid);
            session.getSecond().close();
          } else
            LogManager.instance().log(this, Level.INFO, "PSQL: Blocked unauthorized canceling session " + pid);
        } else
          LogManager.instance().log(this, Level.INFO, "PSQL: Session " + pid + " not found");

        close();
        return false;
      }

      if (len > 8) {
        while (readNextByte() != 0) {
          reuseLastByte();

          final String paramName = readString();
          final String paramValue = readString();

          switch (paramName) {
          case "user":
            userName = paramValue;
            break;
          case "database":
            databaseName = paramValue;
            break;
          case "options":
            // DEPRECATED, IGNORE IT
            break;
          case "replication":
            // NOT SUPPORTED, IGNORE IT
            break;
          }

          connectionProperties.put(paramName, paramValue);
        }
      }
    } catch (final IOException e) {
      setErrorInTx();
      throw new PostgresProtocolException("Error on parsing startup message", e);
    }
    return true;
  }

  private void writeError(final ERROR_SEVERITY severity, final String errorMessage, final String errorCode) {
    try {
      final String sev = severity.toString();

      final int length = 4 + //
          1 + errorMessage.getBytes(StandardCharsets.UTF_8).length + 1 +//
          1 + sev.getBytes(StandardCharsets.UTF_8).length + 1 +//
          1 + errorCode.getBytes(StandardCharsets.UTF_8).length + 1 +//
          1;

      channel.writeByte((byte) 'E');
      channel.writeUnsignedInt(length);

      channel.writeByte((byte) 'M');
      writeString(errorMessage);

      channel.writeByte((byte) 'S');
      writeString(sev);

      channel.writeByte((byte) 'C');
      writeString(errorCode);

      channel.writeByte((byte) 0);
      channel.flush();
    } catch (final IOException e) {
      setErrorInTx();
      throw new PostgresProtocolException("Error on sending error '" + errorMessage + "' to the client", e);
    }
  }

  private void writeMessage(final String messageName, final WriteMessageCallback callback, final char messageCode,
      final long length) {
    try {
      channel.writeByte((byte) messageCode);
      channel.writeUnsignedInt((int) length);
      if (callback != null)
        callback.write();
      channel.flush();

      if (DEBUG)
        LogManager.instance().log(this, Level.INFO, "PSQL:-> %s (%s - %s) (thread=%s)", null, messageName, messageCode,
            FileUtils.getSizeAsString(length), Thread.currentThread().getId());

    } catch (final IOException e) {
      setErrorInTx();
      throw new PostgresProtocolException("Error on sending '" + messageName + "' message", e);
    }
  }

  private boolean readMessage(final String messageName, final ReadMessageCallback callback, final char... expectedMessageCodes) {
    try {
      final char type = (char) readNextByte();
      final long length = channel.readUnsignedInt();

      if (expectedMessageCodes != null && expectedMessageCodes.length > 0) {
        // VALIDATE MESSAGES
        boolean valid = false;
        for (int i = 0; i < expectedMessageCodes.length; i++) {
          if (type == expectedMessageCodes[i]) {
            valid = true;
            break;
          }
        }

        if (!valid) {
          // READ TILL THE END OF THE MESSAGE
          if (length > 4)
            readBytes((int) (length - 4));
          throw new PostgresProtocolException("Unexpected message type '" + type + "' for message " + messageName);
        }
      }

      //if (length > 4)
      callback.read(type, length - 4);

      return true;

    } catch (final EOFException e) {
      // CLIENT CLOSES THE CONNECTION
      setErrorInTx();
      return false;
    } catch (final IOException e) {
      setErrorInTx();
      throw new PostgresProtocolException("Error on reading " + messageName + " message: " + e.getMessage(), e);
    }
  }

  private int readNextByte() throws IOException {
    if (reuseLastByte) {
      // USE THE BYTE ALREADY READ
      reuseLastByte = false;
      return nextByte;
    }

    return nextByte = channel.readUnsignedByte();
  }

  private void waitForAMessage() {
    while (!channel.inputHasData()) {
      try {
        Thread.sleep(100);
      } catch (final InterruptedException interruptedException) {
        throw new PostgresProtocolException("Error on reading from the channel");
      }
    }
  }

  private void reuseLastByte() {
    reuseLastByte = true;
  }

  private String readString() throws IOException {
    int len = 0;
    for (; len < buffer.length; len++) {
      final int b = readNextByte();
      if (b == 0)
        return new String(buffer, 0, len, DatabaseFactory.getDefaultCharset());

      buffer[len] = (byte) b;
    }

    len = readUntilTerminator(len);

    throw new PostgresProtocolException("String content (" + len + ") too long (>" + BUFFER_LENGTH + ")");
  }

  private void writeString(final String text) throws IOException {
    channel.writeBytes(text.getBytes(StandardCharsets.UTF_8));
    channel.writeByte((byte) 0);
  }

  private int readUntilTerminator(int len) throws IOException {
    // OUT OF BUFFER SIZE, CONTINUE READING AND DISCARD THE CONTENT
    for (; readNextByte() != 0; len++) {
    }
    return len;
  }

  private void readBytes(final int len) throws IOException {
    for (int i = 0; i < len; i++)
      readNextByte();
  }

  private void writeCommandComplete(final String queryText, final int resultSetCount) {
    final String upperCaseText = queryText.toUpperCase(Locale.ENGLISH);
    String tag = "";
    if (upperCaseText.startsWith("CREATE VERTEX") || upperCaseText.startsWith("INSERT INTO"))
      tag = "INSERT 0 " + resultSetCount;
    else if (upperCaseText.startsWith("SELECT") || upperCaseText.startsWith("MATCH"))
      tag = "SELECT " + resultSetCount;
    else if (upperCaseText.startsWith("UPDATE"))
      tag = "UPDATE " + resultSetCount;
    else if (upperCaseText.startsWith("DELETE"))
      tag = "DELETE " + resultSetCount;
    else if (upperCaseText.equals("BEGIN") || upperCaseText.equals("BEGIN TRANSACTION"))
      tag = "BEGIN";

    final String finalTag = tag;
    writeMessage("command complete", () -> writeString(finalTag), 'C', 4 + tag.length() + 1);
  }

  private void writeNoData() {
    writeMessage("no data", null, 'n', 4);
  }

  private PostgresPortal getPortal(final String name, final boolean remove) {
    if (remove)
      return portals.remove(name);
    else
      return portals.get(name);
  }

  private void createResultSet(final PostgresPortal portal, final Object... elements) {
    portal.executed = true;
    portal.cachedResultset = createResultSet(elements);
    portal.columns = getColumns(portal.cachedResultset);
  }

  private List createResultSet(final Object... elements) {
    if (elements.length % 2 != 0)
      throw new IllegalArgumentException("Resultset elements must be in pairs");

    final List resultSet = new ArrayList<>();
    for (int i = 0; i < elements.length; i += 2) {
      final Map map = new HashMap<>(2);
      map.put((String) elements[i], elements[i + 1]);
      resultSet.add(new ResultInternal(map));
    }
    return resultSet;
  }

  private Query getLanguageAndQuery(final String query) {
    String language = "sql";
    String queryText = query;

    // Regular expression to match language prefixes
    Pattern pattern = Pattern.compile("\\{(\\w+)\\}");
    Matcher matcher = pattern.matcher(query);

    if (matcher.find()) {
      language = matcher.group(1);
      queryText = query.substring(matcher.end());
    }

    return new Query(language, queryText);
  }

  private void emptyQueryResponse() {
    writeMessage("empty query response", null, 'I', 4);
  }

  private void portalSuspendedResponse() {
    writeMessage("portal suspended response", null, 's', 4);
  }

  private void setErrorInTx() {
    if (explicitTransactionStarted)
      errorInTransaction = true;
  }

  private record Query(String language, String query) {
  }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy