All Downloads are FREE. Search and download functionalities are using the official Maven repository.

org.tarantool.protocol.ProtoUtils Maven / Gradle / Ivy

package org.tarantool.protocol;

import org.tarantool.Base64;
import org.tarantool.Code;
import org.tarantool.CommunicationException;
import org.tarantool.CountInputStreamImpl;
import org.tarantool.Key;
import org.tarantool.MsgPackLite;
import org.tarantool.TarantoolException;

import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.channels.NonReadableChannelException;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SocketChannel;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;

public abstract class ProtoUtils {

    public static final int LENGTH_OF_SIZE_MESSAGE = 5;

    private static final int DEFAULT_INITIAL_REQUEST_SIZE = 4096;
    private static final String WELCOME = "Tarantool ";

    /**
     * Reads tarantool binary protocol's packet from {@code inputStream}.
     *
     * @param inputStream ready to use input stream
     * @param msgPackLite MessagePack decoder instance
     *
     * @return Nonnull instance of packet
     *
     * @throws IOException in case of any io-error
     */
    public static TarantoolPacket readPacket(InputStream inputStream, MsgPackLite msgPackLite) throws IOException {
        CountInputStreamImpl msgStream = new CountInputStreamImpl(inputStream);

        int size = ((Number) msgPackLite.unpack(msgStream)).intValue();
        long mark = msgStream.getBytesRead();

        Map headers = (Map) msgPackLite.unpack(msgStream);

        Map body = null;
        if (msgStream.getBytesRead() - mark < size) {
            body = (Map) msgPackLite.unpack(msgStream);
        }

        return new TarantoolPacket(headers, body);
    }

    /**
     * Reads a tarantool's binary protocol packet from the reader.
     *
     * @param bufferReader readable channel that have to be in blocking mode
     *                     or instance of {@link ReadableViaSelectorChannel}
     * @param msgPackLite MessagePack decoder instance
     *
     * @return tarantool binary protocol message wrapped by instance of {@link TarantoolPacket}
     *
     * @throws IOException                 if any IO-error occurred during read from the channel
     * @throws CommunicationException      input stream bytes constitute msg pack message in wrong format
     * @throws NonReadableChannelException If this channel was not opened for reading
     */
    public static TarantoolPacket readPacket(ReadableByteChannel bufferReader, MsgPackLite msgPackLite)
        throws CommunicationException, IOException {

        ByteBuffer buffer = ByteBuffer.allocate(LENGTH_OF_SIZE_MESSAGE);
        bufferReader.read(buffer);

        buffer.flip();
        int size = ((Number) msgPackLite.unpack(new ByteBufferBackedInputStream(buffer))).intValue();

        buffer = ByteBuffer.allocate(size);
        bufferReader.read(buffer);

        buffer.flip();
        ByteBufferBackedInputStream msgBytesStream = new ByteBufferBackedInputStream(buffer);
        Object unpackedHeaders = msgPackLite.unpack(msgBytesStream);
        if (!(unpackedHeaders instanceof Map)) {
            //noinspection ConstantConditions
            throw new CommunicationException(
                "Error while unpacking headers of tarantool response: " +
                    "expected type Map but was " +
                    unpackedHeaders != null ? unpackedHeaders.getClass().toString() : "null"
            );
        }
        //noinspection unchecked (checked above)
        Map headers = (Map) unpackedHeaders;

        Map body = null;
        if (msgBytesStream.hasAvailable()) {
            Object unpackedBody = msgPackLite.unpack(msgBytesStream);
            if (!(unpackedBody instanceof Map)) {
                //noinspection ConstantConditions
                throw new CommunicationException(
                    "Error while unpacking body of tarantool response: " +
                        "expected type Map but was " +
                        unpackedBody != null ? unpackedBody.getClass().toString() : "null"
                );
            }
            //noinspection unchecked (checked above)
            body = (Map) unpackedBody;
        }

        return new TarantoolPacket(headers, body);
    }

    /**
     * Connects to a tarantool node described by {@code socket}. Performs an authentication if required
     *
     * @param socket   a socket channel to a tarantool node
     * @param username auth username
     * @param password auth password
     * @param msgPackLite MessagePack encoder / decoder instance
     *
     * @return object with information about a connection/
     *
     * @throws IOException            in case of any IO fails
     * @throws CommunicationException when welcome string is invalid
     * @throws TarantoolException     in case of failed authentication
     */
    public static TarantoolGreeting connect(Socket socket,
                                            String username,
                                            String password,
                                            MsgPackLite msgPackLite) throws IOException {
        byte[] inputBytes = new byte[64];

        InputStream inputStream = socket.getInputStream();
        inputStream.read(inputBytes);

        String firstLine = new String(inputBytes);
        assertCorrectWelcome(firstLine, socket.getRemoteSocketAddress());
        String serverVersion = firstLine.substring(WELCOME.length());

        inputStream.read(inputBytes);
        String salt = new String(inputBytes);
        if (username != null && password != null) {
            ByteBuffer authPacket = createAuthPacket(username, password, salt, msgPackLite);

            OutputStream os = socket.getOutputStream();
            os.write(authPacket.array(), 0, authPacket.remaining());
            os.flush();

            TarantoolPacket responsePacket = readPacket(socket.getInputStream(), msgPackLite);
            assertNoErrCode(responsePacket);
        }

        return new TarantoolGreeting(serverVersion);
    }

    /**
     * Connects to a tarantool node described by {@code socketChannel}. Performs an authentication if required.
     *
     * @param channel  a socket channel to tarantool node. The channel have to be in blocking mode
     * @param username auth username
     * @param password auth password
     * @param msgPackLite MessagePack encoder / decoder instance
     *
     * @return object with information about a connection/
     *
     * @throws IOException            in case of any IO fails
     * @throws CommunicationException when welcome string is invalid
     * @throws TarantoolException     in case of failed authentication
     */
    public static TarantoolGreeting connect(SocketChannel channel,
                                            String username,
                                            String password,
                                            MsgPackLite msgPackLite) throws IOException {
        ByteBuffer welcomeBytes = ByteBuffer.wrap(new byte[64]);
        channel.read(welcomeBytes);

        String firstLine = new String(welcomeBytes.array());
        assertCorrectWelcome(firstLine, channel.getRemoteAddress());
        final String serverVersion = firstLine.substring(WELCOME.length());

        ((Buffer)welcomeBytes).clear();
        channel.read(welcomeBytes);
        String salt = new String(welcomeBytes.array());

        if (username != null && password != null) {
            ByteBuffer authPacket = createAuthPacket(username, password, salt, msgPackLite);
            writeFully(channel, authPacket);

            TarantoolPacket authResponse = readPacket(channel, msgPackLite);
            assertNoErrCode(authResponse);
        }

        return new TarantoolGreeting(serverVersion);
    }

    private static void assertCorrectWelcome(String firstLine, SocketAddress remoteAddress) {
        if (!firstLine.startsWith(WELCOME)) {
            String errMsg = "Failed to connect to node " + remoteAddress.toString() +
                ": Welcome message should starts with tarantool but starts with '" +
                firstLine +
                "'";
            throw new CommunicationException(errMsg, new IllegalStateException("Invalid welcome packet"));
        }
    }

    private static void assertNoErrCode(TarantoolPacket authResponse) {
        Long code = (Long) authResponse.getHeaders().get(Key.CODE.getId());
        if (code != 0) {
            Object error = authResponse.getBody().get(Key.ERROR.getId());
            String errorMsg = error instanceof String ? (String) error : new String((byte[]) error);
            throw new TarantoolException(code, errorMsg);
        }
    }

    public static void writeFully(OutputStream stream, ByteBuffer buffer) throws IOException {
        stream.write(buffer.array());
        stream.flush();
    }

    public static void writeFully(SocketChannel channel, ByteBuffer buffer) throws IOException {
        long code = 0;
        while (buffer.remaining() > 0 && (code = channel.write(buffer)) > -1) {
        }
        if (code < 0) {
            throw new SocketException("write failed code: " + code);
        }
    }

    public static ByteBuffer createAuthPacket(String username,
                                              final String password,
                                              String salt,
                                              MsgPackLite msgPackLite) throws IOException {
        final MessageDigest sha1;
        try {
            sha1 = MessageDigest.getInstance("SHA-1");
        } catch (NoSuchAlgorithmException e) {
            throw new IllegalStateException(e);
        }
        List auth = new ArrayList(2);
        auth.add("chap-sha1");

        byte[] p = sha1.digest(password.getBytes());

        sha1.reset();
        byte[] p2 = sha1.digest(p);

        sha1.reset();
        sha1.update(Base64.decode(salt), 0, 20);
        sha1.update(p2);
        byte[] scramble = sha1.digest();
        for (int i = 0, e = 20; i < e; i++) {
            p[i] ^= scramble[i];
        }
        auth.add(p);

        return createPacket(
            DEFAULT_INITIAL_REQUEST_SIZE, msgPackLite,
            Code.AUTH, 0L, null, Key.USER_NAME, username, Key.TUPLE, auth
        );
    }

    public static ByteBuffer createPacket(MsgPackLite msgPackLite,
                                          Code code,
                                          Long syncId,
                                          Long schemaId,
                                          Object... args) throws IOException {
        return createPacket(DEFAULT_INITIAL_REQUEST_SIZE, msgPackLite, code, syncId, schemaId, args);
    }

    public static ByteBuffer createPacket(int initialRequestSize,
                                          MsgPackLite msgPackLite,
                                          Code code,
                                          Long syncId,
                                          Long schemaId,
                                          Object... args) throws IOException {
        ByteArrayOutputStream bos = new ByteArrayOutputStream(initialRequestSize);
        bos.write(new byte[5]);
        final DataOutputStream ds = new DataOutputStream(bos);
        Map header = new EnumMap<>(Key.class);
        Map body = new EnumMap<>(Key.class);
        header.put(Key.CODE, code);
        header.put(Key.SYNC, syncId);
        if (schemaId != null) {
            header.put(Key.SCHEMA_ID, schemaId);
        }
        if (args != null) {
            for (int i = 0, e = args.length; i < e; i += 2) {
                Object value = args[i + 1];
                body.put((Key) args[i], value);
            }
        }
        msgPackLite.pack(header, ds);
        msgPackLite.pack(body, ds);
        ds.flush();
        ByteBuffer buffer = bos.toByteBuffer();
        buffer.put(0, (byte) 0xce);
        buffer.putInt(1, bos.size() - 5);
        return buffer;
    }

    private static class ByteArrayOutputStream extends java.io.ByteArrayOutputStream {
        public ByteArrayOutputStream(int size) {
            super(size);
        }

        ByteBuffer toByteBuffer() {
            return ByteBuffer.wrap(buf, 0, count);
        }
    }

}




© 2015 - 2025 Weber Informatics LLC | Privacy Policy