All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.apache.tinkerpop.gremlin.driver.Handler Maven / Gradle / Ivy

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.tinkerpop.gremlin.driver;

import io.netty.util.AttributeMap;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.tinkerpop.gremlin.driver.exception.ResponseException;
import org.apache.tinkerpop.gremlin.driver.message.RequestMessage;
import org.apache.tinkerpop.gremlin.driver.message.ResponseMessage;
import org.apache.tinkerpop.gremlin.driver.message.ResponseStatusCode;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import org.apache.tinkerpop.gremlin.driver.ser.SerializationException;
import org.apache.tinkerpop.gremlin.util.iterator.IteratorUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetSocketAddress;
import java.security.PrivilegedExceptionAction;
import java.security.PrivilegedActionException;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentMap;

import javax.security.auth.Subject;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.login.LoginContext;
import javax.security.auth.login.LoginException;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;

/**
 * Holder for internal handler classes used in constructing the channel pipeline.
 *
 * @author Stephen Mallette (http://stephen.genoprime.com)
 */
final class Handler {

    /**
     * Generic SASL handler that will authenticate against the gremlin server.
     */
    static class GremlinSaslAuthenticationHandler extends SimpleChannelInboundHandler implements CallbackHandler {
        private static final Logger logger = LoggerFactory.getLogger(GremlinSaslAuthenticationHandler.class);
        private static final AttributeKey subjectKey = AttributeKey.valueOf("subject");
        private static final AttributeKey saslClientKey = AttributeKey.valueOf("saslclient");
        private static final Map SASL_PROPERTIES = new HashMap() {{ put(Sasl.SERVER_AUTH, "true"); }};
        private static final byte[] NULL_CHALLENGE = new byte[0];

        private static final Base64.Encoder BASE64_ENCODER = Base64.getEncoder();
        private static final Base64.Decoder BASE64_DECODER = Base64.getDecoder();

        private final AuthProperties authProps;

        public GremlinSaslAuthenticationHandler(final AuthProperties authProps) {
            this.authProps = authProps;
        }

        @Override
        protected void channelRead0(final ChannelHandlerContext channelHandlerContext, final ResponseMessage response) throws Exception {
            // We are only interested in AUTHENTICATE responses here. Everything else can
            // get passed down the pipeline
            if (response.getStatus().getCode() == ResponseStatusCode.AUTHENTICATE) {
                final Attribute saslClient = ((AttributeMap) channelHandlerContext).attr(saslClientKey);
                final Attribute subject = ((AttributeMap) channelHandlerContext).attr(subjectKey);
                final RequestMessage.Builder messageBuilder = RequestMessage.build(Tokens.OPS_AUTHENTICATION);
                // First time through we don't have a sasl client
                if (saslClient.get() == null) {
                    subject.set(login());
                    try {
                        saslClient.set(saslClient(getHostName(channelHandlerContext)));
                    } catch (SaslException saslException) {
                        // push the sasl error into a failure response from the server. this ensures that standard
                        // processing for the ResultQueue is kept. without this SaslException trap and subsequent
                        // conversion to an authentication failure, the close() of the connection might not
                        // succeed as it will appear as though pending messages remain present in the queue on the
                        // connection and the shutdown won't proceed
                        final ResponseMessage clientSideError = ResponseMessage.build(response.getRequestId())
                                .code(ResponseStatusCode.FORBIDDEN).statusMessage(saslException.getMessage()).create();
                        channelHandlerContext.fireChannelRead(clientSideError);
                        return;
                    }

                    messageBuilder.addArg(Tokens.ARGS_SASL_MECHANISM, getMechanism());
                    messageBuilder.addArg(Tokens.ARGS_SASL, saslClient.get().hasInitialResponse() ?
                            BASE64_ENCODER.encodeToString(evaluateChallenge(subject, saslClient, NULL_CHALLENGE)) : null);
                } else {
                    // the server sends base64 encoded sasl as well as the byte array. the byte array will eventually be
                    // phased out, but is present now for backward compatibility in 3.2.x
                    final String base64sasl = response.getStatus().getAttributes().containsKey(Tokens.ARGS_SASL) ?
                        response.getStatus().getAttributes().get(Tokens.ARGS_SASL).toString() :
                        BASE64_ENCODER.encodeToString((byte[]) response.getResult().getData());

                    messageBuilder.addArg(Tokens.ARGS_SASL, BASE64_ENCODER.encodeToString(
                        evaluateChallenge(subject, saslClient, BASE64_DECODER.decode(base64sasl))));
                }
                channelHandlerContext.writeAndFlush(messageBuilder.create());
            } else {
                // SimpleChannelInboundHandler will release the frame if we don't retain it explicitly.
                ReferenceCountUtil.retain(response);
                channelHandlerContext.fireChannelRead(response);
            }
        }

        public void handle(final Callback[] callbacks) {
            for (Callback callback : callbacks) {
                if (callback instanceof NameCallback) {
                    if (authProps.get(AuthProperties.Property.USERNAME) != null) {
                        ((NameCallback)callback).setName(authProps.get(AuthProperties.Property.USERNAME));
                    }
                } else if (callback instanceof PasswordCallback) {
                    if (authProps.get(AuthProperties.Property.PASSWORD) != null) {
                        ((PasswordCallback)callback).setPassword(authProps.get(AuthProperties.Property.PASSWORD).toCharArray());
                    }
                } else {
                    logger.warn("SASL handler got a callback of type " + callback.getClass().getCanonicalName());
                }
            }
        }

        private byte[] evaluateChallenge(final Attribute subject, final Attribute saslClient,
                                         final byte[] challenge) throws SaslException {

            if (subject.get() == null) {
                return saslClient.get().evaluateChallenge(challenge);
            } else {
                // If we have a subject then run this as a privileged action using the subject
                try {
                    return Subject.doAs(subject.get(), (PrivilegedExceptionAction) () -> saslClient.get().evaluateChallenge(challenge));
                } catch (PrivilegedActionException e) {
                    throw (SaslException)e.getException();
                }
            }
        }

        private Subject login() throws LoginException {
            // Login if the user provided us with an entry into the JAAS config file
            if (authProps.get(AuthProperties.Property.JAAS_ENTRY) != null) {
                final LoginContext login = new LoginContext(authProps.get(AuthProperties.Property.JAAS_ENTRY));
                login.login();
                return login.getSubject();
            }
            return null;
        }

        private SaslClient saslClient(final String hostname) throws SaslException {
            return Sasl.createSaslClient(new String[] { getMechanism() }, null, authProps.get(AuthProperties.Property.PROTOCOL),
                                         hostname, SASL_PROPERTIES, this);
        }

        private String getHostName(final ChannelHandlerContext channelHandlerContext) {
            return ((InetSocketAddress)channelHandlerContext.channel().remoteAddress()).getAddress().getCanonicalHostName();
        }

        /**
         * Work out the Sasl mechanism based on the user supplied parameters.
         * If we have a username and password use PLAIN otherwise GSSAPI
         * ToDo: have gremlin-server provide the mechanism(s) it is configured with, so that additional mechanisms can
         * be supported in the driver and confusing GSSException messages from the driver are avoided
         */
        private String getMechanism() {
            if ((authProps.get(AuthProperties.Property.USERNAME) != null) &&
                (authProps.get(AuthProperties.Property.PASSWORD) != null)) {
                return "PLAIN";
            } else {
                return "GSSAPI";
            }
        }
    }

    /**
     * Takes a map of requests pending responses and writes responses to the {@link ResultQueue} of a request
     * as the {@link ResponseMessage} objects are deserialized.
     */
    static class GremlinResponseHandler extends SimpleChannelInboundHandler {
        private static final Logger logger = LoggerFactory.getLogger(GremlinResponseHandler.class);
        private final ConcurrentMap pending;

        public GremlinResponseHandler(final ConcurrentMap pending) {
            this.pending = pending;
        }

        @Override
        public void channelInactive(final ChannelHandlerContext ctx) throws Exception {
            // occurs when the server shuts down in a disorderly fashion, otherwise in an orderly shutdown the server
            // should fire off a close message which will properly release the driver.
            super.channelInactive(ctx);

            // the channel isn't going to get anymore results as it is closed so release all pending requests
            pending.values().forEach(val -> val.markError(new IllegalStateException("Connection to server is no longer active")));
            pending.clear();
        }

        @Override
        protected void channelRead0(final ChannelHandlerContext channelHandlerContext, final ResponseMessage response) throws Exception {
            final ResponseStatusCode statusCode = response.getStatus().getCode();
            final ResultQueue queue = pending.get(response.getRequestId());
            if (statusCode == ResponseStatusCode.SUCCESS || statusCode == ResponseStatusCode.PARTIAL_CONTENT) {
                final Object data = response.getResult().getData();
                final Map meta = response.getResult().getMeta();

                if (!meta.containsKey(Tokens.ARGS_SIDE_EFFECT_KEY)) {
                    // this is a "result" from the server which is either the result of a script or a
                    // serialized traversal
                    if (data instanceof List) {
                        // unrolls the collection into individual results to be handled by the queue.
                        final List listToUnroll = (List) data;
                        listToUnroll.forEach(item -> queue.add(new Result(item)));
                    } else {
                        // since this is not a list it can just be added to the queue
                        queue.add(new Result(response.getResult().getData()));
                    }
                } else {
                    // this is the side-effect from the server which is generated from a serialized traversal
                    final String aggregateTo = meta.getOrDefault(Tokens.ARGS_AGGREGATE_TO, Tokens.VAL_AGGREGATE_TO_NONE).toString();
                    if (data instanceof List) {
                        // unrolls the collection into individual results to be handled by the queue.
                        final List listOfSideEffects = (List) data;
                        listOfSideEffects.forEach(sideEffect -> queue.addSideEffect(aggregateTo, sideEffect));
                    } else {
                        // since this is not a list it can just be added to the queue. this likely shouldn't occur
                        // however as the protocol will typically push everything to list first.
                        queue.addSideEffect(aggregateTo, data);
                    }
                }
            } else {
                // this is a "success" but represents no results otherwise it is an error
                if (statusCode != ResponseStatusCode.NO_CONTENT) {
                    final Map attributes = response.getStatus().getAttributes();
                    final String stackTrace = attributes.containsKey(Tokens.STATUS_ATTRIBUTE_STACK_TRACE) ?
                            (String) attributes.get(Tokens.STATUS_ATTRIBUTE_STACK_TRACE) : null;
                    final List exceptions = attributes.containsKey(Tokens.STATUS_ATTRIBUTE_EXCEPTIONS) ?
                            (List) attributes.get(Tokens.STATUS_ATTRIBUTE_EXCEPTIONS) : null;
                    queue.markError(new ResponseException(response.getStatus().getCode(), response.getStatus().getMessage(),
                            exceptions, stackTrace, cleanStatusAttributes(attributes)));
                }
            }

            // as this is a non-PARTIAL_CONTENT code - the stream is done.
            if (statusCode != ResponseStatusCode.PARTIAL_CONTENT) {
                pending.remove(response.getRequestId()).markComplete(response.getStatus().getAttributes());
            }
        }

        @Override
        public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) throws Exception {
            // if this happens enough times (like the client is unable to deserialize a response) the pending
            // messages queue will not clear.  wonder if there is some way to cope with that.  of course, if
            // there are that many failures someone would take notice and hopefully stop the client.
            logger.error("Could not process the response", cause);

            // the channel took an error because of something pretty bad so release all the futures out there
            pending.values().forEach(val -> val.markError(cause));
            pending.clear();

            // serialization exceptions should not close the channel - that's worth a retry
            if (!IteratorUtils.anyMatch(ExceptionUtils.getThrowableList(cause).iterator(), t -> t instanceof SerializationException))
                if (ctx.channel().isActive()) ctx.close();
        }

        private Map cleanStatusAttributes(final Map statusAttributes) {
            final Map m = new HashMap<>();
            statusAttributes.forEach((k,v) -> {
                if (!k.equals(Tokens.STATUS_ATTRIBUTE_EXCEPTIONS) && !k.equals(Tokens.STATUS_ATTRIBUTE_STACK_TRACE))
                    m.put(k,v);
            });
            return m;
        }
    }

}