All Downloads are FREE. Search and download functionalities are using the official Maven repository.

bt.net.MSEHandshakeProcessor Maven / Gradle / Ivy

There is a newer version: 1.10
Show newest version
package bt.net;

import bt.metainfo.TorrentId;
import bt.protocol.DecodingContext;
import bt.protocol.Handshake;
import bt.protocol.Message;
import bt.protocol.Protocols;
import bt.protocol.crypto.EncryptionPolicy;
import bt.protocol.crypto.MSECipher;
import bt.protocol.handler.MessageHandler;
import bt.torrent.TorrentDescriptor;
import bt.torrent.TorrentRegistry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.ReadableByteChannel;
import java.security.KeyPair;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.util.Arrays;
import java.util.Optional;
import java.util.Random;
import java.util.function.Consumer;
import java.util.function.Supplier;

/**
 * Implements Message Stream Encryption protocol negotiation.
 *
 * This class is not a part of the public API and is subject to change.
 */
class MSEHandshakeProcessor {
    private static final Logger LOGGER = LoggerFactory.getLogger(MSEHandshakeProcessor.class);

    private static final Duration receiveTimeout = Duration.ofSeconds(10);
    private static final Duration waitBetweenReads = Duration.ofSeconds(1);
    private static final int paddingMaxLength = 512;
    private static final byte[] VC_RAW_BYTES = new byte[8];

    private final MSEKeyPairGenerator keyGenerator;
    private final TorrentRegistry torrentRegistry;
    private final MessageHandler messageHandler;
    private final EncryptionPolicy localEncryptionPolicy;
    private final int bufferSize;

    MSEHandshakeProcessor(TorrentRegistry torrentRegistry,
                          MessageHandler messageHandler,
                          EncryptionPolicy localEncryptionPolicy,
                          int bufferSize,
                          int privateKeySize) {
        this.keyGenerator = new MSEKeyPairGenerator(privateKeySize);
        this.torrentRegistry = torrentRegistry;
        this.messageHandler = messageHandler;
        this.localEncryptionPolicy = localEncryptionPolicy;
        this.bufferSize = bufferSize;
    }

    MessageReaderWriter negotiateOutgoing(Peer peer, ByteChannel channel, TorrentId torrentId) throws IOException {
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("Negotiating encryption for outgoing connection: {}", peer);
        }
        /**
         * Blocking steps:
         *
         * 1. A->B: Diffie Hellman Ya, PadA
         * 2. B->A: Diffie Hellman Yb, PadB
         * 3. A->B:
         *  - HASH('req1', S),
         *  - HASH('req2', SKEY) xor HASH('req3', S),
         *  - ENCRYPT(VC, crypto_provide, len(PadC), PadC, len(IA)),
         *  - ENCRYPT(IA)
         * 4. B->A:
         * - ENCRYPT(VC, crypto_select, len(padD), padD),
         * - ENCRYPT2(Payload Stream)
         * 5. A->B: ENCRYPT2(Payload Stream)
         */

        // check if the encryption negotiation can be skipped or preemptively aborted
        EncryptionPolicy peerEncryptionPolicy = peer.getOptions().getEncryptionPolicy();
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("Peer {} has encryption policy: {}", peer, peerEncryptionPolicy);
        }
        assertPolicyIsCompatible(peerEncryptionPolicy);

        ByteBuffer in = ByteBuffer.allocateDirect(bufferSize);
        ByteBuffer out = ByteBuffer.allocateDirect(bufferSize);

        if (peerEncryptionPolicy == EncryptionPolicy.REQUIRE_PLAINTEXT ||
                (peerEncryptionPolicy == EncryptionPolicy.PREFER_PLAINTEXT &&
                        (localEncryptionPolicy == EncryptionPolicy.PREFER_PLAINTEXT
                                || localEncryptionPolicy == EncryptionPolicy.REQUIRE_PLAINTEXT))) {
            // if peer requires plaintext and we support it, then do not negotiate encryption and use plaintext right away
            return createReaderWriter(peer, channel, in, out);
        }

        ByteChannelReader reader = reader(channel);

        // 1. A->B: Diffie Hellman Ya, PadA
        // send our public key
        KeyPair keys = keyGenerator.generateKeyPair();
        out.put(keys.getPublic().getEncoded());
        out.put(getPadding(paddingMaxLength));
        out.flip();
        channel.write(out);
        out.clear();

        // 2. B->A: Diffie Hellman Yb, PadB
        // receive peer's public key
        int phase1Min = keyGenerator.getPublicKeySize();
        int phase1Limit = phase1Min + paddingMaxLength;
        int phase1Read = reader.readBetween(phase1Min, phase1Limit).read(in);
        in.flip();
        BigInteger peerPublicKey = BigIntegers.decodeUnsigned(in, phase1Min);
        in.clear(); // discard the padding, if present

        // calculate shared secret S
        BigInteger S = keyGenerator.calculateSharedSecret(peerPublicKey, keys.getPrivate());

        // 3. A->B:
        MessageDigest digest = getDigest("SHA-1");
        // - HASH('req1', S)
        digest.update("req1".getBytes("ASCII"));
        digest.update(BigIntegers.encodeUnsigned(S, keyGenerator.getPublicKeySize()));
        out.put(digest.digest());
        // - HASH('req2', SKEY) xor HASH('req3', S)
        digest.update("req2".getBytes("ASCII"));
        digest.update(torrentId.getBytes());
        byte[] b1 = digest.digest();
        digest.update("req3".getBytes("ASCII"));
        digest.update(BigIntegers.encodeUnsigned(S, keyGenerator.getPublicKeySize()));
        byte[] b2 = digest.digest();
        out.put(xor(b1, b2));
        // write
        out.flip();
        channel.write(out);
        out.clear();

        byte[] Sbytes = BigIntegers.encodeUnsigned(S, MSEKeyPairGenerator.PUBLIC_KEY_BYTES);
        MSECipher cipher = MSECipher.forInitiator(Sbytes, torrentId);
        ByteChannel encryptedChannel = new EncryptedChannel(channel, cipher.getDecryptionCipher(), cipher.getEncryptionCipher());
        // - ENCRYPT(VC, crypto_provide, len(PadC), PadC, len(IA))
        out.put(VC_RAW_BYTES);
        out.put(getCryptoProvideBitfield(localEncryptionPolicy));
        byte[] padding = getZeroPadding(512);
        out.put(Protocols.getShortBytes(padding.length));
        out.put(padding);
        // - ENCRYPT(IA)
        // do not write IA (initial payload data) for now, wait for encryption negotiation
        out.putShort((short) 0); // IA length = 0
        out.flip();
        encryptedChannel.write(out);
        out.clear();

        // 4. B->A:
        // - ENCRYPT(VC, crypto_select, len(padD), padD)
        byte[] encryptedVC;
        {
            MSECipher throwawayCipher = MSECipher.forInitiator(Sbytes, torrentId);
            try {
                encryptedVC = throwawayCipher.getDecryptionCipher().doFinal(VC_RAW_BYTES);
            } catch (Exception e) {
                throw new RuntimeException("Failed to encrypt VC", e);
            }
        }
        // synchronize on the incoming stream of data
        int phase2Min = encryptedVC.length + 4/*crypto_select*/ + 2/*padding_len*/;
        // account for (phase1Limit - phase1Read) in case the padding from phase1 arrives later than expected
        int phase2Limit = (phase1Limit - phase1Read) + phase2Min + paddingMaxLength;

        int initpos = in.position();
        // use plaintext reader because encryption stream is not synced yet and will potentially produce garbage
        int phase2Read = reader.readBetween(phase2Min, phase2Limit).sync(in, encryptedVC);
        int matchpos = in.position();

        // the rest of the data is known to be encrypted
        ByteChannelReader encryptedReader = reader(encryptedChannel);
        // but we need to align the incoming (decrypting) cipher
        // for the number of encrypted bytes that have already arrived
        // and decrypt these bytes in the incoming data buffer for later processing
        in.limit(initpos + phase2Read);
        {
            cipher.getDecryptionCipher().update(new byte[VC_RAW_BYTES.length]);
            byte[] encryptedData = new byte[in.remaining()];
            in.get(encryptedData);
            in.position(matchpos);
            byte[] decryptedData = cipher.getDecryptionCipher().update(encryptedData);
            in.put(decryptedData);
            in.position(matchpos);
        }

        // we may still need to receive some handshake data (e.g. padding), that is arriving later than expected
        if (in.remaining() < (phase2Min - encryptedVC.length)) {
            int lim = in.limit();
            in.limit(in.capacity());
            int read = encryptedReader.readAtLeast(phase2Min - encryptedVC.length)
                    .readNoMoreThan((phase2Min - encryptedVC.length) + paddingMaxLength)
                    .read(in);
            in.position(matchpos);
            in.limit(lim + read);
        }

        byte[] crypto_select = new byte[4];
        in.get(crypto_select);
        EncryptionPolicy negotiatedEncryptionPolicy = selectPolicy(crypto_select, localEncryptionPolicy);
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("Negotiated encryption policy: {}, peer: {}", negotiatedEncryptionPolicy, peer);
        }

        int theirPadding = in.getShort() & 0xFFFF;
        int missing = (theirPadding - in.remaining());
        if (missing > 0) {
            int pos = in.position();
            in.limit(in.capacity());
            encryptedReader.readAtLeast(missing).read(in);
            in.flip();
            in.position(pos);
        }

        // account for the upper layer protocol data that has already arrived
        in.position(in.position() + theirPadding);
        in.compact();
        out.clear();

        // - ENCRYPT2(Payload Stream)
        switch (negotiatedEncryptionPolicy) {
            case REQUIRE_PLAINTEXT:
            case PREFER_PLAINTEXT: {
                return createReaderWriter(peer, channel, in, out);
            }
            case PREFER_ENCRYPTED:
            case REQUIRE_ENCRYPTED: {
                return createReaderWriter(peer, encryptedChannel, in, out);
            }
            default: {
                throw new IllegalStateException("Unknown encryption policy: " + negotiatedEncryptionPolicy.name());
            }
        }
    }

    MessageReaderWriter negotiateIncoming(Peer peer, ByteChannel channel) throws IOException {
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("Negotiating encryption for incoming connection: {}", peer);
        }
        /**
         * Blocking steps:
         *
         * 1. A->B: Diffie Hellman Ya, PadA
         * 2. B->A: Diffie Hellman Yb, PadB
         * 3. A->B:
         *  - HASH('req1', S),
         *  - HASH('req2', SKEY) xor HASH('req3', S),
         *  - ENCRYPT(VC, crypto_provide, len(PadC), PadC, len(IA)),
         *  - ENCRYPT(IA)
         * 4. B->A:
         *  - ENCRYPT(VC, crypto_select, len(padD), padD),
         *  - ENCRYPT2(Payload Stream)
         * 5. A->B: ENCRYPT2(Payload Stream)
         */

        ByteBuffer in = ByteBuffer.allocateDirect(bufferSize);
        ByteBuffer out = ByteBuffer.allocateDirect(bufferSize);
        ByteChannelReader reader = reader(channel);

        // 1. A->B: Diffie Hellman Ya, PadA
        // receive initiator's public key
        // specify lower threshold on the amount of bytes to receive,
        // as we will try to decode plaintext message of an unknown length first
        int phase0Min = 20;
        int phase0Read = reader.readAtLeast(phase0Min).read(in);
        in.flip();
        // try to determine the protocol from the first received bytes
        DecodingContext context = new DecodingContext(peer);
        int consumed = 0;
        try {
             consumed = messageHandler.decode(context, in);
        } catch (Exception e) {
            // ignore
        }
        // TODO: can this be done without knowing the protocol specifics? (KeepAlive can be especially misleading: 0x00 0x00 0x00 0x00)
        if (consumed > 0 && context.getMessage() instanceof Handshake) {
            // decoding was successful, can use plaintext (if supported)
            assertPolicyIsCompatible(EncryptionPolicy.REQUIRE_PLAINTEXT);
            return createReaderWriter(peer, channel, in, out);
        }

        int phase1Min = keyGenerator.getPublicKeySize();
        int phase1Limit = phase1Min + paddingMaxLength;
        in.limit(in.capacity());
        in.position(phase0Read);
        // verify that there is a sufficient amount of bytes to decode peer's public key
        int phase1Read;
        if (phase0Read < phase1Min) {
            phase1Read = reader.readAtLeast(phase1Min - phase0Read).readNoMoreThan(phase1Limit - phase0Read).read(in);
        } else {
            phase1Read = 0;
        }

        in.flip();
        BigInteger peerPublicKey = BigIntegers.decodeUnsigned(in, keyGenerator.getPublicKeySize());
        in.clear(); // discard padding

        // 2. B->A: Diffie Hellman Yb, PadB
        // send our public key
        KeyPair keys = keyGenerator.generateKeyPair();
        out.put(keys.getPublic().getEncoded());
        out.put(getPadding(paddingMaxLength));
        out.flip();
        channel.write(out);
        out.clear();

        // calculate shared secret S
        BigInteger S = keyGenerator.calculateSharedSecret(peerPublicKey, keys.getPrivate());

        // 3. A->B:
        int phase2Min = 20 + 20 + 8 + 4 + 2 + 0 + 2;
        int phase2Limit = 20 + 20 + 8 + 4 + 2 + 512 + 2;
        MessageDigest digest = getDigest("SHA-1");
        // padding from phase 1 may be arriving later than expected, so we need to synchronize
        // on the incoming stream of data, looking for a correct S hash
        byte[] bytes = new byte[20];
        // - HASH('req1', S)
        digest.update("req1".getBytes("ASCII"));
        digest.update(BigIntegers.encodeUnsigned(S, keyGenerator.getPublicKeySize()));
        byte[] req1hash = digest.digest();
        // syncing will also ensure that the peer knows S (otherwise synchronization will fail due to not finding the pattern)
        int phase2Read = reader.readAtLeast(phase2Min)
                .readNoMoreThan(phase2Limit + (phase1Limit - (phase0Read + phase1Read))) // account for padding arriving later
                .sync(in, req1hash);
        in.limit(phase1Read + phase2Read);

        // - HASH('req2', SKEY) xor HASH('req3', S)
        in.get(bytes); // read SKEY/S hash
        TorrentId requestedTorrent = null;
        digest.update("req3".getBytes("ASCII"));
        digest.update(BigIntegers.encodeUnsigned(S, keyGenerator.getPublicKeySize()));
        byte[] b2 = digest.digest();
        for (TorrentId torrentId : torrentRegistry.getTorrentIds()) {
            digest.update("req2".getBytes("ASCII"));
            digest.update(torrentId.getBytes());
            byte[] b1 = digest.digest();
            if (Arrays.equals(xor(b1, b2), bytes)) {
                requestedTorrent = torrentId;
                break;
            }
        }
        // check that torrent is supported and active
        if (requestedTorrent == null) {
            throw new IllegalStateException("Unsupported torrent requested");
        } else {
            Optional descriptor = torrentRegistry.getDescriptor(requestedTorrent);
            if (descriptor.isPresent() && !descriptor.get().isActive()) {
                // don't throw an exception if descriptor is not present -- torrent might be being fetched at the time
                throw new IllegalStateException("Inactive torrent requested: " + requestedTorrent);
            }
        }

        byte[] Sbytes = BigIntegers.encodeUnsigned(S, MSEKeyPairGenerator.PUBLIC_KEY_BYTES);
        MSECipher cipher = MSECipher.forReceiver(Sbytes, requestedTorrent);
        ByteChannel encryptedChannel = new EncryptedChannel(channel, cipher.getDecryptionCipher(), cipher.getEncryptionCipher());
        ByteChannelReader encryptedReader = reader(encryptedChannel);

        // - ENCRYPT(VC, crypto_provide, len(PadC), PadC, len(IA))
        // derypt encrypted leftovers from step #3
        int pos = in.position();
        byte[] leftovers = new byte[in.remaining()];
        in.get(leftovers);
        in.position(pos);
        try {
            in.put(cipher.getDecryptionCipher().update(leftovers));
        } catch (Exception e) {
            throw new RuntimeException("Failed to decrypt leftover bytes: " + leftovers.length);
        }
        in.position(pos);

        byte[] theirVC = new byte[8];
        in.get(theirVC);
        if (!Arrays.equals(VC_RAW_BYTES, theirVC)) {
            throw new IllegalStateException("Invalid VC: "+ Arrays.toString(theirVC));
        }

        byte[] crypto_provide = new byte[4];
        in.get(crypto_provide);
        EncryptionPolicy negotiatedEncryptionPolicy = selectPolicy(crypto_provide, localEncryptionPolicy);
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("Negotiated encryption policy: {}, peer: {}", negotiatedEncryptionPolicy, peer);
        }

        int theirPadding = in.getShort() & 0xFFFF;
        if (theirPadding > 512) {
            // sanity check
            throw new IllegalStateException("Padding is too long: " + theirPadding);
        }
        int position = in.position(); // mark
        // check if the whole padding block has already arrived
        if (in.remaining() < theirPadding) {
            in.limit(in.capacity());
            in.position(phase1Read + phase2Read);
            encryptedReader.readAtLeast(theirPadding - in.remaining() + 2/*IA length*/).read(in);
            in.flip();
            in.position(position); // reset
        }

        in.position(position + theirPadding); // discard padding
        // Initial Payload length (0..65535 bytes)
        int initialPayloadLength = in.getShort() & 0xFFFF; // currently we ignore IA, it will be processed by the upper layer
        in.compact();

        // 4. B->A:
        // - ENCRYPT(VC, crypto_select, len(padD), padD)
        // - ENCRYPT2(Payload Stream)
        out.put(VC_RAW_BYTES);
        out.put(getCryptoProvideBitfield(negotiatedEncryptionPolicy));
        byte[] padding = getZeroPadding(512);
        out.putShort((short) padding.length);
        out.put(padding);
        out.flip();
        encryptedChannel.write(out);
        out.clear();

        switch (negotiatedEncryptionPolicy) {
            case REQUIRE_PLAINTEXT:
            case PREFER_PLAINTEXT: {
                return createReaderWriter(peer, channel, in, out);
            }
            case PREFER_ENCRYPTED:
            case REQUIRE_ENCRYPTED: {
                return createReaderWriter(peer, encryptedChannel, in, out);
            }
            default: {
                throw new IllegalStateException("Unknown encryption policy: " + negotiatedEncryptionPolicy.name());
            }
        }
    }

    private ByteChannelReader reader(ReadableByteChannel channel) {
        return ByteChannelReader.forChannel(channel).withTimeout(receiveTimeout).waitBetweenReads(waitBetweenReads);
    }

    private void assertPolicyIsCompatible(EncryptionPolicy peerEncryptionPolicy) {
        if (!localEncryptionPolicy.isCompatible(peerEncryptionPolicy)) {
            throw new RuntimeException("Encryption policies are incompatible: peer's (" + peerEncryptionPolicy.name()
                    + "), local (" + localEncryptionPolicy.name() + ")");
        }
    }

    private MessageReaderWriter createReaderWriter(Peer peer, ByteChannel channel, ByteBuffer in, ByteBuffer out) {
        Supplier reader = new DefaultMessageReader(peer, channel, messageHandler, in);
        Consumer writer = new DefaultMessageWriter(channel, peer, messageHandler, out);
        return new DelegatingMessageReaderWriter(reader, writer);
    }

    private byte[] getPadding(int length) {
        // todo: use this constructor everywhere in the project
        Random r = new Random();
        byte[] padding = new byte[r.nextInt(length + 1)];
        for (int i = 0; i < padding.length; i++) {
            padding[i] = (byte) r.nextInt(256);
        }
        return padding;
    }

    private byte[] getZeroPadding(int length) {
        // todo: use this constructor everywhere in the project
        Random r = new Random();
        return new byte[r.nextInt(length + 1)];
    }

    private MessageDigest getDigest(String algorithm) {
        try {
            return MessageDigest.getInstance(algorithm);
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
    }

    private byte[] xor(byte[] b1, byte[] b2) {
        if (b1.length != b2.length) {
            throw new IllegalStateException("Lengths do not match: " + b1.length + ", " + b2.length);
        }
        byte[] result = new byte[b1.length];
        for (int i = 0; i < b1.length; i++) {
            result[i] = (byte) (b1[i] ^ b2[i]);
        }
        return result;
    }

    private byte[] getCryptoProvideBitfield(EncryptionPolicy encryptionPolicy) {
        byte[] crypto_provide = new byte[4];
        switch (encryptionPolicy) {
            case REQUIRE_PLAINTEXT: {
                crypto_provide[3] = 1; // only 0x01
                break;
            }
            case PREFER_PLAINTEXT:
            case PREFER_ENCRYPTED: {
                crypto_provide[3] = 3; // both 0x01 and 0x02
                break;
            }
            case REQUIRE_ENCRYPTED: {
                crypto_provide[3] = 2; // only 0x02
                break;
            }
            default: {
                // do nothing, bitfield is all zeros
            }
        }
        return crypto_provide;
    }

    private EncryptionPolicy selectPolicy(byte[] crypto_provide, EncryptionPolicy localEncryptionPolicy) {
        boolean plaintextProvided = (crypto_provide[3] & 0x01) == 0x01;
        boolean encryptionProvided = (crypto_provide[3] & 0x02) == 0x02;

        EncryptionPolicy selected = null;
        if (plaintextProvided || encryptionProvided) {
            switch (localEncryptionPolicy) {
                case REQUIRE_PLAINTEXT: {
                    if (plaintextProvided) {
                        selected = EncryptionPolicy.REQUIRE_PLAINTEXT;
                    }
                    break;
                }
                case PREFER_PLAINTEXT: {
                    selected = plaintextProvided ? EncryptionPolicy.REQUIRE_PLAINTEXT : EncryptionPolicy.REQUIRE_ENCRYPTED;
                    break;
                }
                case PREFER_ENCRYPTED: {
                    selected = encryptionProvided ? EncryptionPolicy.REQUIRE_ENCRYPTED : EncryptionPolicy.REQUIRE_PLAINTEXT;
                    break;
                }
                case REQUIRE_ENCRYPTED: {
                    if (encryptionProvided) {
                        selected = EncryptionPolicy.REQUIRE_ENCRYPTED;
                    }
                    break;
                }
                default: {
                    throw new IllegalStateException("Unknown encryption policy: " + localEncryptionPolicy.name());
                }
            }
        }

        if (selected == null) {
            throw new IllegalStateException("Failed to negotiate the encryption policy: local policy (" +
                    localEncryptionPolicy.name() + "), peer's policy (" + Arrays.toString(crypto_provide) + ")");
        }
        return selected;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy