edu.uvm.ccts.common.rmi.RMISSLClientSocketFactory Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of ccts-common Show documentation
Show all versions of ccts-common Show documentation
A library of useful generic objects and tools consolidated here to simplify all UVM CCTS projects
/*
* Copyright 2015 The University of Vermont and State
* Agricultural College. All rights reserved.
*
* Written by Matthew B. Storer
*
* This file is part of CCTS Common.
*
* CCTS Common is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* CCTS Common is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with CCTS Common. If not, see .
*/
package edu.uvm.ccts.common.rmi;
import edu.uvm.ccts.common.util.SSLUtil;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import javax.net.SocketFactory;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.TrustManagerFactory;
import java.io.IOException;
import java.io.Serializable;
import java.net.Socket;
import java.rmi.server.RMIClientSocketFactory;
import java.security.KeyStore;
import java.util.HashMap;
import java.util.Map;
/**
* Created by mstorer on 8/13/14.
*/
public class RMISSLClientSocketFactory implements RMIClientSocketFactory, Serializable {
private static final Log log = LogFactory.getLog(RMISSLClientSocketFactory.class);
private String path;
public RMISSLClientSocketFactory(String path, char[] passphrase) {
this.path = path;
try {
KeyStore ks = SSLUtil.loadKeyStoreFromResource(path, passphrase);
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(ks, passphrase);
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
tmf.init(ks);
SSLContext ctx = SSLContext.getInstance("TLS");
ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
SocketFactoryRegistry.getInstance().register(path, ctx.getSocketFactory());
} catch (Exception e) {
log.error("caught " + e.getClass().getName() + " - " + e.getMessage(), e);
if (e instanceof RuntimeException) throw (RuntimeException) e;
else throw new RuntimeException(e);
}
}
@Override
public Socket createSocket(String host, int port) throws IOException {
return SocketFactoryRegistry.getInstance().get(path).createSocket(host, port);
}
@Override
public int hashCode() {
return getClass().hashCode();
}
@Override
public boolean equals(Object obj) {
if (obj == this) {
return true;
} else if (obj == null || getClass() != obj.getClass()) {
return false;
}
return true;
}
private static final class SocketFactoryRegistry {
private static SocketFactoryRegistry registry = null;
private static SocketFactoryRegistry getInstance() {
if (registry == null) registry = new SocketFactoryRegistry();
return registry;
}
private Map map;
private SocketFactoryRegistry() {
map = new HashMap();
}
private void register(String key, SocketFactory sf) {
map.put(key, sf);
}
private SocketFactory get(String key) {
return map.get(key);
}
private void deregister(String key) {
map.remove(key);
}
}
}