All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.marklogic.xcc.impl.MLCloudAuthManager Maven / Gradle / Ivy

There is a newer version: 11.3.0
Show newest version
/*
 * Copyright (c) 2023 MarkLogic Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.marklogic.xcc.impl;

import com.marklogic.xcc.ContentSource;
import com.marklogic.xcc.Request;
import com.marklogic.xcc.Session;
import com.marklogic.xcc.exceptions.RequestException;
import com.marklogic.xcc.impl.handlers.MLCloudRequestController;

import java.util.concurrent.*;
import java.util.logging.Level;
import java.util.logging.Logger;

/**
 * Utility class that manages the connection between xcc and MarkLogic Cloud.
 * It maintains a thread pool with only 1 thread that schedules
 * session token renewal periodically. This class is thread safe.
 */
public final class MLCloudAuthManager {
    // Renew the session token 1 minute before it expires
    private static final long RENEW_OFFSET_MINUTES = 1L;
    private static final Logger logger;
    private static final ScheduledExecutorService scheduler;
    // Api key hash --> MLCloudAuthContext
    private static final ConcurrentHashMap
        contextMap = new ConcurrentHashMap<>();

    static {
        logger = Logger.getLogger(MLCloudAuthManager.class.getName());
        scheduler = Executors.newScheduledThreadPool(1, r -> {
            Thread t = new Thread(r);
            t.setDaemon(true);
            t.setPriority(Thread.MAX_PRIORITY);
            return t;
        });
    }

    private MLCloudAuthManager() {}

    public static synchronized void createMLCloudAuthContext(ContentSource cs) {
        int apiKeyHash = new String(cs.getUserCredentials().
            getMLCloudAuthConfig().getApiKey()).hashCode();
        if (!contextMap.containsKey(apiKeyHash)) {
            // For each unique api key, maintain a dedicated ContentSource
            // object. Obtain the first session token and then schedule the
            // token renewal tasks.
            contextMap.put(apiKeyHash, new MLCloudAuthContext(cs));
            runObtainSessionTokenOneTime(apiKeyHash);
            runRenewSessionToken(apiKeyHash);
        }
    }
    private static void runObtainSessionTokenOneTime(int hash) {
        MLCloudAuthContext context = contextMap.get(hash);
        if (context != null && context.tokenConfig != null) {
            // Paranoia: If someone just renew the token, and it's not expiring,
            // don't do anything
            if (!context.tokenToExpire()) return;
        }
        if (logger.isLoggable(Level.INFO)) {
            logger.log(Level.INFO, "Scheduled to obtain session token one time.");
        }
        final ScheduledFuture handle =
            scheduler.schedule(new TokenRunner(hash), 0, TimeUnit.MINUTES);
        try {
            handle.get();
        } catch (InterruptedException | ExecutionException e) {
            if (logger.isLoggable(Level.WARNING)) {
                logger.log(Level.WARNING, "Exception obtaining session token " +
                    "from MarkLogic Cloud token endpoint. " + e.getMessage());
            }
        }
    }

    private static void runRenewSessionToken(int hash) {
        MLCloudAuthContext context = contextMap.get(hash);
        if (context != null && context.tokenConfig != null) {
            long initDelay = context.getInitDelay();
            long period = context.getPeriod();
            if (logger.isLoggable(Level.INFO)) {
                logger.log(Level.INFO, "Scheduled to renew session token " +
                    "periodically. initDelay=" + initDelay + " minutes, " +
                    "period=" + period + " minutes.");
            }
            scheduler.scheduleWithFixedDelay(new TokenRunner(hash),
                initDelay, period, TimeUnit.MINUTES);
        }
    }

    public static void setTokenConfig(char[] apiKey,
        char[] curToken, long tokenExpiration, long tokenDuration) {
        int hash = new String(apiKey).hashCode();
        MLCloudAuthContext context = contextMap.get(hash);
        if (context != null) {
           context.setTokenConfig(curToken, tokenExpiration, tokenDuration);
        }
    }

    public static String getSessionToken(char[] apiKey) {
        int hash = new String(apiKey).hashCode();
        MLCloudAuthContext context = contextMap.get(hash);
        if (context != null && context.tokenConfig != null) {
            return new String(context.tokenConfig.curToken);
        } else return null;
    }

    public static long getTokenExpiration(char[] apiKey) {
        int hash = new String(apiKey).hashCode();
        MLCloudAuthContext context = contextMap.get(hash);
        if (context != null && context.tokenConfig != null) {
            return context.tokenConfig.tokenExpiration;
        } else return 0;
    }

    public static long getTokenDuration(char[] apiKey) {
        int hash = new String(apiKey).hashCode();
        MLCloudAuthContext context = contextMap.get(hash);
        if (context != null && context.tokenConfig != null) {
            return context.tokenConfig.tokenDuration;
        } else return 0;
    }

    private static class MLCloudAuthContext {
        private final ContentSource cs;
        private TokenConfig tokenConfig;

        private MLCloudAuthContext(ContentSource cs) {
            this.cs = cs;
        }

        private void setTokenConfig(char[] curToken, long tokenExpiration,
                                    long tokenDuration) {
            this.tokenConfig = new TokenConfig(curToken, tokenExpiration,
                tokenDuration);
        }

        private long getInitDelay() {
            return TimeUnit.MILLISECONDS.toMinutes(tokenConfig.tokenExpiration -
                System.currentTimeMillis()) - RENEW_OFFSET_MINUTES;
        }

        private long getPeriod() {
            return tokenConfig.tokenDuration - RENEW_OFFSET_MINUTES;
        }

        private boolean tokenToExpire() {
            return (TimeUnit.MILLISECONDS.toMinutes(tokenConfig.tokenExpiration
                - System.currentTimeMillis()) - RENEW_OFFSET_MINUTES) <= 0;
        }

        private static class TokenConfig {
            private final char[] curToken;
            // Expiration time of the session token in milliseconds
            private final long tokenExpiration;
            // Lifetime of the session token in minutes
            private final long tokenDuration;

            private TokenConfig(char[] curToken, long tokenExpiration,
                                long tokenDuration) {
                this.curToken = curToken;
                this.tokenExpiration = tokenExpiration;
                this.tokenDuration = tokenDuration;
            }
        }
    }

    private static class TokenRunner implements Runnable {
        private final int apiKeyHash;

        public TokenRunner(int apiKeyHash) {
            this.apiKeyHash = apiKeyHash;
        }

        @Override
        public void run() {
            MLCloudAuthContext context = contextMap.get(apiKeyHash);
            if (context == null)
                throw new RuntimeException("Unrecognized user api key.");
            if (logger.isLoggable(Level.INFO)) {
                logger.log(Level.INFO, "Connecting to MarkLogic Cloud to " +
                    "obtain session token.");
            }
            Session session = context.cs.newSession();
            Request request = session.newAdhocQuery("()");
            MLCloudRequestController controller = new MLCloudRequestController(
                session.getUserCredentials().getMLCloudAuthConfig());
            try {
                controller.runRequest(
                    session.getContentSource().getConnectionProvider(), request,
                    session.getLogger());
            } catch (RequestException e) {
                throw new RuntimeException(
                    "Exception obtaining session token.", e);
            }
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy