All Downloads are FREE. Search and download functionalities are using the official Maven repository.

com.opdar.gulosity.connection.MysqlConnection Maven / Gradle / Ivy

The newest version!
package com.opdar.gulosity.connection;

import com.opdar.gulosity.base.Constants;
import com.opdar.gulosity.connection.entity.Column;
import com.opdar.gulosity.connection.parser.Body;
import com.opdar.gulosity.connection.parser.ColumnParser;
import com.opdar.gulosity.connection.parser.RowParser;
import com.opdar.gulosity.connection.protocol.ErrorProtocol;
import com.opdar.gulosity.connection.protocol.HeaderProtocol;
import com.opdar.gulosity.entity.MysqlAuthInfoEntity;
import com.opdar.gulosity.utils.BufferUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * Created by Shey on 2016/8/19.
 */
public class MysqlConnection {
    private MysqlAuthInfoEntity authInfo;
    private SocketChannel channel;
    private int soTimeout = 30 * 1000;
    private int receiveBufferSize = 16 * 1024;
    private int sendBufferSize = 16 * 1024;
    private long connectionId = -1;
    private int charsetNumber = 33;//utf-8
    private byte[] scrumble;
    private String serverVersion = "";
    private AtomicBoolean connected = new AtomicBoolean();
    private Logger logger = LoggerFactory.getLogger(getClass());

    public AtomicBoolean getConnected() {
        return connected;
    }

    public MysqlConnection(MysqlAuthInfoEntity authInfo) {
        this.authInfo = authInfo;
    }

    public MysqlAuthInfoEntity getAuthInfo() {
        return authInfo;
    }

    public SocketChannel getChannel() {
        return channel;
    }

    public void connect() throws IOException {
        this.channel = SocketChannel.open();
        channel.socket().setKeepAlive(true);
        channel.socket().setReuseAddress(true);
        channel.socket().setSoTimeout(soTimeout);
        channel.socket().setTcpNoDelay(true);
        channel.socket().setSendBufferSize(sendBufferSize);
        channel.connect(authInfo.getAddress());
        Body body = Body.get(channel);
        if (body.getState() < 0) {
            if (body.getState() == Constants.MYSQL.ERR_PACKET) {
                ErrorProtocol error = new ErrorProtocol();
                error.fromBytes(body.getBody());
                throw new IOException("handshake exception:\n" + error.toString());
            } else if (body.getState() == Constants.MYSQL.EOF) {
                throw new IOException("Unexpected EOF packet at handshake phase.");
            } else {
                throw new IOException("unpexpected packet with field_count=" + body.getState());
            }
        }
        initHandshakeV10(body.getBody());
        auth411(body.getHeader());
        body = Body.get(channel);
        if (body.getState() < 0) {
            if (body.getState() == -1) {
                ErrorProtocol err = new ErrorProtocol();
                err.fromBytes(body.getBody());
                throw new IOException("Error When doing Client Authentication:" + err.toString());
            } else if (body.getState() == -2) {
                throw new IOException("Not support old password.");
            } else {
                throw new IOException("unpexpected packet with field_count=" + body.getState());
            }
        } else {
            if (connected.compareAndSet(false, true)) {
                logger.info("Auth Success.");
            }
        }
    }

    //CLIENT_PROTOCOL_41
    private void auth411(HeaderProtocol header) throws IOException {
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        int clientFlag = 1 | 4 | 512 | 8192 | 32768;
        BufferUtils.writeInt(clientFlag, out);
        int maxPackageLength = 1 << 24;
        BufferUtils.writeInt(maxPackageLength, out);
        out.write(this.charsetNumber);
        //填充00
        out.write(new byte[23]);

        out.write(authInfo.getUserName().getBytes());
        out.write(0x00);
        if (authInfo.getPassWord() == null || authInfo.getPassWord().equals("")) {
            out.write(0x00);
        } else {
            //密码生成
            try {
                byte[] encryptedPassword = scramble411(authInfo.getPassWord().getBytes(), scrumble);
                BufferUtils.writeLength(encryptedPassword, out);
            } catch (NoSuchAlgorithmException e) {
                throw new RuntimeException("加密失败", e);
            }
        }

        //初始化数据库
        if (authInfo.getDatabaseName() != null) {
            out.write(authInfo.getDatabaseName().getBytes());
            out.write(0x00);
        }
        byte[] auth = out.toByteArray();
        HeaderProtocol h = new HeaderProtocol();
        h.setBodyLength(auth.length);
        h.setSequence((byte) (header.getSequence() + 1));
        channel.write(new ByteBuffer[]{ByteBuffer.wrap(h.toBytes()), ByteBuffer.wrap(auth)});
    }

    //SHA1( password ) XOR SHA1( "20-bytes random data from server"  SHA1( SHA1( password ) ) )
    public byte[] scramble411(byte[] pass, byte[] scrumble) throws NoSuchAlgorithmException {
        MessageDigest md = MessageDigest.getInstance("SHA-1");
        //2次加密
        byte[] pass1 = md.digest(pass);
        md.reset();
        byte[] pass2 = md.digest(pass1);
        md.reset();
        //加混淆盐
        md.update(scrumble);
        //生成新密码
        byte[] pass3 = md.digest(pass2);
        for (int i = 0; i < pass3.length; i++) {
            //XOR
            pass3[i] = (byte) (pass3[i] ^ pass1[i]);
        }
        return pass3;
    }

    private void initHandshakeV10(ByteBuffer body) {
        //协议版本[1]
        int protocolVersion = body.get();
        //服务器版本[NUL]
        ByteArrayOutputStream out = new ByteArrayOutputStream();
        while (body.hasRemaining()) {
            byte item = body.get();
            if (item == 0x00) {
                break;
            }
            out.write(item);
        }
        byte[] serverVersionBytes = out.toByteArray();
        serverVersion = new String(serverVersionBytes);
        // connection id[4]
        connectionId = (long) (body.get() & 0xFF) | (long) ((body.get() & 0xFF) << 8)
                | (long) ((body.get() & 0xFF) << 16) | (long) ((body.get() & 0xFF) << 24);
        //混淆串前部分 auth-plugin-data-part-1[8]
        byte[] apdp1 = new byte[8];
        body.get(apdp1);
        //filter[1]
        body.get();
        //capabilities[2]
        int capabilitieflags = (body.get() & 0xFF) | ((body.get() & 0xFF) << 8);
        // 获取charset[1]
        int charsetNumber = body.get();
        // 获取服务器状态status flags[2]
        int statusFlags = (body.get() & 0xFF) | ((body.get() & 0xFF) << 8);
        byte[] capabilitieflagsHigher = new byte[2];
        body.get(capabilitieflagsHigher);
        // 保留位 全部以00填充
        body.get(new byte[10]);
        int authDataLength = body.get();
        // 混淆穿后部分[12] 13字节最后位为00 意为结束字符串
        byte[] apdp2 = new byte[12];
        body.get(apdp2);
        //合并混淆串,用来作认证
        scrumble = new byte[apdp1.length + apdp2.length];
        System.arraycopy(apdp1, 0, scrumble, 0, apdp1.length);
        System.arraycopy(apdp2, 0, scrumble, apdp1.length, apdp2.length);

    }
    public void close() {
        try {
            channel.close();
        } catch (IOException e) {
            e.printStackTrace();
        }finally {
            connected.set(false);
        }
    }

    public void waitConnect() {
        while (!connected.get()) {
            //wait
        }
    }
}




© 2015 - 2024 Weber Informatics LLC | Privacy Policy