All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.marklogic.xcc.impl.ContentSourceImpl Maven / Gradle / Ivy

There is a newer version: 11.3.0
Show newest version
/*
 * Copyright 2003-2018 MarkLogic Corporation
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.marklogic.xcc.impl;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.Constructor;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.charset.Charset;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.Principal;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Formatter;
import java.util.logging.Handler;
import java.util.logging.Level;
import java.util.logging.LogManager;
import java.util.logging.Logger;

import javax.security.auth.Subject;
import javax.security.auth.login.AppConfigurationEntry;
import javax.security.auth.login.Configuration;
import javax.security.auth.login.LoginContext;
import javax.security.auth.login.LoginException;
import javax.security.auth.kerberos.KerberosTicket;

import org.ietf.jgss.GSSContext;
import org.ietf.jgss.GSSCredential;
import org.ietf.jgss.GSSException;
import org.ietf.jgss.GSSManager;
import org.ietf.jgss.GSSName;
import org.ietf.jgss.Oid;

import sun.security.krb5.KrbException;
import sun.security.krb5.PrincipalName;

import com.marklogic.io.Base64;
import com.marklogic.io.IOHelper;
import com.marklogic.xcc.ContentSource;
import com.marklogic.xcc.Session;
import com.marklogic.xcc.UserCredentials;
import com.marklogic.xcc.spi.ConnectionProvider;

@SuppressWarnings("deprecation")
public class ContentSourceImpl implements ContentSource {
	
    public static enum AuthType {
        NONE, BASIC, DIGEST, NEGOTIATE
    };
    
    private static final String DEFAULT_LOGGER_NAME = "com.marklogic.xcc";
    private static final String XCC_LOGGING_CONFIG_FILE = "xcc.logging.properties";
    private static final String XCC_CONFIG_FILE = "xcc.properties";
    private static final String SYSTEM_LOGGING_CONFIG_CLASS = "java.util.logging.config.class";
    private static final String SYSTEM_LOGGING_CONFIG_FILE = "java.util.logging.config.file";

    private final ConnectionProvider connectionProvider;
    private final String user;
    private final char[] password;
    private final String contentBase;
    private boolean authenticationPreemptive = false; 
    private boolean challengeIgnored = false; // for regression testing only
    /**
     * logger is initiated before initializeConfig()
     */
    private Logger logger = newDefaultLogger();

    private AuthType authType = AuthType.NONE;
    private String challenge;

    private static Random random = new Random();
    
    private Credentials credentials;

    private static Logger newDefaultLogger() {
        LogManager logManager = LogManager.getLogManager();
        Logger logger = logManager.getLogger(DEFAULT_LOGGER_NAME);

        if (logger != null) {
            return logger;
        }

        if ((System.getProperty(SYSTEM_LOGGING_CONFIG_CLASS) != null)
                || (System.getProperty(SYSTEM_LOGGING_CONFIG_FILE) != null)) {
            // If custom config file or class, don't override anything
            return Logger.getLogger(DEFAULT_LOGGER_NAME);
        }

        return customizedLogger(logManager);
    }

    private void initializeConfig() {
        URL url = getClass().getClassLoader().getResource(XCC_CONFIG_FILE);
        Properties props = System.getProperties();
        if (url != null) {
            try (FileInputStream is = new FileInputStream(url.getPath())) {
                props.load(is);
            } catch (IOException e) {
                logger.log(Level.WARNING,
                    "property file not found:" + url.getPath());
            }
        }
    }
    
    public ContentSourceImpl(ConnectionProvider connectionProvider, String user, char[] password, String contentBase) {
        this.connectionProvider = connectionProvider;
        this.user = user;
        this.password = password;
        credentials = new Credentials(user, password);

        String cbName = contentBase;

        if (cbName != null) {
            cbName = cbName.trim();

            if (cbName.length() == 0) {
                cbName = null;
            }
        }

        this.contentBase = cbName;
        initializeConfig();
    }

    public ConnectionProvider getConnectionProvider() {
		return connectionProvider;
	}

    public Session newSession() {
        return (new SessionImpl(this, connectionProvider, credentials, contentBase));
    }

    public Session newSession(String userName, char[] password) {
        return (new SessionImpl(this, connectionProvider, 
                new Credentials(userName, password), contentBase));
    }

    public Session newSession(String user, char[] password, String contentBaseArg) {
        String contentBase = (contentBaseArg == null) ? this.contentBase : contentBaseArg;
        
        return (new SessionImpl(this, connectionProvider, 
                new Credentials(user, password), contentBase));
    }

    public Session newSession(String databaseId) {
        return (new SessionImpl(this, connectionProvider, credentials, databaseId));
    }

    public Logger getDefaultLogger() {
        return logger;
    }

    public void setDefaultLogger(Logger logger) {
        this.logger = logger;
    }

    public boolean isAuthenticationPreemptive() {
    	return this.authenticationPreemptive;
    }
    
    public void setAuthenticationPreemptive(boolean value) {
    	this.authenticationPreemptive = value;
    }

    public void setAuthChallenge(String challenge) {
    	synchronized(this) {
    		this.authType = AuthType.valueOf(challenge.split(" ")[0].toUpperCase());
    		this.challenge = challenge;
    	}
    }

    /**
     * For regression testing only; returns whether session to ignore authentication challenges and fail immediately.
     */
    public boolean isChallengeIgnored() {
        return challengeIgnored;
    }

    /**
     * For regression testing only; tells session to ignore authentication challenges and fail immediately.
     */
    public void setChallengeIgnored(boolean challengeIgnored) {
        this.challengeIgnored = challengeIgnored;
    }

    public String getAuthString(String method, String uri, UserCredentials credentials) {
        AuthType authType;
        String challenge;
        synchronized(this) {
            authType = this.authType;
            challenge = this.challenge;
        }
        switch (authType) {
        case BASIC:
            return credentials.toHttpBasicAuth();
        case DIGEST:
            return credentials.toHttpDigestAuth(method, uri, challenge);
        case NEGOTIATE:
            return credentials.toHttpNegotiateAuth(connectionProvider.getHostName(), challenge);
        default:
            return isAuthenticationPreemptive() ? credentials.toHttpBasicAuth() : null;
        }
    }

    @Override
    public String toString() {
        return "user=" + ((user == null) ? "{none}" : user) + ", cb="
                + ((contentBase == null) ? "{none}" : contentBase) + " [provider: " + connectionProvider.toString()
                + "]";
    }

    // -------------------------------------------------------------

    private static Logger customizedLogger(LogManager logManager) {
        Properties props = loadLoggingPropertiesFromResource();
        Logger logger = Logger.getLogger(DEFAULT_LOGGER_NAME);
        List handlers = getLoggerHandlers(logger, logManager, props);

        for (Iterator it = handlers.iterator(); it.hasNext();) {
            logger.addHandler(it.next());
        }

        boolean useParentHandlers = getUseParentHandlersFlag(logger, logManager, props);

        logger.setUseParentHandlers(useParentHandlers);

        logManager.addLogger(logger);

        return logger;
    }

    private static Properties loadLoggingPropertiesFromResource() {
        Properties props = new Properties();
        URL url = ClassLoader.getSystemResource(XCC_LOGGING_CONFIG_FILE);
        if (url != null) {
            try (FileInputStream is = new FileInputStream(url.getPath())) {
                props.load(is);
                return props;
            } catch (IOException e) {
                //property file not found
                Logger logger = Logger.getLogger(DEFAULT_LOGGER_NAME);
                if(logger!=null) {
                    logger.warning("property file not found: " + url);
                }
            }
        }
        // Load properties internally from com.marklogic.xcc package in
        // xcc.jar
        try (InputStream is = 
             ContentSource.class.getResourceAsStream(XCC_LOGGING_CONFIG_FILE)) {
            if (is != null) {
                props.load(is);
            }
        } catch (IOException e) {
            // property file not found
            Logger logger = Logger.getLogger(DEFAULT_LOGGER_NAME);
            if (logger!=null) {
                logger.warning("Error loading default logging file: " + 
                    e.getMessage());
            }
        }
        return props;
    }

    private static List getLoggerHandlers(Logger logger, LogManager logManager, Properties props) {
        String propName = logger.getName() + ".handlers";
        String handlerPropVal = getPropertyValue(propName, logManager, props);

        if (handlerPropVal == null) {
            return new ArrayList(0);
        }

        String[] handlerClassNames = handlerPropVal.split("\\\\s*,?\\\\s*");
        List handlers = new ArrayList(handlerClassNames.length);
        Level level = getLoggerLevel(logger, logManager, props);

        if (level != null)
            logger.setLevel(level);

        for (int i = 0; i < handlerClassNames.length; i++) {
            try {
                Class handlerClass = Class.forName(handlerClassNames[i]).asSubclass(Handler.class);
                Handler handler = handlerClass.newInstance();
                Formatter formatter = getFormatter(handler, logManager, props);

                handlers.add(handler);
                if (formatter != null)
                    handler.setFormatter(formatter);
                if (level != null)
                    handler.setLevel(level);
            } catch (Exception e) {
                // Do nothing, can't instantiate the handler class
            }
        }

        return handlers;
    }

    private static Formatter getFormatter(Handler handler, LogManager logManager, Properties props) {
        String propName = handler.getClass().getName() + ".formatter";
        String formatterClassName = getPropertyValue(propName, logManager, props);

        try {
            Class clazz = Class.forName(formatterClassName).asSubclass(Formatter.class);
            Constructor cons = null;

            try {
                cons = clazz.getConstructor(new Class[] { Properties.class, LogManager.class });
            } catch (Exception e) {
                // do nothing, may not be our LogFormatter class
            }

            if (cons != null) {
                return cons.newInstance(new Object[] { props, logManager });
            }

            return (Formatter)Class.forName(formatterClassName).newInstance();
        } catch (Exception e) {
            return null;
        }
    }

    private static Level getLoggerLevel(Logger logger, LogManager logManager, Properties props) {
        String propName = logger.getName() + ".level";
        String levelName = getPropertyValue(propName, logManager, props);

        try {
            return Level.parse(levelName);
        } catch (Exception e) {
            return null;
        }
    }

    private static boolean getUseParentHandlersFlag(Logger logger, LogManager logManager, Properties props) {
        String propName = logger.getName() + ".useParentHandlers";
        String propValue = getPropertyValue(propName, logManager, props);

        if (propValue == null) {
            return false;
        }

        try {
            return Boolean.valueOf(propValue).booleanValue();
        } catch (Exception e) {
            return false;
        }
    }

    private static String getPropertyValue(String propName, LogManager logManager, Properties props) {
        String propVal = props.getProperty(propName);

        if (propVal != null) {
            return propVal.trim();
        }

        propVal = logManager.getProperty(propName);

        if (propVal != null) {
            return propVal.trim();
        }

        return null;
    }

    // -------------------------------------------------------------

    static class Credentials implements UserCredentials {
        private String user;
        private char[] password;
        private String basicAuth;
        private String HA1;

        private LoginContext loginContext;

    /**
     * Class to create Kerberos Configuration object which specifies the
     * Kerberos Login Module to be used for authentication.
     *
     */
    private class KerberosLoginConfiguration extends Configuration {

        @Override
        public AppConfigurationEntry[] getAppConfigurationEntry(String name) {
            Map options = new HashMap();
            options.put("refreshKrb5Config", "true");
            options.put("useTicketCache", "true");
            return new AppConfigurationEntry[] {
                    new AppConfigurationEntry("com.sun.security.auth.module.Krb5LoginModule",
                            AppConfigurationEntry.LoginModuleControlFlag.REQUIRED, options) };
        }
    }


    /**
     * This method checks the validity of the TGT in the cache and build the
     * Subject inside the LoginContext using Krb5LoginModule and the TGT cached
     * by the Kerberos client. It assumes that a valid TGT is already present in
     * the kerberos client's cache.
     *
     * @throws KrbException
     * @throws IOException
     * @throws LoginException
     */
    private void buildSubjectCredentials() throws KrbException, IOException, LoginException {
        Subject subject = new Subject();
        sun.security.krb5.Credentials cred;
        // Check if the cache already has valid TGT information. If not, throw exceptions
        if(user != null && !user.equals("")) {
            cred = sun.security.krb5.Credentials.acquireTGTFromCache(new PrincipalName(user), null);
        }
        else {
            cred = sun.security.krb5.Credentials.acquireTGTFromCache(null, null);
        }
        if (cred == null) {
            throw new KrbException("No ticket granting ticket in the cache");
        } else {
            Date endTime = cred.getEndTime();
            if (endTime != null) {
                if (endTime.compareTo(new Date()) == -1) {
                    throw new KrbException("The ticket granting ticket in the cache is no longer valid");
                }
            }
        }

        /*
         * We are not getting the TGT from KDC here. The actual TGT is got from
         * the KDC using kinit or equivalent but we use the cached TGT in order
         * to build the LoginContext and populate the TGT inside the Subject
         * using Krb5LoginModule
         */
        loginContext = new LoginContext("Krb5LoginContext", subject, null, new KerberosLoginConfiguration());
        loginContext.login();
    }

    /**
     * Creates a privileged action which will be executed as the Subject using
     * Subject.doAs() method. We do this in order to create a context of the
     * user who has the service ticket and reuse this context for subsequent
     * requests
     */
    private static class CreateAuthorizationHeaderAction implements PrivilegedAction {
        String clientPrincipalName;
        String serverPrincipalName;

        private StringBuffer outputToken = new StringBuffer();

        private CreateAuthorizationHeaderAction(final String clientPrincipalName, final String serverPrincipalName) {
            this.clientPrincipalName = clientPrincipalName;
            this.serverPrincipalName = serverPrincipalName;
        }

        private String getNegotiateToken() {
            return outputToken.toString();
        }

        /*
         * Here GSS API takes care of getting the service ticket from the Subject
         * cache or by using the TGT information populated in the subject which is
         * done by buildSubjectCredentials method. The service ticket received is
         * populated in the subject's private credentials along with the TGT
         * information since we will be executing this method as the Subject.
         * For subsequent requests, the cached service ticket will be re-used.
         * For this to work the System property
         * javax.security.auth.useSubjectCredsOnly must be set to true.
         */
        public Object run() {
            try {
                Oid krb5Mechanism = new Oid("1.2.840.113554.1.2.2");
                Oid krb5PrincipalNameType = new Oid("1.2.840.113554.1.2.2.1");
                final GSSManager manager = GSSManager.getInstance();
                final GSSName clientName = manager.createName(clientPrincipalName, krb5PrincipalNameType);
                final GSSCredential clientCred = manager.createCredential(clientName, 8 * 3600, krb5Mechanism,
                        GSSCredential.INITIATE_ONLY);
                final GSSName serverName = manager.createName(serverPrincipalName, krb5PrincipalNameType);

                final GSSContext context = manager.createContext(serverName, krb5Mechanism, clientCred,
                        GSSContext.DEFAULT_LIFETIME);

                byte[] inToken = new byte[0]; // since
                byte[] outToken = context.initSecContext(inToken, 0, inToken.length);
                outputToken.append(new String(Base64.encodeBytes(outToken,Base64.DONT_BREAK_LINES)));
                context.dispose();
            } catch (GSSException exception) {
                throw new RuntimeException(exception.getMessage());
            }
            return null;
        }
    }

    /**
     * This method builds the Authorization header for Kerberos. It
     * generates a request token based on the service ticket, client principal name and
     * time-stamp
     *
     * @param serverPrincipalName
     *            the name registered with the KDC of the service for which we
     *            need to authenticate
     * @return the HTTP Authorization header token
     */
    private String getAuthorizationHeader(String serverPrincipalName) throws GSSException, LoginException, KrbException, IOException
    {
        /*
         * Get the principal from the Subject's private credentials and populate
         * the client and server principal name for the GSS API
         */
        final String clientPrincipal = getClientPrincipalName();
        final CreateAuthorizationHeaderAction action = new CreateAuthorizationHeaderAction(clientPrincipal,
                serverPrincipalName);

        /*
         * Check if the TGT in the Subject's private credentials are valid. If
         * valid, then we use the TGT in the Subject's private credentials. If
         * not, we build the Subject's private credentials again from valid TGT
         * in the Kerberos client cache.
         */
        Set privateCreds = loginContext.getSubject().getPrivateCredentials();
        for (Object privateCred : privateCreds) {
            if (privateCred instanceof KerberosTicket) {
                String serverPrincipalTicketName = ((KerberosTicket) privateCred).getServer().getName();
                if ((serverPrincipalTicketName.startsWith("krbtgt"))
                        && ((KerberosTicket) privateCred).getEndTime().compareTo(new Date()) == -1) {
                    buildSubjectCredentials();
                    break;
                }
            }
        }

        /*
         * Subject.doAs takes in the Subject context and the action to be run as
         * arguments. This method executes the action as the Subject given in
         * the argument. We do this in order to provide the Subject's context so
         * that we reuse the service ticket which will be populated in the
         * Subject rather than getting the service ticket from the KDC for each
         * request. The GSS API populates the service ticket in the Subject and
         * reuses it
         *
         */
        Subject.doAs(loginContext.getSubject(), action);
        return action.getNegotiateToken();
    }

    /**
     * This method is responsible for getting the client principal name from the
     * subject's principal set
     *
     * @return String the Kerberos principal name populated in the subject
     * @throws IllegalStateException
     *             if there is more than 0 or more than 1 principal is present
     */
    private String getClientPrincipalName() {
        final Set principalSet = loginContext.getSubject().getPrincipals();
        if (principalSet.size() != 1)
            throw new IllegalStateException(
                    "Only one principal per subject is expected. Found 0 or more than one principals :" + principalSet);
        return principalSet.iterator().next().getName();
    }

        public Credentials(String user, char[] password) {
            this.user = user;
            this.password = password;
            if (user != null && password != null) {
                initBasicAuth();
            }
        }

        public String getUserName() {
            return user;
        }
        
        void initBasicAuth(Charset encoding) 
                throws UnsupportedEncodingException {
            byte[] ubytes = (user + ":").getBytes(encoding);
            ByteBuffer pbuf = encoding.encode(CharBuffer.wrap(password));
            byte[] upbytes = new byte[pbuf.remaining() + ubytes.length];
            System.arraycopy(ubytes, 0, upbytes, 0, ubytes.length);
            pbuf.get(upbytes, ubytes.length, pbuf.remaining());
            basicAuth = "basic " + 
                    Base64.encodeBytes(upbytes, Base64.DONT_BREAK_LINES);
        }
        
        void initBasicAuth() {
            try {
                initBasicAuth(Charset.forName("UTF-8"));
            } catch (UnsupportedEncodingException e) {
                try {
                    initBasicAuth(Charset.defaultCharset());
                } catch (UnsupportedEncodingException e1) {
                }
            }
        }

        public String toHttpBasicAuth() {  
            if ((user == null) || ((password == null) && basicAuth == null)) {
                throw new IllegalStateException("Invalid authentication credentials");
            }
            if (password != null) {
                Arrays.fill(password, (char)0);
                password = null;
            }
            return basicAuth;
        }

        private static final AtomicLong nonceCounter = new AtomicLong();

        public String toHttpDigestAuth(String method, String uri, String challengeHeader) {
            if ((user == null) || (password == null && HA1 == null)) {
                throw new IllegalStateException("Invalid authentication credentials");
            }
            
            if ((challengeHeader == null) || !challengeHeader.startsWith("Digest ")) {
                return null;
            }

            String pairs[] = challengeHeader.substring("Digest ".length()).split(", +");

            Map params = new HashMap();

            for (String pair : pairs) {
                String nv[] = pair.split("=", 2);
                params.put(nv[0].toLowerCase(), nv[1].substring(1, nv[1].length() - 1));
            }

            String realm = params.get("realm");

            if (HA1 == null) {
                HA1 = digestCalcHA1(user, realm, password);
            }

            String nonce = params.get("nonce");
            String qop = params.get("qop");
            String opaque = params.get("opaque");

            byte[] bytes = new byte[16];

            synchronized (random) {
                random.nextBytes(bytes);
            }

            String cNonce = IOHelper.bytesToHex(bytes);

            String nonceCount = Long.toHexString(nonceCounter.incrementAndGet());

            String response = digestCalcResponse(HA1, nonce, nonceCount, cNonce, qop, method, uri);

            StringBuilder buf = new StringBuilder();

            buf.append("Digest username=\"");
            buf.append(user);
            buf.append("\", realm=\"");
            buf.append(realm);
            buf.append("\", nonce=\"");
            buf.append(nonce);
            buf.append("\", uri=\"");
            buf.append(uri);
            buf.append("\", qop=\"auth\", nc=\"");
            buf.append(nonceCount);
            buf.append("\", cnonce=\"");
            buf.append(cNonce);
            buf.append("\", response=\"");
            buf.append(response);
            buf.append("\", opaque=\"");
            buf.append(opaque);
            buf.append("\"");

            return buf.toString();
        }

        public String toHttpNegotiateAuth(String hostName, String challenge) {

            try {
                 if (loginContext == null) 
                   buildSubjectCredentials();
                 String authLine = new String("Negotiate " + getAuthorizationHeader("HTTP/" + hostName));
                 return authLine;

            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        @Override
        public String toString() {
            return "user=" + user;
        }
    }

    public static String digestCalcResponse(String HA1, String nonce, String nonceCount, String cNonce, String qop,
            String method, String uri) {

        try {
            MessageDigest digest = MessageDigest.getInstance("MD5");

            StringBuilder plaintext = new StringBuilder();

            plaintext.append(method);
            plaintext.append(":");
            plaintext.append(uri);

            digest.update(plaintext.toString().getBytes(), 0, plaintext.length());

            String HA2 = IOHelper.bytesToHex(digest.digest());

            plaintext.setLength(0);
            plaintext.append(HA1);
            plaintext.append(":");
            plaintext.append(nonce);
            plaintext.append(":");
            if (qop != null) {
                plaintext.append(nonceCount);
                plaintext.append(":");
                plaintext.append(cNonce);
                plaintext.append(":");
                plaintext.append(qop);
                plaintext.append(":");
            }
            plaintext.append(HA2);

            digest.update(plaintext.toString().getBytes(), 0, plaintext.length());

            return IOHelper.bytesToHex(digest.digest());
        } catch (NoSuchAlgorithmException e) {
            // this really shouldn't happen
            throw new RuntimeException(e);
        }
    }

    public static String digestCalcHA1(String userName, String realm, char[] password) {

        try {
            MessageDigest digest = MessageDigest.getInstance("MD5");

            StringBuilder plaintext = new StringBuilder();

            plaintext.append(userName);
            plaintext.append(":");
            plaintext.append(realm);
            plaintext.append(":");
            byte[] ubytes = plaintext.toString().getBytes();
            
            ByteBuffer pbuf = 
                    Charset.defaultCharset().encode(CharBuffer.wrap(password));
            byte[] upbytes = new byte[pbuf.remaining() + ubytes.length];
            System.arraycopy(ubytes, 0, upbytes, 0, ubytes.length);
            pbuf.get(upbytes, ubytes.length, pbuf.remaining());
            
            digest.update(upbytes, 0, upbytes.length);
            return IOHelper.bytesToHex(digest.digest());
        } catch (NoSuchAlgorithmException e) {
            // this really shouldn't happen
            throw new RuntimeException(e);
        }
    }

    @Override
    public Session newSession(String userName, String password) {
        return newSession(userName, 
                password == null ? null : password.toCharArray());
    }

    @Override
    public Session newSession(String userName, String password,
            String contentbaseId) {
        return newSession(userName, 
                password == null ? null : password.toCharArray(), 
                        contentbaseId);
    }
}