All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.bouncycastle.jsse.provider.ProvSSLSessionContext Maven / Gradle / Ivy

The newest version!
package org.bouncycastle.jsse.provider;

import java.lang.ref.ReferenceQueue;
import java.lang.ref.SoftReference;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Locale;
import java.util.Map;
import java.util.logging.Logger;

import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSessionContext;

import org.bouncycastle.tls.SessionID;
import org.bouncycastle.tls.TlsSession;
import org.bouncycastle.tls.TlsUtils;
import org.bouncycastle.tls.crypto.impl.jcajce.JcaTlsCrypto;

class ProvSSLSessionContext
    implements SSLSessionContext
{
    private static final Logger LOG = Logger.getLogger(ProvSSLSessionContext.class.getName());

    private static final int provSessionCacheSize = PropertyUtils
        .getIntegerSystemProperty("javax.net.ssl.sessionCacheSize", 20480, 0, Integer.MAX_VALUE);

    // NOTE: This is configured as a simple LRU cache using the "access order" constructor
    @SuppressWarnings("serial")
    protected final Map sessionsByID = new LinkedHashMap(16, 0.75f, true)
    {
        protected boolean removeEldestEntry(Map.Entry eldest)
        {
            boolean shouldRemove = sessionCacheSize > 0 && size() > sessionCacheSize;
            if (shouldRemove)
            {
                removeSessionByPeer(eldest.getValue());
            }
            return shouldRemove;
        }
    };
    protected final Map sessionsByPeer = new HashMap();
    protected final ReferenceQueue sessionsQueue = new ReferenceQueue();

    protected final ContextData contextData;

    protected int sessionCacheSize = provSessionCacheSize;
    protected int sessionTimeoutSeconds = 86400; // 24hrs (in seconds)

    ProvSSLSessionContext(ContextData contextData)
    {
        this.contextData = contextData;
    }

    ProvSSLContextSpi getSSLContext()
    {
        return contextData.getContext();
    }

    JcaTlsCrypto getCrypto()
    {
        return contextData.getCrypto();
    }

    synchronized ProvSSLSession getSessionImpl(byte[] sessionID)
    {
        processQueue();

        return accessSession(mapGet(sessionsByID, makeSessionID(sessionID)));
    }

    synchronized ProvSSLSession getSessionImpl(String hostName, int port)
    {
        processQueue();

        SessionEntry sessionEntry = mapGet(sessionsByPeer, makePeerKey(hostName, port));
        ProvSSLSession session = accessSession(sessionEntry);
        if (session != null)
        {
            // NOTE: For the current simple cache implementation, need to 'access' the sessionByIDs entry
            sessionsByID.get(sessionEntry.getSessionID());
        }
        return session;
    }

    synchronized void removeSession(byte[] sessionID)
    {
        SessionEntry sessionEntry = mapRemove(sessionsByID, makeSessionID(sessionID));
        if (null != sessionEntry)
        {
            removeSessionByPeer(sessionEntry);
        }
    }

    synchronized ProvSSLSession reportSession(String peerHost, int peerPort, TlsSession tlsSession,
        JsseSessionParameters jsseSessionParameters, boolean addToCache)
    {
        processQueue();

        if (!addToCache)
        {
            return new ProvSSLSession(this, peerHost, peerPort, tlsSession, jsseSessionParameters);
        }

        SessionID sessionID = makeSessionID(tlsSession.getSessionID());
        SessionEntry sessionEntry = mapGet(sessionsByID, sessionID);

        ProvSSLSession session = sessionEntry == null ? null : sessionEntry.get();
        if (null == session || session.getTlsSession() != tlsSession)
        {
            session = new ProvSSLSession(this, peerHost, peerPort, tlsSession, jsseSessionParameters);

            if (null != sessionID)
            {
                sessionEntry = new SessionEntry(sessionID, session, sessionsQueue);
                sessionsByID.put(sessionID, sessionEntry);
            }
        }

        if (null != sessionEntry)
        {
            mapAdd(sessionsByPeer, sessionEntry.getPeerKey(), sessionEntry);
        }

        return session;
    }

    public synchronized Enumeration getIds()
    {
        removeAllExpiredSessions();

        ArrayList ids = new ArrayList(sessionsByID.size());
        for (SessionID sessionID : sessionsByID.keySet())
        {
            ids.add(sessionID.getBytes());
        }
        return Collections.enumeration(ids);
    }

    public SSLSession getSession(byte[] sessionID)
    {
        if (sessionID == null)
        {
            throw new NullPointerException("'sessionID' cannot be null");
        }

        return getSessionImpl(sessionID);
    }

    public synchronized int getSessionCacheSize()
    {
        return sessionCacheSize;
    }

    public synchronized int getSessionTimeout()
    {
        return sessionTimeoutSeconds;
    }

    public synchronized void setSessionCacheSize(int size) throws IllegalArgumentException
    {
        if (sessionCacheSize == size)
        {
            return;
        }

        if (size < 0)
        {
            throw new IllegalArgumentException("'size' cannot be < 0");
        }

        this.sessionCacheSize = size;

        removeAllExpiredSessions();

        // Immediately remove LRU sessions in excess of the new limit
        if (sessionCacheSize > 0)
        {
            int currentSize = sessionsByID.size();
            if (currentSize > sessionCacheSize)
            {
                Iterator iter = sessionsByID.values().iterator();
                while (iter.hasNext() && currentSize > sessionCacheSize)
                {
                    SessionEntry sessionEntry = iter.next();
                    iter.remove();
                    removeSessionByPeer(sessionEntry);
                    --currentSize;
                }
            }
        }
    }

    public synchronized void setSessionTimeout(int seconds) throws IllegalArgumentException
    {
        if (sessionTimeoutSeconds == seconds)
        {
            return;
        }

        if (seconds < 0)
        {
            throw new IllegalArgumentException("'seconds' cannot be < 0");
        }

        this.sessionTimeoutSeconds = seconds;

        removeAllExpiredSessions();
    }

    private ProvSSLSession accessSession(SessionEntry sessionEntry)
    {
        if (sessionEntry != null)
        {
            ProvSSLSession session = sessionEntry.get();
            if (session != null)
            {
                long currentTimeMillis = System.currentTimeMillis();
                if (!invalidateIfCreatedBefore(sessionEntry, getCreationTimeLimit(currentTimeMillis)))
                {
                    session.accessedAt(currentTimeMillis);
                    return session;
                }
            }

            removeSession(sessionEntry);
        }
        return null;
    }

    private long getCreationTimeLimit(long expiryTimeMillis)
    {
        return sessionTimeoutSeconds < 1 ? Long.MIN_VALUE : (expiryTimeMillis - 1000L * sessionTimeoutSeconds);
    }

    private boolean invalidateIfCreatedBefore(SessionEntry sessionEntry, long creationTimeLimit)
    {
        ProvSSLSession session = sessionEntry.get();
        if (session == null)
        {
            return true;
        }
        if (session.getCreationTime() < creationTimeLimit)
        {
            session.invalidatedBySessionContext();
        }
        return !session.isValid();
    }

    private void processQueue()
    {
        int count = 0;

        SessionEntry sessionEntry;
        while ((sessionEntry = (SessionEntry)sessionsQueue.poll()) != null)
        {
            removeSession(sessionEntry);
            ++count;
        }

        if (count > 0)
        {
            LOG.fine("Processed " + count + " session entries (soft references) from the reference queue");
        }
    }

    private void removeAllExpiredSessions()
    {
        processQueue();

        long creationTimeLimit = getCreationTimeLimit(System.currentTimeMillis());

        Iterator iter = sessionsByID.values().iterator();
        while (iter.hasNext())
        {
            SessionEntry sessionEntry = iter.next();
            if (invalidateIfCreatedBefore(sessionEntry, creationTimeLimit))
            {
                iter.remove();
                removeSessionByPeer(sessionEntry);
            }
        }
    }

    private void removeSession(SessionEntry sessionEntry)
    {
        mapRemove(sessionsByID, sessionEntry.getSessionID(), sessionEntry);

        removeSessionByPeer(sessionEntry);
    }

    private boolean removeSessionByPeer(SessionEntry sessionEntry)
    {
        return mapRemove(sessionsByPeer, sessionEntry.getPeerKey(), sessionEntry);
    }

    private static String makePeerKey(ProvSSLSession session)
    {
        return session == null ? null : makePeerKey(session.getPeerHost(), session.getPeerPort());
    }

    private static String makePeerKey(String hostName, int port)
    {
        return (hostName == null || port < 0) ? null : (hostName + ':' + Integer.toString(port)).toLowerCase(Locale.ENGLISH);
    }

    private static SessionID makeSessionID(byte[] sessionID)
    {
        return TlsUtils.isNullOrEmpty(sessionID) ? null : new SessionID(sessionID);
    }

    private static  void mapAdd(Map map, K key, V value)
    {
        if (map == null || value == null)
        {
            throw new NullPointerException();
        }
        if (key != null)
        {
            map.put(key, value);
        }
    }

    private static  V mapGet(Map map, K key)
    {
        if (map == null)
        {
            throw new NullPointerException();
        }
        return key == null ? null : map.get(key);
    }

    private static  V mapRemove(Map map, K key)
    {
        if (map == null)
        {
            throw new NullPointerException();
        }
        return key == null ? null : map.remove(key);
    }

    private static  boolean mapRemove(Map map, K key, V value)
    {
        if (map == null || value == null)
        {
            throw new NullPointerException();
        }
        if (key != null)
        {
            // TODO[jsse] From 1.8 there is a 2-argument remove method to accomplish this 
            V removed = map.remove(key);
            if (removed == value)
            {
                return true;
            }
            if (removed != null)
            {
                map.put(key, removed);
            }
        }
        return false;
    }

    private static final class SessionEntry
        extends SoftReference
    {
        private final SessionID sessionID;
        private final String peerKey;

        SessionEntry(SessionID sessionID, ProvSSLSession session, ReferenceQueue queue)
        {
            super(session, queue);

            if (sessionID == null || session == null || queue == null)
            {
                throw new NullPointerException();
            }

            this.sessionID = sessionID;
            this.peerKey = makePeerKey(session);
        }

        public String getPeerKey()
        {
            return peerKey;
        }

        public SessionID getSessionID()
        {
            return sessionID;
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy