All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.subgraph.orchid.connections.ConnectionHandshakeV3 Maven / Gradle / Ivy

package com.subgraph.orchid.connections;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.GeneralSecurityException;
import java.security.Principal;
import java.security.PublicKey;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPublicKey;

import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;

import com.subgraph.orchid.Cell;
import com.subgraph.orchid.ConnectionHandshakeException;
import com.subgraph.orchid.ConnectionIOException;

public class ConnectionHandshakeV3 extends ConnectionHandshake {

	private X509Certificate linkCertificate;
	private X509Certificate identityCertificate;
	
	ConnectionHandshakeV3(ConnectionImpl connection, SSLSocket socket) {
		super(connection, socket);
	}

	void runHandshake() throws IOException, InterruptedException, ConnectionIOException {
		sendVersions(3);
		receiveVersions();
		recvCerts();
		recvAuthChallengeAndNetinfo();
		verifyCertificates();
		sendNetinfo();
	}
	
	void recvCerts() throws ConnectionHandshakeException  {
		final Cell cell = expectCell(Cell.CERTS);
		final int ncerts = cell.getByte();
		if(ncerts != 2) {
			throw new ConnectionHandshakeException("Expecting 2 certificates and got "+ ncerts);
		}

		linkCertificate = null;
		identityCertificate = null;
		
		for(int i = 0; i < ncerts; i++) {
			int type = cell.getByte();
			if(type == 1) {
				linkCertificate = testAndReadCertificate(cell, linkCertificate, "Link (type = 1)");
			} else if(type == 2) {
				identityCertificate = testAndReadCertificate(cell, identityCertificate, "Identity (type = 2)");
			} else {
				throw new ConnectionHandshakeException("Unexpected certificate type = "+ type + " in CERTS cell");
			}
		}
		
	}
	
	RSAPublicKey getConnectionPublicKey() {
		try {
			javax.security.cert.X509Certificate[] chain = socket.getSession().getPeerCertificateChain();
			return (RSAPublicKey) chain[0].getPublicKey();
		} catch (SSLPeerUnverifiedException e) {
			return null;
		}
	}
	

	private X509Certificate testAndReadCertificate(Cell cell, X509Certificate currentValue, String type) throws ConnectionHandshakeException {
		if(currentValue == null) {
			return readCertificateFromCell(cell);
		} else {
			throw new ConnectionHandshakeException("Duplicate "+ type + " certificates in CERTS cell");
		}
	}
	
	private X509Certificate readCertificateFromCell(Cell cell) {
		try {
			final CertificateFactory certificateFactory = CertificateFactory.getInstance("X.509");
			final int clen = cell.getShort();
			final byte[] certificateBuffer = new byte[clen];
			cell.getByteArray(certificateBuffer);
			final ByteArrayInputStream bis = new ByteArrayInputStream(certificateBuffer);
			return (X509Certificate) certificateFactory.generateCertificate(bis);
		} catch (CertificateException e) {
			return null;
		}
		
	}
	
	void verifyCertificates() throws ConnectionHandshakeException {
		PublicKey publicKey = identityCertificate.getPublicKey();
		verifyIdentityKey(publicKey);
		RSAPublicKey rsaPublicKey = (RSAPublicKey) publicKey;
		
		if(rsaPublicKey.getModulus().bitLength() != 1024) {
			throw new ConnectionHandshakeException("Invalid RSA modulus length in router identity key");
		}
		
		try {
			identityCertificate.checkValidity();
			identityCertificate.verify(rsaPublicKey);
			linkCertificate.checkValidity();
			linkCertificate.verify(rsaPublicKey);
		} catch (GeneralSecurityException e) {
			throw new ConnectionHandshakeException("Router presented invalid certificate chain in CERTS cell");
		}
	
		RSAPublicKey rsa2 = (RSAPublicKey) linkCertificate.getPublicKey();
		if(!getConnectionPublicKey().getModulus().equals(rsa2.getModulus())) {
			throw new ConnectionHandshakeException("Link certificate in CERTS cell does not match connection certificate");
		}
	}

	void recvAuthChallengeAndNetinfo() throws ConnectionHandshakeException {
		final Cell cell = expectCell(Cell.AUTH_CHALLENGE, Cell.NETINFO);
		if(cell.getCommand() == Cell.NETINFO) {
			processNetInfo(cell);
			return;
		}
		final Cell netinfo = expectCell(Cell.NETINFO);
		processNetInfo(netinfo);
	}
	
	public static boolean sessionSupportsHandshake(SSLSession session) {
		javax.security.cert.X509Certificate cert = getConnectionCertificateFromSession(session);
		if(cert == null) {
			return false;
		}
		return isSelfSigned(cert) || testDName(cert.getSubjectDN()) ||
				testDName(cert.getIssuerDN()) || testModulusLength(cert); 
	}
	
	static private javax.security.cert.X509Certificate getConnectionCertificateFromSession(SSLSession session) {
		try {
			final javax.security.cert.X509Certificate[] chain = session.getPeerCertificateChain();
			return chain[0];
		} catch (SSLPeerUnverifiedException e) {
			return null;
		}
	}
	
	static private boolean isSelfSigned(javax.security.cert.X509Certificate certificate) {
		try {
			certificate.verify(certificate.getPublicKey());
			return true;
		} catch (Exception e) {
			return false;
		}
	}
	
    /*
     * * Some component other than "commonName" is set in the subject or
     *   issuer DN of the certificate.
     *   
     * * The commonName of the subject or issuer of the certificate ends
     *   with a suffix other than ".net".
     */
	static private boolean testDName(Principal dn) {
		final String dname = dn.getName();
		if(dname.indexOf(",") >= 0) {
			return true;
		}
		return !getCN(dname).endsWith(".net");
	}

	/*
	 * * The certificate's public key modulus is longer than 1024 bits.
	 */
	static private boolean testModulusLength(javax.security.cert.X509Certificate cert) {
		if(!(cert.getPublicKey() instanceof RSAPublicKey)) {
			return false;
		}
		final RSAPublicKey rsaPublicKey = (RSAPublicKey) cert.getPublicKey();
		final BigInteger modulus = rsaPublicKey.getModulus();
		return modulus.bitLength() > 1024;
	}

	static private String getCN(String dname) {
		final int idx = dname.indexOf("CN=");
		if(idx == -1) {
			return "";
		}
		final int comma = dname.indexOf(',', idx);
		if(comma == -1) {
			return dname.substring(idx);
		} else {
			return dname.substring(idx, comma);
		}
	}
}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy