All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.finos.tracdap.common.data.pipeline.GrpcDownloadSink Maven / Gradle / Ivy

Go to download

TRAC D.A.P. data library, interfaces and core functionality for working with primary data

The newest version!
/*
 * Licensed to the Fintech Open Source Foundation (FINOS) under one or
 * more contributor license agreements. See the NOTICE file distributed
 * with this work for additional information regarding copyright ownership.
 * FINOS licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with the
 * License. You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.finos.tracdap.common.data.pipeline;

import org.finos.tracdap.common.data.util.Bytes;
import org.finos.tracdap.common.exception.ETracInternal;
import org.finos.tracdap.common.exception.EUnexpected;

import com.google.protobuf.ByteString;
import com.google.protobuf.MessageLite;
import com.google.protobuf.UnsafeByteOperations;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import org.apache.arrow.memory.ArrowBuf;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Flow;
import java.util.function.BiFunction;
import java.util.function.Supplier;


public class GrpcDownloadSink {

    public static final boolean STREAMING = true;
    public static final boolean AGGREGATED = false;

    private static final long REQUEST_MAX_BUFFER = 128;
    private static final long REQUEST_MIN_BUFFER = 64;

    private final ServerCallStreamObserver responseStream;
    private final boolean streaming;

    private final Supplier builder;
    private BiFunction dataFunc;

    private final TBuilder aggregateResponse;
    private final List aggregateBuffer;

    private CompletableFuture firstMessage;
    private Flow.Subscription subscription;
    private Runnable cleanup;

    private long nRequested;
    private long nReceived;


    public GrpcDownloadSink(StreamObserver response, Supplier builder, boolean streaming) {

        if (!(response instanceof ServerCallStreamObserver))
            throw new EUnexpected();

        responseStream = (ServerCallStreamObserver) response;
        responseStream.setOnCancelHandler(this::apiOnCancel);
        responseStream.setOnReadyHandler(this::apiOnReady);

        this.builder = builder;
        this.streaming = streaming;

        if (streaming) {
            aggregateResponse = null;
            aggregateBuffer = null;
        }
        else {
            aggregateResponse = builder.get();
            aggregateBuffer = new ArrayList<>();
        }
    }

    public void whenComplete(Runnable cleanup) {
        this.cleanup = cleanup;
    }

    public  CompletionStage start(TRequest request) {

        return CompletableFuture.completedFuture(request);
    }

    public Void failed(Throwable error) {

        try {

            if (!streaming)
                releaseAggregate();

            responseStream.onError(error);

            if (subscription != null)
                subscription.cancel();

            return null;
        }
        finally {
            if (cleanup != null) {
                cleanup.run();
                cleanup = null;
            }
        }
    }

    @SuppressWarnings("unused")
    public  CompletableFuture
    firstMessage(BiFunction resultFunc, Class resultType) {

        // Sometimes TResult cannot be inferred, allow client code to be explicit

        return firstMessage(resultFunc);
    }

    public  CompletableFuture
    firstMessage(BiFunction resultFunc) {

        if (firstMessage != null)
            throw new EUnexpected();

        var future = new CompletableFuture();
        firstMessage = future.thenApply(s -> resultFunc.apply(builder.get(), s));

        firstMessage
            .thenAccept(this::firstMessageComplete)
            .exceptionally(this::firstMessageFailed);

        return future;
    }

    public Flow.Subscriber
    dataStream(BiFunction wrapFunc) {

        this.dataFunc = wrapFunc;
        return new DownloadSubscriber();
    }

    private void apiOnCancel() {
        // TODO: Cancel not supported yet
    }

    private void apiOnReady() {
        if (nRequested - nReceived < REQUEST_MIN_BUFFER && subscription != null)
            requestMore();
    }

    @SuppressWarnings("unchecked")
    private void firstMessageComplete(TBuilder response) {

        if (streaming)
            responseStream.onNext((TResponse) response.build());
        else
            aggregateResponse.mergeFrom(response.buildPartial());

        if (subscription != null)
            requestMore();
    }

    private Void firstMessageFailed(Throwable error) {

        try {

            if (!streaming)
                releaseAggregate();

            responseStream.onError(error);

            if (subscription != null)
                subscription.cancel();

            return null;
        }
        finally {
            if (cleanup != null) {
                cleanup.run();
                cleanup = null;
            }
        }
    }

    private void pipelineOnSubscribe(Flow.Subscription subscription) {

        if (this.subscription != null)
            throw new ETracInternal("Upload source is already subscribed");

        this.subscription = subscription;

        if (firstMessage.isCompletedExceptionally()) {
            subscription.cancel();
        }
        else if (firstMessage.isDone()) {
            requestMore();
        }
    }

    private void pipelineOnNext(ArrowBuf chunk) {

        nReceived += 1;

        if (streaming)
            Bytes.readFromStream(chunk, this::pipelineSendChunk);
        else
            aggregateBuffer.add(chunk);

        if (nRequested - nReceived < REQUEST_MIN_BUFFER)
            requestMore();
    }

    @SuppressWarnings("unchecked")
    private void pipelineSendChunk(ByteBuffer chunk) {

        var protoBuilder = builder.get();
        var protoBytes = UnsafeByteOperations.unsafeWrap(chunk);

        dataFunc.apply(protoBuilder, protoBytes);

        var protoMsg = (TResponse) protoBuilder.build();

        responseStream.onNext(protoMsg);
    }

    @SuppressWarnings("unchecked")
    private void pipelineOnComplete() {

        try {

            if (!streaming) {

                var bufferBytes = Bytes.readFromBuffer(aggregateBuffer);
                var protoBytes = UnsafeByteOperations.unsafeWrap(bufferBytes);

                dataFunc.apply(aggregateResponse, protoBytes);

                var protoMsg = (TResponse) aggregateResponse.build();

                responseStream.onNext(protoMsg);
            }

            responseStream.onCompleted();
        }
        finally {
            if (cleanup != null) {
                cleanup.run();
                cleanup = null;
            }
        }
    }

    private void pipelineOnError(Throwable error) {

        try {

            if (!streaming)
                releaseAggregate();

            responseStream.onError(error);
        }
        finally {
            if (cleanup != null) {
                cleanup.run();
                cleanup = null;
            }
        }
    }

    private void requestMore() {

        if (!responseStream.isReady())
            return;

        long nPending = nRequested - nReceived;
        long n = REQUEST_MAX_BUFFER - nPending;
        nRequested += n;
        subscription.request(n);
    }

    private void releaseAggregate() {

        aggregateBuffer.forEach(ArrowBuf::close);
        aggregateBuffer.clear();
    }

    private class DownloadSubscriber implements Flow.Subscriber {

        @Override
        public void onSubscribe(Flow.Subscription subscription) {
            pipelineOnSubscribe(subscription);
        }

        @Override
        public void onNext(ArrowBuf chunk) {
            pipelineOnNext(chunk);
        }

        @Override
        public void onError(Throwable throwable) {
            pipelineOnError(throwable);
        }

        @Override
        public void onComplete() {
            pipelineOnComplete();
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy