All Downloads are FREE. Search and download functionalities are using the official Maven repository.

bt.protocol.StandardBittorrentProtocol Maven / Gradle / Ivy

There is a newer version: 1.10
Show newest version
/*
 * Copyright (c) 2016—2017 Andrei Tomashpolskiy and individual contributors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package bt.protocol;

import bt.BtException;
import bt.metainfo.TorrentId;
import bt.module.MessageHandlers;
import bt.net.PeerId;
import bt.protocol.handler.BitfieldHandler;
import bt.protocol.handler.CancelHandler;
import bt.protocol.handler.ChokeHandler;
import bt.protocol.handler.HaveHandler;
import bt.protocol.handler.InterestedHandler;
import bt.protocol.handler.MessageHandler;
import bt.protocol.handler.NotInterestedHandler;
import bt.protocol.handler.PieceHandler;
import bt.protocol.handler.RequestHandler;
import bt.protocol.handler.UnchokeHandler;
import com.google.inject.Inject;

import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

import static bt.protocol.Protocols.readInt;

public class StandardBittorrentProtocol implements MessageHandler {

    /**
     * BitTorrent message prefix size in bytes.
     *
     * @since 1.0
     */
    public static final int MESSAGE_LENGTH_PREFIX_SIZE = 4;
    /**
     * BitTorrent message ID size in bytes.
     *
     * @since 1.0
     */
    public static final int MESSAGE_TYPE_SIZE = 1;
    /**
     * BitTorrent message prefix size in bytes.
     * Message prefix is a concatenation of message length prefix and message ID.
     *
     * @since 1.0
     */
    public static final int MESSAGE_PREFIX_SIZE = MESSAGE_LENGTH_PREFIX_SIZE + MESSAGE_TYPE_SIZE;

    /**
     * @since 1.0
     */
    public static final int CHOKE_ID = 0;

    /**
     * @since 1.0
     */
    public static final int UNCHOKE_ID = 1;

    /**
     * @since 1.0
     */
    public static final int INTERESTED_ID = 2;

    /**
     * @since 1.0
     */
    public static final int NOT_INTERESTED_ID = 3;

    /**
     * @since 1.0
     */
    public static final int HAVE_ID = 4;

    /**
     * @since 1.0
     */
    public static final int BITFIELD_ID = 5;

    /**
     * @since 1.0
     */
    public static final int REQUEST_ID = 6;

    /**
     * @since 1.0
     */
    public static final int PIECE_ID = 7;

    /**
     * @since 1.0
     */
    public static final int CANCEL_ID = 8;

    private static final String PROTOCOL_NAME = "BitTorrent protocol";

    private static final byte[] PROTOCOL_NAME_BYTES;
    private static final byte[] HANDSHAKE_PREFIX;
    private static final int HANDSHAKE_RESERVED_OFFSET;
    private static final int HANDSHAKE_RESERVED_LENGTH = 8;

    private static final byte[] KEEPALIVE = new byte[]{0,0,0,0};

    static {

        PROTOCOL_NAME_BYTES = PROTOCOL_NAME.getBytes(Charset.forName("ASCII"));
        int protocolNameLength = PROTOCOL_NAME_BYTES.length;
        int prefixLength = 1;

        HANDSHAKE_RESERVED_OFFSET = 1 + protocolNameLength;
        HANDSHAKE_PREFIX = new byte[HANDSHAKE_RESERVED_OFFSET];
        HANDSHAKE_PREFIX[0] = (byte) protocolNameLength;
        System.arraycopy(PROTOCOL_NAME_BYTES, 0, HANDSHAKE_PREFIX, prefixLength, protocolNameLength);
    }

    private Map> handlers;
    private Map> uniqueTypes;
    private Map, MessageHandler> handlersByType;
    private Map, Integer> idMap;

    @Inject
    public StandardBittorrentProtocol(@MessageHandlers Map> extraHandlers) {

        Map> handlers = new HashMap<>();
        handlers.put(CHOKE_ID, new ChokeHandler());
        handlers.put(UNCHOKE_ID, new UnchokeHandler());
        handlers.put(INTERESTED_ID, new InterestedHandler());
        handlers.put(NOT_INTERESTED_ID, new NotInterestedHandler());
        handlers.put(HAVE_ID, new HaveHandler());
        handlers.put(BITFIELD_ID, new BitfieldHandler());
        handlers.put(REQUEST_ID, new RequestHandler());
        handlers.put(PIECE_ID, new PieceHandler());
        handlers.put(CANCEL_ID, new CancelHandler());

        extraHandlers.forEach((messageId, handler) -> {
            if (handlers.containsKey(messageId)) {
                throw new BtException("Duplicate handler for message ID: " + messageId);
            }
            handlers.put(messageId, handler);
        });

        Map, Integer> idMap = new HashMap<>();
        Map, MessageHandler> handlersByType = new HashMap<>();
        Map> uniqueTypes = new HashMap<>();

        handlers.forEach((messageId, handler) -> {

            if (handler.getSupportedTypes().isEmpty()) {
                throw new BtException("No supported types declared in handler: " + handler.getClass().getName());
            } else {
                uniqueTypes.put(messageId, handler.getSupportedTypes().iterator().next());
            }

            handler.getSupportedTypes().forEach(messageType -> {
                    if (idMap.containsKey(messageType)) {
                        throw new BtException("Duplicate handler for message type: " + messageType.getSimpleName());
                    }
                    idMap.put(messageType, messageId);
                    handlersByType.put(messageType, handler);
                }
            );
        });

        this.handlers = handlers;
        this.idMap = idMap;
        this.handlersByType = handlersByType;
        this.uniqueTypes = uniqueTypes;
    }

    @Override
    public Collection> getSupportedTypes() {
        return null;
    }

    @Override
    public final Class readMessageType(ByteBuffer buffer) {

        Objects.requireNonNull(buffer);

        if (!buffer.hasRemaining()) {
            return null;
        }

        int position = buffer.position();
        byte first = buffer.get();
        if (first == PROTOCOL_NAME.length()) {
            return Handshake.class;
        }
        buffer.position(position);

        Integer length = readInt(buffer);
        if (length == null) {
            return null;
        } else if (length == 0) {
            return KeepAlive.class;
        }

        if (buffer.hasRemaining()) {
            int messageTypeId = buffer.get();
            Class messageType;

            messageType = uniqueTypes.get(messageTypeId);
            if (messageType == null) {
                MessageHandler handler = handlers.get(messageTypeId);
                if (handler == null) {
                    throw new InvalidMessageException("Unknown message type ID: " + messageTypeId);
                }
                messageType = handler.readMessageType(buffer);
            }
            return messageType;
        }

        return null;
    }

    @Override
    public final int decode(DecodingContext context, ByteBuffer buffer) {

        Objects.requireNonNull(context);
        Objects.requireNonNull(buffer);

        if (!buffer.hasRemaining()) {
            return 0;
        }

        int position = buffer.position();
        Class messageType = readMessageType(buffer);
        if (messageType == null) {
            return 0;
        }

        if (Handshake.class.equals(messageType)) {
            buffer.position(position);
            return decodeHandshake(context, buffer);
        }
        if (KeepAlive.class.equals(messageType)) {
            context.setMessage(KeepAlive.instance());
            return KEEPALIVE.length;
        }

        MessageHandler handler = Objects.requireNonNull(handlersByType.get(messageType));
        buffer.position(position);
        return handler.decode(context, buffer);
    }

    @SuppressWarnings("unchecked")
    @Override
    public final boolean encode(EncodingContext context, Message message, ByteBuffer buffer) {

        Objects.requireNonNull(buffer);
        Integer messageId = idMap.get(Objects.requireNonNull(message).getClass());

        if (Handshake.class.equals(message.getClass())) {
            Handshake handshake = (Handshake) message;
            return writeHandshake(buffer, handshake.getReserved(), handshake.getTorrentId(), handshake.getPeerId());
        }
        if (KeepAlive.class.equals(message.getClass())) {
            return writeKeepAlive(buffer);
        }

        if (messageId == null) {
            throw new InvalidMessageException("Unknown message type: " + message.getClass().getSimpleName());
        }
        return ((MessageHandler)handlers.get(messageId)).encode(context, message, buffer);
    }

    // keep-alive: 
    private static boolean writeKeepAlive(ByteBuffer buffer) {
        if (buffer.remaining() < KEEPALIVE.length) {
            return false;
        }
        buffer.put(KEEPALIVE);
        return true;
    }

    // handshake: 
    private static boolean writeHandshake(ByteBuffer buffer, byte[] reserved, TorrentId torrentId, PeerId peerId) {

        if (reserved.length != HANDSHAKE_RESERVED_LENGTH) {
            throw new InvalidMessageException("Invalid reserved bytes: expected " + HANDSHAKE_RESERVED_LENGTH
                    + " bytes, received " + reserved.length);
        }

        int length = HANDSHAKE_PREFIX.length + HANDSHAKE_RESERVED_LENGTH +
                TorrentId.length() + PeerId.length();
        if (buffer.remaining() < length) {
            return false;
        }

        buffer.put(HANDSHAKE_PREFIX);
        buffer.put(reserved);
        buffer.put(torrentId.getBytes());
        buffer.put(peerId.getBytes());

        return true;
    }

    private static int decodeHandshake(DecodingContext context, ByteBuffer buffer) {

        int consumed = 0;
        int offset = HANDSHAKE_RESERVED_OFFSET;
        int length = HANDSHAKE_RESERVED_LENGTH + TorrentId.length() + PeerId.length();
        int limit = offset + length;

        if (buffer.remaining() >= limit) {

            buffer.get(); // skip message ID

            byte[] protocolNameBytes = new byte[PROTOCOL_NAME.length()];
            buffer.get(protocolNameBytes);
            if (!Arrays.equals(PROTOCOL_NAME_BYTES, protocolNameBytes)) {
                throw new InvalidMessageException("Unexpected protocol name (decoded with ASCII): " +
                        new String(protocolNameBytes, Charset.forName("ASCII")));
            }

            byte[] reserved = new byte[HANDSHAKE_RESERVED_LENGTH];
            buffer.get(reserved);

            byte[] infoHash = new byte[TorrentId.length()];
            buffer.get(infoHash);

            byte[] peerId = new byte[PeerId.length()];
            buffer.get(peerId);

            context.setMessage(new Handshake(reserved, TorrentId.fromBytes(infoHash), PeerId.fromBytes(peerId)));
            consumed = limit;
        }

        return consumed;
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy