All Downloads are FREE. Search and download functionalities are using the official Maven repository.

io.mantisrx.api.tunnel.CrossRegionHandler Maven / Gradle / Ivy

The newest version!
/*
 * Copyright 2020 Netflix, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package io.mantisrx.api.tunnel;

import com.google.common.annotations.VisibleForTesting;
import com.netflix.config.DynamicIntProperty;
import com.netflix.config.DynamicStringProperty;
import com.netflix.spectator.api.Counter;
import com.netflix.zuul.netty.SpectatorUtils;
import io.mantisrx.api.Constants;
import io.mantisrx.api.Util;
import io.mantisrx.api.push.ConnectionBroker;
import io.mantisrx.api.push.PushConnectionDetails;
import io.mantisrx.shaded.com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpUtil;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.handler.codec.http.QueryStringDecoder;
import io.netty.handler.codec.http.QueryStringEncoder;
import lombok.extern.slf4j.Slf4j;
import mantis.io.reactivex.netty.channel.StringTransformer;
import mantis.io.reactivex.netty.protocol.http.client.HttpClient;
import mantis.io.reactivex.netty.protocol.http.client.HttpClientRequest;
import mantis.io.reactivex.netty.protocol.http.client.HttpClientResponse;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import rx.Observable;
import rx.Scheduler;
import rx.Subscription;

import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import static io.mantisrx.api.Constants.OriginRegionTagName;
import static io.mantisrx.api.Constants.TagNameValDelimiter;
import static io.mantisrx.api.Constants.TagsParamName;
import static io.mantisrx.api.Constants.TunnelPingMessage;
import static io.mantisrx.api.Constants.TunnelPingParamName;
import static io.mantisrx.api.Util.getLocalRegion;

@Slf4j
public class CrossRegionHandler extends SimpleChannelInboundHandler {

    private final List pushPrefixes;
    private final MantisCrossRegionalClient mantisCrossRegionalClient;
    private final ConnectionBroker connectionBroker;
    private final Scheduler scheduler;

    private Subscription subscription = null;
    private String uriForLogging = null;
    private ScheduledExecutorService scheduledExecutorService = new ScheduledThreadPoolExecutor(1,
            new ThreadFactoryBuilder().setNameFormat("cross-region-handler-drainer-%d").build());
    private ScheduledFuture drainFuture;
    private final DynamicIntProperty queueCapacity = new DynamicIntProperty("io.mantisrx.api.push.queueCapacity", 1000);
    private final DynamicIntProperty writeIntervalMillis = new DynamicIntProperty("io.mantisrx.api.push.writeIntervalMillis", 50);
    private final DynamicStringProperty tunnelRegionsProperty = new DynamicStringProperty("io.mantisrx.api.tunnel.regions", Util.getLocalRegion());

    public CrossRegionHandler(
            List pushPrefixes,
            MantisCrossRegionalClient mantisCrossRegionalClient,
            ConnectionBroker connectionBroker,
            Scheduler scheduler) {
        super(true);
        this.pushPrefixes = pushPrefixes;
        this.mantisCrossRegionalClient = mantisCrossRegionalClient;
        this.connectionBroker = connectionBroker;
        this.scheduler = scheduler;
    }

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) throws Exception {
        uriForLogging = request.uri();
        if (HttpUtil.is100ContinueExpected(request)) {
            send100Contine(ctx);
        }

        if (isCrossRegionStreamingPath(request.uri())) {
            handleRemoteSse(ctx, request);
        } else { // REST
            if (request.method() == HttpMethod.HEAD) {
                handleHead(ctx, request);
            } else if (request.method() == HttpMethod.GET) {
                handleRestGet(ctx, request);
            } else if(request.method() == HttpMethod.POST) {
                handleRestPost(ctx, request);
            } else {
                ctx.fireChannelRead(request.retain());
            }
        }
    }

    //
    // REST Implementations
    //

    private void handleHead(ChannelHandlerContext ctx, FullHttpRequest request) {

        HttpHeaders headers = new DefaultHttpHeaders();
        headers.add(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON);
        headers.add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, "*");
        headers.add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS,
                "Origin, X-Requested-With, Accept, Content-Type, Cache-Control");
        headers.add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS,
                "GET, OPTIONS, PUT, POST, DELETE, CONNECT");

        HttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
                HttpResponseStatus.OK,
                Unpooled.copiedBuffer("", Charset.defaultCharset()),
                headers,
                new DefaultHttpHeaders());
        ctx.writeAndFlush(response)
                .addListener(__ -> ctx.close());
    }

    @VisibleForTesting
    List getTunnelRegions() {
        return parseRegionCsv(tunnelRegionsProperty.get());
    }

    private static List parseRegionCsv(String regionCsv) {
        return Arrays.stream(regionCsv.split(","))
                .map(String::trim)
                .map(String::toLowerCase)
                .collect(Collectors.toList());
    }

    @VisibleForTesting
    List parseRegionsInUri(String uri) {
        final String regionString = getRegion(uri);
        if (isAllRegion(regionString)) {
            return getTunnelRegions();
        } else if (regionString.contains(",")) {
            return parseRegionCsv(regionString);
        } else {
            return Collections.singletonList(regionString);
        }
    }

    private void handleRestGet(ChannelHandlerContext ctx, FullHttpRequest request) {
        List regions = parseRegionsInUri(request.uri());

        String uri = getTail(request.uri());
        log.info("Relaying GET URI {} to {} (original uri {}).", uri, regions, request.uri());
        Observable.from(regions)
                .flatMap(region -> {
                    final AtomicReference ref = new AtomicReference<>();
                    HttpClientRequest rq = HttpClientRequest.create(HttpMethod.GET, uri);

                    return Observable
                            .create((Observable.OnSubscribe>) subscriber ->
                                    subscriber.onNext(mantisCrossRegionalClient.getSecureRestClient(region)))
                            .flatMap(client -> {
                                ref.set(null);
                                return client.submit(rq)
                                        .flatMap(resp -> {
                                            final int code = resp.getStatus().code();
                                            if (code >= 500) {
                                                throw new RuntimeException(resp.getStatus().toString());
                                            }
                                            return responseToRegionData(region, resp);
                                        })
                                        .onErrorReturn(t -> {
                                            log.warn("Error getting response from remote master: " + t.getMessage());
                                            ref.set(t);
                                            return new RegionData(region, false, t.getMessage(), 0);
                                        });
                            })
                            .map(data -> {
                                final Throwable t = ref.get();
                                if (t != null)
                                    throw new RuntimeException(t);
                                return data;
                            })
                            .retryWhen(Util.getRetryFunc(log, uri + " in " + region))
                            .take(1)
                            .onErrorReturn(t -> new RegionData(region, false, t.getMessage(), 0));
                })
                .reduce(new ArrayList(3), (regionDatas, regionData) -> {
                    regionDatas.add(regionData);
                    return regionDatas;
                })
                .observeOn(scheduler)
                .subscribeOn(scheduler)
                .take(1)
                .subscribe(result -> writeDataAndCloseChannel(ctx, result));
    }

    private void handleRestPost(ChannelHandlerContext ctx, FullHttpRequest request) {
        String uri = getTail(request.uri());
        List regions = parseRegionsInUri(request.uri());

        log.info("Relaying POST URI {} to {} (original uri {}).", uri, regions, request.uri());
        final AtomicReference ref = new AtomicReference<>();

        String content = request.content().toString(Charset.defaultCharset());
        Observable.from(regions)
                .flatMap(region -> {
                    HttpClientRequest rq = HttpClientRequest.create(HttpMethod.POST, uri);
                    rq.withRawContent(content, StringTransformer.DEFAULT_INSTANCE);

                    return Observable
                            .create((Observable.OnSubscribe>) subscriber ->
                                    subscriber.onNext(mantisCrossRegionalClient.getSecureRestClient(region)))
                            .flatMap(client -> client.submit(rq)
                                    .flatMap(resp -> {
                                        final int code = resp.getStatus().code();
                                        if (code >= 500) {
                                            throw new RuntimeException(resp.getStatus().toString() + "in " + region );
                                        }
                                        return responseToRegionData(region, resp);
                                    })
                                    .onErrorReturn(t -> {
                                        log.warn("Error getting response from remote master: " + t.getMessage());
                                        ref.set(t);
                                        return new RegionData(region, false, t.getMessage(), 0);
                                    }))
                            .map(data -> {
                                final Throwable t = ref.get();
                                if (t != null)
                                    throw new RuntimeException(t);
                                return data;
                            })
                            .retryWhen(Util.getRetryFunc(log, uri + " in " + region))
                            .take(1)
                            .onErrorReturn(t -> new RegionData(region, false, t.getMessage(), 0));
                })
                .reduce(new ArrayList(), (regionDatas, regionData) -> {
                    regionDatas.add(regionData);
                    return regionDatas;
                })
                .observeOn(scheduler)
                .subscribeOn(scheduler)
                .take(1)
                .subscribe(result -> writeDataAndCloseChannel(ctx, result));
    }

    private void handleRemoteSse(ChannelHandlerContext ctx, FullHttpRequest request) {
        HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1,
                HttpResponseStatus.OK);
        HttpHeaders headers = response.headers();
        headers.add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, "*");
        headers.add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, "Origin, X-Requested-With, Accept, Content-Type, Cache-Control");
        headers.set(HttpHeaderNames.CONTENT_TYPE, "text/event-stream");
        headers.set(HttpHeaderNames.CACHE_CONTROL, "no-cache, no-store, max-age=0, must-revalidate");
        headers.set(HttpHeaderNames.PRAGMA, HttpHeaderValues.NO_CACHE);
        headers.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED);
        response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE);
        ctx.writeAndFlush(response);

        final boolean sendThroughTunnelPings = hasTunnelPingParam(request.uri());
        final String uri = uriWithTunnelParamsAdded(getTail(request.uri()));
        List regions = parseRegionsInUri(request.uri());

        log.info("Initiating remote SSE connection to {} in {} (original URI: {}).", uri, regions, request.uri());
        PushConnectionDetails pcd = PushConnectionDetails.from(uri, regions);

        String[] tags = Util.getTaglist(request.uri(), pcd.target, getRegion(request.uri()));

        Counter numDroppedBytesCounter = SpectatorUtils.newCounter(Constants.numDroppedBytesCounterName, pcd.target, tags);
        Counter numDroppedMessagesCounter = SpectatorUtils.newCounter(Constants.numDroppedMessagesCounterName, pcd.target, tags);
        Counter numMessagesCounter = SpectatorUtils.newCounter(Constants.numMessagesCounterName, pcd.target, tags);
        Counter numBytesCounter = SpectatorUtils.newCounter(Constants.numBytesCounterName, pcd.target, tags);
        Counter drainTriggeredCounter = SpectatorUtils.newCounter(Constants.drainTriggeredCounterName, pcd.target, tags);
        Counter numIncomingMessagesCounter = SpectatorUtils.newCounter(Constants.numIncomingMessagesCounterName, pcd.target, tags);

        BlockingQueue queue = new LinkedBlockingQueue(queueCapacity.get());

        drainFuture = scheduledExecutorService.scheduleAtFixedRate(() -> {
            try {
                if (queue.size() > 0 && ctx.channel().isWritable()) {
                    drainTriggeredCounter.increment();
                    final List items = new ArrayList<>(queue.size());
                    synchronized (queue) {
                        queue.drainTo(items);
                    }
                    for (String data : items) {
                        ctx.write(Unpooled.copiedBuffer(data, StandardCharsets.UTF_8));
                        numMessagesCounter.increment();
                        numBytesCounter.increment(data.length());
                    }
                    ctx.flush();
                }
            } catch (Exception ex) {
                log.error("Error writing to channel", ex);
            }
        }, writeIntervalMillis.get(), writeIntervalMillis.get(), TimeUnit.MILLISECONDS);

        subscription = connectionBroker.connect(pcd)
                .filter(event -> !event.equalsIgnoreCase(TunnelPingMessage) || sendThroughTunnelPings)
                .doOnNext(event -> {
                    numIncomingMessagesCounter.increment();
                    if (!Constants.DUMMY_TIMER_DATA.equals(event)) {
                      String data = Constants.SSE_DATA_PREFIX + event + Constants.SSE_DATA_SUFFIX;
                        boolean offer = false;
                        synchronized (queue) {
                            offer = queue.offer(data);
                        }
                        if (!offer) {
                          numDroppedBytesCounter.increment(data.length());
                          numDroppedMessagesCounter.increment();
                      }
                    }
                })
                .subscribe();
    }

    @Override
    public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
        log.info("Channel {} is unregistered. URI: {}", ctx.channel(), uriForLogging);
        unsubscribeIfSubscribed();
        super.channelUnregistered(ctx);
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        log.info("Channel {} is inactive. URI: {}", ctx.channel(), uriForLogging);
        unsubscribeIfSubscribed();
        super.channelInactive(ctx);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        log.warn("Exception caught by channel {}. URI: {}", ctx.channel(), uriForLogging, cause);
        unsubscribeIfSubscribed();
        // This is the tail of handlers. We should close the channel between the server and the client,
        // essentially causing the client to disconnect and terminate.
        ctx.close();
    }

    /** Unsubscribe if it's subscribed. */
    private void unsubscribeIfSubscribed() {
        if (subscription != null && !subscription.isUnsubscribed()) {
            log.info("SSE unsubscribing subscription with URI: {}", uriForLogging);
            subscription.unsubscribe();
        }
        if (drainFuture != null) {
            drainFuture.cancel(false);
        }
        if (scheduledExecutorService != null) {
            scheduledExecutorService.shutdown();
        }
    }

    private boolean hasTunnelPingParam(String uri) {
        return uri != null && uri.contains(TunnelPingParamName);
    }

    private Observable responseToRegionData(String region, HttpClientResponse resp) {
        final int code = resp.getStatus().code();
        return resp.getContent()
                .collect(Unpooled::buffer,
                        ByteBuf::writeBytes)
                .map(byteBuf -> new RegionData(region, true,
                        byteBuf.toString(StandardCharsets.UTF_8), code)
                )
                .onErrorReturn(t -> new RegionData(region, false, t.getMessage(), code));
    }

    private void writeDataAndCloseChannel(ChannelHandlerContext ctx, ArrayList result) {
        try {
            String serialized = responseToString(result);
            HttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
                    HttpResponseStatus.OK,
                    Unpooled.copiedBuffer(serialized, Charset.defaultCharset()));

            HttpHeaders headers = response.headers();
            headers.add(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON + "; charset=utf-8");
            headers.add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, "*");
            headers.add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS,
                    "Origin, X-Requested-With, Accept, Content-Type, Cache-Control");
            headers.add(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS,
                    "GET, OPTIONS, PUT, POST, DELETE, CONNECT");
            headers.add(HttpHeaderNames.CONTENT_LENGTH, serialized.length());

            ctx.writeAndFlush(response)
                    .addListener(__ -> ctx.close());
        } catch (Exception ex) {
            log.error("Error serializing cross regional response: {}", ex.getMessage(), ex);
        }
    }

    private String uriWithTunnelParamsAdded(String uri) {
        QueryStringDecoder queryStringDecoder = new QueryStringDecoder(uri);
        QueryStringEncoder queryStringEncoder = new QueryStringEncoder(queryStringDecoder.path());

        queryStringDecoder.parameters().forEach((key, value) -> value.forEach(val -> queryStringEncoder.addParam(key, val)));

        queryStringEncoder.addParam(TunnelPingParamName, "true");
        queryStringEncoder.addParam(TagsParamName, OriginRegionTagName + TagNameValDelimiter + getLocalRegion());
        return queryStringEncoder.toString();
    }

    private static void send100Contine(ChannelHandlerContext ctx) {
        FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1,
                HttpResponseStatus.CONTINUE);
        ctx.writeAndFlush(response);
    }

    private boolean isCrossRegionStreamingPath(String uri) {
        return Util.startsWithAnyOf(getTail(uri), this.pushPrefixes);
    }

    private static String getTail(String uri) {
        return uri.replaceFirst("^/region/.*?/", "/");
    }

    /**
     * Fetches a region from a URI if it contains one, returns garbage if not.
     *
     * @param uri The uri from which to fetch the region.
     * @return The region embedded in the URI, always lower case.
     * */
    private static String getRegion(String uri) {
        return uri.replaceFirst("^/region/", "")
                .replaceFirst("/.*$", "")
                .trim()
                .toLowerCase();
    }

    /**
     * Checks for a specific `all` string to connect to all regions
     * specified by {@link CrossRegionHandler#getTunnelRegions()}
     */
    private static boolean isAllRegion(String region) {
        return region != null && region.trim().equalsIgnoreCase("all");
    }

    private static String responseToString(List dataList) {
        StringBuilder sb = new StringBuilder("[");
        boolean first = true;
        for (RegionData data : dataList) {
            if (first)
                first = false;
            else {
                sb.append(",");
            }
            if (data.isSuccess()) {
                String outputData = getForceWrappedJson(data.getData(), data.getRegion(), data.getResponseCode(), null);
                sb.append(outputData);
            } else {
                sb.append(getForceWrappedJson("", data.getRegion(), data.getResponseCode(), data.getData()));
            }
        }
        sb.append("]");
        return sb.toString();
    }

    private final static String regKey = "mantis.meta.origin";
    private final static String errKey = "mantis.meta.errorString";
    private final static String codeKey = "mantis.meta.origin.response.code";

    public static String getWrappedJson(String data, String region, String err) {
        return getWrappedJsonIntl(data, region, err, 0, false);
    }

    public static String getForceWrappedJson(String data, String region, int code, String err) {
        return getWrappedJsonIntl(data, region, err, code, true);
    }

    private static String getWrappedJsonIntl(String data, String region, String err, int code, boolean forceJson) {
        try {
            JSONObject o = new JSONObject(data);
            o.put(regKey, region);
            if (err != null && !err.isEmpty())
                o.put(errKey, err);
            if (code > 0)
                o.put(codeKey, "" + code);
            return o.toString();
        } catch (JSONException e) {
            try {
                JSONArray a = new JSONArray(data);
                if (!forceJson)
                    return data;
                JSONObject o = new JSONObject();
                o.put(regKey, region);
                if (err != null && !err.isEmpty())
                    o.put(errKey, err);
                if (code > 0)
                    o.put(codeKey, "" + code);
                o.accumulate("response", a);
                return o.toString();
            } catch (JSONException ae) {
                if (!forceJson)
                    return data;
                JSONObject o = new JSONObject();
                o.put(regKey, region);
                if (err != null && !err.isEmpty())
                    o.put(errKey, err);
                if (code > 0)
                    o.put(codeKey, "" + code);
                o.put("response", data);
                return o.toString();
            }
        }
    }
}





© 2015 - 2024 Weber Informatics LLC | Privacy Policy