All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.mockserver.proxy.relay.RelayConnectHandler Maven / Gradle / Ivy

package org.mockserver.proxy.relay;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.HttpClientCodec;
import io.netty.handler.codec.http.HttpContentDecompressor;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.util.AttributeKey;
import org.mockserver.lifecycle.LifeCycle;
import org.mockserver.logging.LoggingHandler;
import org.mockserver.logging.MockServerLogger;

import java.net.InetSocketAddress;
import java.util.List;
import java.util.concurrent.Future;

import static org.mockserver.exception.ExceptionHandler.shouldNotIgnoreException;
import static org.mockserver.mock.action.ActionHandler.REMOTE_SOCKET;
import static org.mockserver.mockserver.MockServerHandler.PROXYING;
import static org.mockserver.socket.NettySslContextFactory.nettySslContextFactory;
import static org.mockserver.unification.PortUnificationHandler.isSslEnabledDownstream;
import static org.mockserver.unification.PortUnificationHandler.isSslEnabledUpstream;
import static org.slf4j.event.Level.TRACE;

@ChannelHandler.Sharable
public abstract class RelayConnectHandler extends SimpleChannelInboundHandler {

    public static final AttributeKey>> HTTP_CONNECT_SOCKET = AttributeKey.valueOf("HTTP_CONNECT_SOCKET");
    private final LifeCycle server;
    private final MockServerLogger mockServerLogger;
    private final String host;
    private final int port;

    public RelayConnectHandler(LifeCycle server, MockServerLogger mockServerLogger, String host, int port) {
        this.server = server;
        this.mockServerLogger = mockServerLogger;
        this.host = host;
        this.port = port;
    }

    @Override
    public void channelRead0(final ChannelHandlerContext serverCtx, final T request) throws Exception {
        Bootstrap bootstrap = new Bootstrap()
            .group(serverCtx.channel().eventLoop())
            .channel(NioSocketChannel.class)
            .handler(new ChannelInboundHandlerAdapter() {
                @Override
                public void channelActive(final ChannelHandlerContext clientCtx) throws Exception {
                    serverCtx.channel()
                        .writeAndFlush(successResponse(request))
                        .addListener(new ChannelFutureListener() {
                            @Override
                            public void operationComplete(ChannelFuture channelFuture) throws Exception {
                                removeCodecSupport(serverCtx);
                                serverCtx.channel().attr(PROXYING).set(Boolean.TRUE);

                                // downstream
                                ChannelPipeline downstreamPipeline = clientCtx.channel().pipeline();

                                if (isSslEnabledDownstream(serverCtx.channel())) {
                                    downstreamPipeline.addLast(nettySslContextFactory().createClientSslContext().newHandler(clientCtx.alloc(), host, port));
                                }

                                if (mockServerLogger.isEnabled(TRACE)) {
                                    downstreamPipeline.addLast(new LoggingHandler("downstream                -->"));
                                }

                                downstreamPipeline.addLast(new HttpClientCodec());

                                downstreamPipeline.addLast(new HttpContentDecompressor());

                                downstreamPipeline.addLast(new HttpObjectAggregator(Integer.MAX_VALUE));

                                downstreamPipeline.addLast(new DownstreamProxyRelayHandler(mockServerLogger, serverCtx.channel()));


                                // upstream
                                ChannelPipeline upstreamPipeline = serverCtx.channel().pipeline();

                                if (isSslEnabledUpstream(serverCtx.channel())) {
                                    upstreamPipeline.addLast(nettySslContextFactory().createServerSslContext().newHandler(serverCtx.alloc()));
                                }

                                if (mockServerLogger.isEnabled(TRACE)) {
                                    upstreamPipeline.addLast(new LoggingHandler("upstream <-- "));
                                }

                                upstreamPipeline.addLast(new HttpServerCodec(8192, 8192, 8192));

                                upstreamPipeline.addLast(new HttpContentDecompressor());

                                upstreamPipeline.addLast(new HttpObjectAggregator(Integer.MAX_VALUE));

                                upstreamPipeline.addLast(new UpstreamProxyRelayHandler(mockServerLogger, serverCtx.channel(), clientCtx.channel()));
                            }
                        });
                }
            });

        final InetSocketAddress remoteSocket = getDownstreamSocket(serverCtx.channel());
        bootstrap.connect(remoteSocket).addListener(new ChannelFutureListener() {
            @Override
            public void operationComplete(ChannelFuture future) throws Exception {
                if (!future.isSuccess()) {
                    failure("Connection failed to " + remoteSocket, future.cause(), serverCtx, failureResponse(request));
                }
            }
        });
    }

    private InetSocketAddress getDownstreamSocket(Channel channel) {
        if (channel.attr(REMOTE_SOCKET).get() != null) {
            return channel.attr(REMOTE_SOCKET).get();
        } else {
            return new InetSocketAddress(server.getLocalPort());
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        failure("Exception caught by CONNECT proxy handler -> closing pipeline ", cause, ctx, failureResponse(null));
    }

    private void failure(String message, Throwable cause, ChannelHandlerContext ctx, Object response) {
        if (shouldNotIgnoreException(cause)) {
            mockServerLogger.error(message, cause);
        }
        Channel channel = ctx.channel();
        channel.writeAndFlush(response);
        if (channel.isActive()) {
            channel.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
        }
    }

    protected abstract void removeCodecSupport(ChannelHandlerContext ctx);

    protected abstract Object successResponse(Object request);

    protected abstract Object failureResponse(Object request);

    protected void removeHandler(ChannelPipeline pipeline, Class handlerType) {
        if (pipeline.get(handlerType) != null) {
            pipeline.remove(handlerType);
        }
    }

    protected void removeHandler(ChannelPipeline pipeline, ChannelHandler channelHandler) {
        if (pipeline.toMap().containsValue(channelHandler)) {
            pipeline.remove(channelHandler);
        }
    }

}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy