All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.marklogic.client.dataservices.impl.IOEndpointImpl Maven / Gradle / Ivy

The newest version!
/*
 * Copyright © 2024 MarkLogic Corporation. All Rights Reserved.
 */
package com.marklogic.client.dataservices.impl;

import com.fasterxml.jackson.databind.JsonNode;
import com.marklogic.client.DatabaseClient;
import com.marklogic.client.MarkLogicInternalException;
import com.marklogic.client.SessionState;
import com.marklogic.client.dataservices.IOEndpoint;
import com.marklogic.client.impl.NodeConverter;
import com.marklogic.client.io.marker.BufferableHandle;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.*;
import java.util.function.Consumer;

abstract class IOEndpointImpl implements IOEndpoint {
    private static final Logger logger = LoggerFactory.getLogger(IOEndpointImpl.class);

    final static int DEFAULT_MAX_RETRIES  = 100;
    final static int DEFAULT_BATCH_SIZE   = 100;

    private final DatabaseClient    client;
    private final IOCallerImpl caller;

    public IOEndpointImpl(DatabaseClient client, IOCallerImpl caller) {
        if (client == null)
            throw new IllegalArgumentException("null client");
        if (caller == null)
            throw new IllegalArgumentException("null caller");
        this.client = client;
        this.caller = caller;
    }

    int initBatchSize(IOCallerImpl caller) {
        JsonNode apiDeclaration = caller.getApiDeclaration();
        if (apiDeclaration.has("$bulk") && apiDeclaration.get("$bulk").isObject()
                && apiDeclaration.get("$bulk").has("inputBatchSize")
                && apiDeclaration.get("$bulk").get("inputBatchSize").isInt()) {
            return apiDeclaration.get("$bulk").get("inputBatchSize").asInt();
        }
        return DEFAULT_BATCH_SIZE;
    }

    DatabaseClient getClient() {
        return this.client;
    }
    private IOCallerImpl getCaller() {
        return this.caller;
    }

    @Override
    public String getEndpointPath() {
        return getCaller().getEndpointPath();
    }
    @Override
    public boolean allowsEndpointState() {
        return (getEndpointStateParamdef() != null);
    }
    BaseCallerImpl.ParamdefImpl getEndpointStateParamdef() {
        return getCaller().getEndpointStateParamdef();
    }
    @Override
    public boolean allowsEndpointConstants() {
        return (getEndpointConstantsParamdef() != null);
    }
    BaseCallerImpl.ParamdefImpl getEndpointConstantsParamdef() {
        return getCaller().getEndpointConstantsParamdef();
    }
    @Override
    public boolean allowsInput() {
        return (getInputParamdef() != null);
    }
    BaseCallerImpl.ParamdefImpl getInputParamdef() {
        return getCaller().getInputParamdef();
    }
    @Override
    public boolean allowsSession() {
        return (getSessionParamdef() != null);
    }
    BaseCallerImpl.ParamdefImpl getSessionParamdef() {
        return getCaller().getSessionParamdef();
    }

    @Override
    public SessionState newSessionState() {
        if (!allowsEndpointState())
            throw new IllegalStateException("endpoint does not support session state");
        return getCaller().newSessionState();
    }
    @Override
    public CallContextImpl newCallContext(){
        return newCallContext(false);
    }
    CallContextImpl newCallContext(boolean legacyContext){
        return new CallContextImpl<>(this, legacyContext);
    }

    CallContextImpl[] checkAllowedArgs(IOEndpoint.CallContext[] callCtxts) {
        if (callCtxts == null || callCtxts.length ==0)
            throw new IllegalArgumentException("null or empty contexts for call");
        CallContextImpl[] contexts = new CallContextImpl[callCtxts.length];
        for (int i=0; i < callCtxts.length; i++) {
            contexts[i] = checkAllowedArgs(callCtxts[i]);
        }
        return contexts;
    }
    CallContextImpl checkAllowedArgs(CallContext callCtxt) {
        if (!(callCtxt instanceof CallContextImpl)) {
            throw new IllegalArgumentException("Unknown implementation of call context");
        }
        CallContextImpl context = (CallContextImpl) callCtxt;
        if (context.getEndpointState() != null && !allowsEndpointState())
            throw new IllegalArgumentException("endpoint does not accept endpointState parameter");
        if (context.getSessionState() != null && !allowsSession())
            throw new IllegalArgumentException("endpoint does not accept session parameter");
        if (context.getEndpointConstants() != null && !allowsEndpointConstants())
            throw new IllegalArgumentException(
                    "endpoint does not accept "+context.getEndpointConstantsParamName()+" parameter");
        return context;
    }

    static abstract class BulkIOEndpointCallerImpl implements IOEndpoint.BulkIOEndpointCaller {
        enum WorkPhase {
            INITIALIZING, RUNNING, INTERRUPTING, INTERRUPTED, COMPLETED
        }

        private final IOEndpointImpl endpoint;
        private WorkPhase phase = WorkPhase.INITIALIZING;

        private CallContextImpl callContext;

        private CallerThreadPoolExecutor callerThreadPoolExecutor;
        private LinkedBlockingQueue> callContextQueue;
        private int threadCount;

        private long callCount = 0;

        // constructor for calling in the application thread
        BulkIOEndpointCallerImpl(IOEndpointImpl endpoint, CallContextImpl callContext) {
            this.endpoint = endpoint;
            this.callContext = callContext;
// TODO: should only create a session ID if needed
            getSession();
        }
        // constructor for concurrent calling in multiple worker threads
        BulkIOEndpointCallerImpl(IOEndpointImpl endpoint, CallContextImpl[] callContexts, int threadCount, int queueSize) {
            this.endpoint = endpoint;
            this.callerThreadPoolExecutor = new CallerThreadPoolExecutor<>(threadCount, queueSize, this);
            this.callContextQueue = new LinkedBlockingQueue<>(Arrays.asList(callContexts));
            this.threadCount = threadCount;
        }
        private void init(IOEndpointImpl endpoint, int threadCount, int queueSize) {
        }

        long getCallCount() {
            return callCount;
        }
        void incrementCallCount() {
            callCount++;
        }
        CallContextImpl getCallContext() {
            return this.callContext;
        }
        CallerThreadPoolExecutor getCallerThreadPoolExecutor() {
            return this.callerThreadPoolExecutor;
        }
        LinkedBlockingQueue> getCallContextQueue() {
            return this.callContextQueue;
        }
        int getThreadCount(){
            return this.threadCount;
        }

        boolean allowsEndpointState() {
            return callContext.getEndpoint().allowsEndpointState();
        }
        boolean allowsEndpointConstants() {
            checkCallContext();
            return callContext.getEndpoint().allowsEndpointConstants();
        }

		boolean allowsSession() {
            return callContext.getEndpoint().allowsSession();
        }
        SessionState getSession() {
            if (!allowsSession())
                return null;
            checkCallContext();
            if (callContext.getSessionState() == null) {
                // no need to refresh the session id preemptively before timeout
                // because a timed-out session id is merely a new session id
                callContext.withSessionState(callContext.getEndpoint().getCaller().newSessionState());
            }
            return callContext.getSessionState();
        }
        boolean allowsInput() {
            return callContext.getEndpoint().allowsInput();
        }

        boolean queueInput(I input, BlockingQueue queue, int batchSize) {
            if (input == null) return false;
            try {
                queue.put(input);
            } catch (InterruptedException e) {
                throw new IllegalStateException("InputStream was not added to the queue." + e.getMessage());
            }
            return checkQueue(queue, batchSize);
        }
        boolean queueAllInput(I[] input, BlockingQueue queue, int batchSize) {
            if (input == null || input.length == 0) return false;
            try {
                for (I item: input) {
                    queue.put(item);
                }
            } catch (InterruptedException e) {
                throw new IllegalStateException("InputStream was not added to the queue." + e.getMessage());
            }
            return checkQueue(queue, batchSize);
        }
        boolean checkQueue(BlockingQueue queue, int batchSize) {
            if ((queue.size() % batchSize) > 0)
                return false;

            switch (getPhase()) {
                case INITIALIZING:
                    setPhase(WorkPhase.RUNNING);
                    break;
                case RUNNING:
                    break;
                case INTERRUPTING:
                case INTERRUPTED:
                case COMPLETED:
                    throw new IllegalStateException(
                        "can only accept input when initializing or running and not when input is "+
                        getPhase().name().toLowerCase());
                default:
                    throw new MarkLogicInternalException(
                            "unexpected state for " + callContext.getEndpoint().getEndpointPath() + " during loop: " + getPhase().name());
            }

            return true;
        }
        I[] getInputBatch(BlockingQueue queue, int batchSize) {
            List inputStreamList = new ArrayList<>();
            queue.drainTo(inputStreamList, batchSize);
            return inputStreamList.toArray(endpoint.getCaller().newContentInputArray(inputStreamList.size()));
        }
        void processOutputBatch(O[] output, Consumer outputListener) {
            if (output == null || output.length == 0) return;

            for (O value: output) {
                outputListener.accept(value);
            }
        }

        WorkPhase getPhase() {
            return this.phase;
        }
        void setPhase(WorkPhase phase) {
            this.phase = phase;
        }

        @Override
        public void interrupt() {
            if (this.phase == WorkPhase.RUNNING)
                setPhase(WorkPhase.INTERRUPTING);
        }

        private void checkCallContext() {
            if(this.callContext == null)
                throw new InternalError("Can only call set and get methods for call state when using a single CallContext.");
        }

        void submitTask(Callable callable) throws RejectedExecutionException{
            FutureTask futureTask = new FutureTask<>(callable);
            getCallerThreadPoolExecutor().execute(futureTask);
        }

        void checkEndpoint(IOEndpointImpl endpoint, String endpointType) {
            if(getCallContext().getEndpoint() != endpoint)
                throw new IllegalArgumentException("Endpoint must be "+endpointType);
        }

        static class CallerThreadPoolExecutor extends ThreadPoolExecutor {

            private Boolean awaitingTermination;
            private final BulkIOEndpointCallerImpl bulkIOEndpointCaller;
            CallerThreadPoolExecutor(int threadCount, int queueSize, BulkIOEndpointCallerImpl bulkIOEndpointCaller) {

                super(threadCount, threadCount, 0, TimeUnit.MILLISECONDS,
                        new LinkedBlockingQueue<>(queueSize), new CallerRunsPolicy());
                this.bulkIOEndpointCaller = bulkIOEndpointCaller;
            }

            Boolean isAwaitingTermination() {
                return this.awaitingTermination;
            }
            synchronized void awaitTermination() throws InterruptedException {
                if (bulkIOEndpointCaller.getCallContextQueue().isEmpty() && getActiveCount()<=1) {
                    shutdown();
                }
                else {
                    awaitingTermination = true;
                    awaitTermination(Long.MAX_VALUE, TimeUnit.DAYS);
                }
            }
        }
    }
}