All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.neo4j.shell.state.BoltStateHandler Maven / Gradle / Ivy

There is a newer version: 5.26.0
Show newest version
/*
 * Copyright (c) "Neo4j"
 * Neo4j Sweden AB [http://neo4j.com]
 *
 * This file is part of Neo4j.
 *
 * Neo4j is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see .
 */
package org.neo4j.shell.state;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import org.neo4j.driver.AccessMode;
import org.neo4j.driver.AuthToken;
import org.neo4j.driver.AuthTokens;
import org.neo4j.driver.Bookmark;
import org.neo4j.driver.Config;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
import org.neo4j.driver.Query;
import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.Transaction;
import org.neo4j.driver.Value;
import org.neo4j.driver.Values;
import org.neo4j.driver.exceptions.ClientException;
import org.neo4j.driver.exceptions.Neo4jException;
import org.neo4j.driver.exceptions.ServiceUnavailableException;
import org.neo4j.driver.exceptions.SessionExpiredException;
import org.neo4j.driver.internal.Scheme;
import org.neo4j.driver.summary.DatabaseInfo;
import org.neo4j.driver.summary.ResultSummary;
import org.neo4j.shell.ConnectionConfig;
import org.neo4j.shell.Connector;
import org.neo4j.shell.DatabaseManager;
import org.neo4j.shell.QueryRunner;
import org.neo4j.shell.TransactionHandler;
import org.neo4j.shell.TriFunction;
import org.neo4j.shell.build.Build;
import org.neo4j.shell.exception.CommandException;
import org.neo4j.shell.exception.ThrowingAction;
import org.neo4j.shell.log.NullLogging;

import static org.neo4j.shell.util.Versions.isPasswordChangeRequiredException;

/**
 * Handles interactions with the driver
 */
public class BoltStateHandler implements TransactionHandler, Connector, DatabaseManager, QueryRunner
{
    private static final String USER_AGENT = "neo4j-cypher-shell/v" + Build.version();
    private final TriFunction driverProvider;
    private final boolean isInteractive;
    private final Map bookmarks = new HashMap<>();
    protected Driver driver;
    Session session;
    private String protocolVersion;
    private String activeDatabaseNameAsSetByUser;
    private String actualDatabaseNameAsReportedByServer;
    private Transaction tx;

    public BoltStateHandler( boolean isInteractive )
    {
        this( GraphDatabase::driver, isInteractive );
    }

    BoltStateHandler( TriFunction driverProvider,
                      boolean isInteractive )
    {
        this.driverProvider = driverProvider;
        activeDatabaseNameAsSetByUser = ABSENT_DB_NAME;
        this.isInteractive = isInteractive;
    }

    @Override
    public void setActiveDatabase( String databaseName ) throws CommandException
    {
        if ( isTransactionOpen() )
        {
            throw new CommandException( "There is an open transaction. You need to close it before you can switch database." );
        }
        String previousDatabaseName = activeDatabaseNameAsSetByUser;
        activeDatabaseNameAsSetByUser = databaseName;
        try
        {
            if ( isConnected() )
            {
                reconnect( databaseName, previousDatabaseName );
            }
        }
        catch ( ClientException e )
        {
            if ( isInteractive )
            {
                // We want to try to connect to the previous database
                activeDatabaseNameAsSetByUser = previousDatabaseName;
                try
                {
                    reconnect( previousDatabaseName, previousDatabaseName );
                }
                catch ( Exception e2 )
                {
                    e.addSuppressed( e2 );
                }
            }
            throw e;
        }
    }

    @Override
    public String getActiveDatabaseAsSetByUser()
    {
        return activeDatabaseNameAsSetByUser;
    }

    @Override
    public String getActualDatabaseAsReportedByServer()
    {
        return actualDatabaseNameAsReportedByServer;
    }

    @Override
    public void beginTransaction() throws CommandException
    {
        if ( !isConnected() )
        {
            throw new CommandException( "Not connected to Neo4j" );
        }
        if ( isTransactionOpen() )
        {
            throw new CommandException( "There is already an open transaction" );
        }
        tx = session.beginTransaction();
    }

    @Override
    public void commitTransaction() throws CommandException
    {
        if ( !isConnected() )
        {
            throw new CommandException( "Not connected to Neo4j" );
        }
        if ( !isTransactionOpen() )
        {
            throw new CommandException( "There is no open transaction to commit" );
        }

        try
        {
            tx.commit();
            tx.close();
        }
        finally
        {
            tx = null;
        }
    }

    @Override
    public void rollbackTransaction() throws CommandException
    {
        if ( !isConnected() )
        {
            throw new CommandException( "Not connected to Neo4j" );
        }
        if ( !isTransactionOpen() )
        {
            throw new CommandException( "There is no open transaction to rollback" );
        }

        try
        {
            tx.rollback();
            tx.close();
        }
        finally
        {
            tx = null;
        }
    }

    /**
     * Handle an exception while getting or consuming the result. If not in TX, return the given exception. If in a TX, terminate the TX and return a more
     * verbose error message.
     *
     * @param e the thrown exception.
     * @return a suitable exception to rethrow.
     */
    public Neo4jException handleException( Neo4jException e )
    {
        if ( isTransactionOpen() )
        {
            tx.close();
            tx = null;
            return new ErrorWhileInTransactionException(
                    "An error occurred while in an open transaction. The transaction will be rolled back and terminated. Error: " + e.getMessage(), e );
        }
        else
        {
            return e;
        }
    }

    @Override
    public boolean isTransactionOpen()
    {
        return tx != null;
    }

    @Override
    public boolean isConnected()
    {
        return session != null && session.isOpen();
    }

    @Override
    public ConnectionConfig connect( ConnectionConfig connectionConfig, ThrowingAction command ) throws CommandException
    {
        if ( isConnected() )
        {
            throw new CommandException( "Already connected" );
        }
        final AuthToken authToken = AuthTokens.basic( connectionConfig.username(), connectionConfig.password() );
        try
        {
            String previousDatabaseName = activeDatabaseNameAsSetByUser;
            try
            {
                activeDatabaseNameAsSetByUser = connectionConfig.database();
                driver = getDriver( connectionConfig, authToken );
                reconnect( activeDatabaseNameAsSetByUser, previousDatabaseName, command );
            }
            catch ( ServiceUnavailableException | SessionExpiredException e )
            {
                String scheme = connectionConfig.scheme();
                String fallbackScheme;
                switch ( scheme )
                {
                case Scheme.NEO4J_URI_SCHEME:
                    fallbackScheme = Scheme.BOLT_URI_SCHEME;
                    break;
                case Scheme.NEO4J_LOW_TRUST_URI_SCHEME:
                    fallbackScheme = Scheme.BOLT_LOW_TRUST_URI_SCHEME;
                    break;
                case Scheme.NEO4J_HIGH_TRUST_URI_SCHEME:
                    fallbackScheme = Scheme.BOLT_HIGH_TRUST_URI_SCHEME;
                    break;
                default:
                    throw e;
                }
                connectionConfig = new ConnectionConfig(
                        fallbackScheme,
                        connectionConfig.host(),
                        connectionConfig.port(),
                        connectionConfig.username(),
                        connectionConfig.password(),
                        connectionConfig.encryption(),
                        connectionConfig.database() );

                try
                {
                    driver = getDriver( connectionConfig, authToken );
                    reconnect( activeDatabaseNameAsSetByUser, previousDatabaseName, command );
                }
                catch ( Throwable fallbackThrowable )
                {
                    // Throw the original exception to not cause confusion.
                    throw e;
                }
            }
        }
        catch ( Throwable t )
        {
            try
            {
                silentDisconnect();
            }
            catch ( Exception e )
            {
                t.addSuppressed( e );
            }
            throw t;
        }
        return connectionConfig;
    }

    private void reconnect( String databaseToConnectTo, String previousDatabase ) throws CommandException
    {
        reconnect( databaseToConnectTo, previousDatabase, null );
    }

    private void reconnect( String databaseToConnectTo,
                            String previousDatabase,
                            ThrowingAction command ) throws CommandException
    {
        SessionConfig.Builder builder = SessionConfig.builder();
        builder.withDefaultAccessMode( AccessMode.WRITE );
        if ( !ABSENT_DB_NAME.equals( databaseToConnectTo ) )
        {
            builder.withDatabase( databaseToConnectTo );
        }
        closeSession( previousDatabase );
        final Bookmark bookmarkForDBToConnectTo = bookmarks.get( databaseToConnectTo );
        if ( bookmarkForDBToConnectTo != null )
        {
            builder.withBookmarks( bookmarkForDBToConnectTo );
        }

        session = driver.session( builder.build() );

        resetActualDbName(); // Set this to null first in case run throws an exception
        connect( command );
    }

    /**
     * Closes the session, if there is any. Saves a bookmark for the database currently connected to.
     *
     * @param databaseName the name of the database currently connected to
     */
    private void closeSession( String databaseName )
    {
        if ( session != null )
        {
            // Save the last bookmark and close the session
            final Bookmark bookmarkForPreviousDB = session.lastBookmark();
            session.close();
            bookmarks.put( databaseName, bookmarkForPreviousDB );
        }
    }

    private void connect( ThrowingAction command ) throws CommandException
    {
        ThrowingAction toCall = command == null ? getPing() : () ->
        {
            try
            {
                command.apply();
            }
            catch ( Neo4jException e )
            {
                //If we need to update password we need to call the apply
                //to set the server version and such.
                if ( isPasswordChangeRequiredException( e ) )
                {
                    getPing().apply();
                }
                throw e;
            }
        };

        //execute
        toCall.apply();
    }

    private ThrowingAction getPing()
    {
        return () ->
        {
            try
            {
                Result run = session.run( "CALL db.ping()" );
                ResultSummary summary = run.consume();
                BoltStateHandler.this.protocolVersion = summary.server().protocolVersion();
                updateActualDbName( summary );
            }
            catch ( ClientException e )
            {
                //In older versions there is no db.ping procedure, use legacy method.
                if ( procedureNotFound( e ) )
                {
                    Result run = session.run( isSystemDb() ? "CALL db.indexes()" : "RETURN 1" );
                    ResultSummary summary = run.consume();
                    BoltStateHandler.this.protocolVersion = summary.server().protocolVersion();
                    updateActualDbName( summary );
                }
                else
                {
                    throw e;
                }
            }
        };
    }

    @Override
    public String getServerVersion()
    {
        try
        {
            return runCypher("CALL dbms.components() YIELD versions", Collections.emptyMap())
                    .flatMap(recordOpt -> recordOpt.getRecords().stream().findFirst())
                    .map(record -> record.get("versions"))
                    .filter(value  -> !value.isNull())
                    .map(value -> value.get(0).asString())
                    .orElse("");
        }
        catch ( CommandException e )
        {
            // On versions before 3.1.0-M09
            return "";
        }
    }

    @Override
    public String getProtocolVersion()
    {
        if ( isConnected() )
        {
            if ( protocolVersion == null )
            {
                // On versions before 3.1.0-M09
                protocolVersion = "";
            }

            return protocolVersion;
        }
        return "";
    }

    @Override
    public Optional runCypher( String cypher,
                                           Map queryParams ) throws CommandException
    {
        if ( !isConnected() )
        {
            throw new CommandException( "Not connected to Neo4j" );
        }
        if ( isTransactionOpen() )
        {
            // If this fails, don't try any funny business - just let it die
            return getBoltResult( cypher, queryParams );
        }
        else
        {
            try
            {
                // Note that PERIODIC COMMIT can't execute in a transaction, so if the user has not typed BEGIN, then
                // the statement should NOT be executed in a transaction.
                return getBoltResult( cypher, queryParams );
            }
            catch ( SessionExpiredException e )
            {
                // Server is no longer accepting writes, reconnect and try again.
                // If it still fails, leave it up to the user
                reconnect( activeDatabaseNameAsSetByUser, activeDatabaseNameAsSetByUser );
                return getBoltResult( cypher, queryParams );
            }
        }
    }

    public void updateActualDbName( ResultSummary resultSummary )
    {
        actualDatabaseNameAsReportedByServer = getActualDbName( resultSummary );
    }

    public void changePassword( ConnectionConfig connectionConfig, String newPassword )
    {
        if ( isConnected() )
        {
            silentDisconnect();
        }

        final AuthToken authToken = AuthTokens.basic( connectionConfig.username(), connectionConfig.password() );

        try
        {
            driver = getDriver( connectionConfig, authToken );

            activeDatabaseNameAsSetByUser = SYSTEM_DB_NAME;
            // Supply empty command, so that we do not run ping.
            reconnect( SYSTEM_DB_NAME, SYSTEM_DB_NAME, () ->
            {
            } );

            try
            {
                String command = "ALTER CURRENT USER SET PASSWORD FROM $o TO $n";
                Value parameters = Values.parameters("o", connectionConfig.password(), "n", newPassword);
                Result run = session.run(command, parameters);
                run.consume();
            }
            catch ( Neo4jException e )
            {
                if ( isPasswordChangeRequiredException( e ) )
                {
                    // In < 4.0 versions use legacy method.
                    String oldCommand = "CALL dbms.security.changePassword($n)";
                    Value oldParameters = Values.parameters("n", newPassword);
                    Result run = session.run(oldCommand, oldParameters);
                    run.consume();
                }
                else
                {
                    throw e;
                }
            }

            silentDisconnect();
        }
        catch ( Throwable t )
        {
            try
            {
                silentDisconnect();
            }
            catch ( Exception e )
            {
                t.addSuppressed( e );
            }
            if ( t instanceof RuntimeException )
            {
                throw (RuntimeException) t;
            }
            // The only checked exception is CommandException and we know that
            // we cannot get that since we supply an empty command.
            throw new RuntimeException( t );
        }
    }

    /**
     * @throws SessionExpiredException when server no longer serves writes anymore
     */
    private Optional getBoltResult( String cypher, Map queryParams ) throws SessionExpiredException
    {
        Result statementResult;

        if ( isTransactionOpen() )
        {
            statementResult = tx.run( new Query( cypher, queryParams ) );
        }
        else
        {
            statementResult = session.run( new Query( cypher, queryParams ) );
        }

        if ( statementResult == null )
        {
            return Optional.empty();
        }

        return Optional.of( new StatementBoltResult( statementResult ) );
    }

    private static String getActualDbName( ResultSummary resultSummary )
    {
        DatabaseInfo dbInfo = resultSummary.database();
        return dbInfo.name() == null ? ABSENT_DB_NAME : dbInfo.name();
    }

    private void resetActualDbName()
    {
        actualDatabaseNameAsReportedByServer = null;
    }

    /**
     * Disconnect from Neo4j, clearing up any session resources, but don't give any output. Intended only to be used if connect fails.
     */
    void silentDisconnect()
    {
        try
        {
            closeSession( activeDatabaseNameAsSetByUser );
            if ( driver != null )
            {
                driver.close();
            }
        }
        finally
        {
            session = null;
            driver = null;
            resetActualDbName();
        }
    }

    /**
     * Reset the current session. This rolls back any open transactions.
     */
    public void reset()
    {
        if ( isConnected() )
        {
            session.reset();

            // Clear current state
            if ( isTransactionOpen() )
            {
                // Bolt has already rolled back the transaction but it doesn't close it properly
                tx.rollback();
                tx.close();
                tx = null;
            }
        }
    }

    @Override
    public void disconnect()
    {
        reset();
        silentDisconnect();
        protocolVersion = null;
    }

    private Driver getDriver( ConnectionConfig connectionConfig, AuthToken authToken )
    {
        Config.ConfigBuilder configBuilder = Config.builder()
                                                   .withLogging( NullLogging.NULL_LOGGING )
                                                   .withUserAgent( USER_AGENT );
        switch ( connectionConfig.encryption() )
        {
        case TRUE:
            configBuilder = configBuilder.withEncryption();
            break;
        case FALSE:
            configBuilder = configBuilder.withoutEncryption();
            break;
        default:
            // Do nothing
        }
        return driverProvider.apply( connectionConfig.driverUrl(), authToken, configBuilder.build() );
    }

    private List executeWithRetry( List transactionStatements, BiFunction biFunction )
    {
        return session.writeTransaction( tx ->
                                                 transactionStatements.stream()
                                                                      .map( transactionStatement -> biFunction.apply( transactionStatement, tx ) )
                                                                      .collect( Collectors.toList() ) );
    }

    private boolean isSystemDb()
    {
        return activeDatabaseNameAsSetByUser.compareToIgnoreCase( SYSTEM_DB_NAME ) == 0;
    }

    private static boolean procedureNotFound( ClientException e )
    {
        return e.code().compareToIgnoreCase( "Neo.ClientError.Procedure.ProcedureNotFound" ) == 0;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy