org.lealone.plugins.mysql.server.MySQLServerConnection Maven / Gradle / Ivy
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.lealone.plugins.mysql.server;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.util.Properties;
import org.lealone.common.exceptions.DbException;
import org.lealone.common.logging.Logger;
import org.lealone.common.logging.LoggerFactory;
import org.lealone.common.util.StringUtils;
import org.lealone.db.ConnectionInfo;
import org.lealone.db.Constants;
import org.lealone.db.result.Result;
import org.lealone.db.session.ServerSession;
import org.lealone.db.value.Value;
import org.lealone.db.value.ValueNull;
import org.lealone.net.AsyncConnection;
import org.lealone.net.NetBuffer;
import org.lealone.net.NetBufferOutputStream;
import org.lealone.net.WritableChannel;
import org.lealone.plugins.mysql.server.handler.AuthPacketHandler;
import org.lealone.plugins.mysql.server.handler.CommandPacketHandler;
import org.lealone.plugins.mysql.server.handler.PacketHandler;
import org.lealone.plugins.mysql.server.protocol.AuthPacket;
import org.lealone.plugins.mysql.server.protocol.EOFPacket;
import org.lealone.plugins.mysql.server.protocol.ErrorPacket;
import org.lealone.plugins.mysql.server.protocol.ExecutePacket;
import org.lealone.plugins.mysql.server.protocol.FieldPacket;
import org.lealone.plugins.mysql.server.protocol.Fields;
import org.lealone.plugins.mysql.server.protocol.HandshakePacket;
import org.lealone.plugins.mysql.server.protocol.OkPacket;
import org.lealone.plugins.mysql.server.protocol.PacketInput;
import org.lealone.plugins.mysql.server.protocol.PacketOutput;
import org.lealone.plugins.mysql.server.protocol.PreparedOkPacket;
import org.lealone.plugins.mysql.server.protocol.ResultSetHeaderPacket;
import org.lealone.plugins.mysql.server.protocol.RowDataPacket;
import org.lealone.plugins.mysql.server.util.PacketUtil;
import org.lealone.server.Scheduler;
import org.lealone.sql.PreparedSQLStatement;
public class MySQLServerConnection extends AsyncConnection {
private static final Logger logger = LoggerFactory.getLogger(MySQLServerConnection.class);
private static final byte[] AUTH_OK = new byte[] { 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0 };
private static final byte[] EMPTY = new byte[0];
private final MySQLServer server;
private final Scheduler scheduler;
private ServerSession session;
private PacketHandler packetHandler;
private AuthPacket authPacket;
private int nextStatementId;
private byte[] seed;
protected MySQLServerConnection(MySQLServer server, WritableChannel channel, Scheduler scheduler) {
super(channel, true);
this.server = server;
this.scheduler = scheduler;
}
public ServerSession getSession() {
return session;
}
public byte[] getSeed() {
return seed;
}
// 客户端连上来后,数据库先发回一个握手包
void handshake(int threadId) {
// 创建一个AuthPacketHandler用来鉴别是否是合法的用户
packetHandler = new AuthPacketHandler(this);
PacketOutput out = getPacketOutput();
HandshakePacket p = HandshakePacket.create(threadId);
scheduler.handle(() -> p.write(out)); // 交给调度器去写,可能通道还没有注册好
// 保存认证数据,不能用restOfScrambleBuff
seed = new byte[p.seed.length + p.authPluginDataPart2.length];
System.arraycopy(p.seed, 0, seed, 0, p.seed.length);
System.arraycopy(p.authPluginDataPart2, 0, seed, p.seed.length, p.authPluginDataPart2.length);
}
public void authenticate(AuthPacket authPacket) {
this.authPacket = authPacket;
try {
String dbName = authPacket.database != null ? authPacket.database
: MySQLServer.DATABASE_NAME;
session = createSession(authPacket, dbName);
String sql = "CREATE ALIAS IF NOT EXISTS CONNECTION_ID DETERMINISTIC FOR "
+ "\"org.lealone.plugins.mysql.sql.expression.MySQLFunction.getConnectionId\"";
session.prepareStatement(sql).executeUpdate();
} catch (Throwable e) {
logger.error("Failed to create session", e);
sendErrorMessage(e);
close();
server.removeConnection(this);
return;
}
// 鉴别成功后创建CommandPacketHandler用来处理各种命令(包括SQL)
packetHandler = new CommandPacketHandler(this);
sendMessage(AUTH_OK);
}
private ServerSession createSession(AuthPacket authPacket, String dbName) {
Properties info = new Properties();
info.put("MODE", "MySQL");
info.put("USER", authPacket.user);
info.put("PASSWORD", StringUtils.convertBytesToHex(getPassword(authPacket)));
info.put("PASSWORD_HASH", "true");
String url = Constants.URL_PREFIX + Constants.URL_TCP + server.getHost() + ":" + server.getPort()
+ "/" + dbName;
ConnectionInfo ci = new ConnectionInfo(url, info);
ci.setSalt(seed);
ci.setRemote(false);
return (ServerSession) ci.createSession();
}
private static byte[] getPassword(AuthPacket authPacket) {
if (authPacket.password == null || authPacket.password.length == 0)
return EMPTY;
return authPacket.password;
}
public void initDatabase(String dbName) {
session = createSession(authPacket, dbName);
}
public void closeStatement(int statementId) {
PreparedSQLStatement command = (PreparedSQLStatement) session.removeCache(statementId, true);
if (command != null) {
command.close();
}
}
public void prepareStatement(String sql) {
PreparedSQLStatement command = session.prepareStatement(sql, -1);
int statementId = ++nextStatementId;
command.setId(statementId);
session.addCache(statementId, command);
PacketOutput out = getPacketOutput();
PreparedOkPacket packet = new PreparedOkPacket();
packet.packetId = 1;
packet.statementId = statementId;
packet.columnsNumber = command.getMetaData().getVisibleColumnCount();
packet.parametersNumber = command.getParameters().size();
packet.write(out);
}
public void executeStatement(ExecutePacket packet) {
PreparedSQLStatement ps = (PreparedSQLStatement) session.getCache((int) packet.statementId);
String sql = ps.getSQL();
executeStatement(ps, sql);
}
public void executeStatement(String sql) {
executeStatement(null, sql);
}
private void executeStatement(PreparedSQLStatement ps, String sql) {
logger.info("execute sql: " + sql);
try {
if (ps == null)
ps = (PreparedSQLStatement) session.prepareSQLCommand(sql, -1);
if (ps.isQuery()) {
Result result = ps.executeQuery(-1).get();
writeQueryResult(result);
} else {
int updateCount = ps.executeUpdate().get();
writeUpdateResult(updateCount);
}
} catch (Throwable e) {
logger.error("Failed to execute statement: " + sql, e);
sendErrorMessage(e);
}
}
private void writeQueryResult(Result result) {
int fieldCount = result.getVisibleColumnCount();
ResultSetHeaderPacket header = PacketUtil.getHeader(fieldCount);
FieldPacket[] fields = new FieldPacket[fieldCount];
EOFPacket eof = new EOFPacket();
byte packetId = 0;
header.packetId = ++packetId;
for (int i = 0; i < fieldCount; i++) {
fields[i] = PacketUtil.getField(result.getColumnName(i).toLowerCase(),
Fields.toMySQLType(result.getColumnType(i)));
fields[i].packetId = ++packetId;
}
eof.packetId = ++packetId;
PacketOutput out = getPacketOutput();
// write header
header.write(out);
// write fields
for (FieldPacket field : fields) {
field.write(out);
}
// write eof
eof.write(out);
// write rows
packetId = eof.packetId;
for (int i = 0; i < result.getRowCount(); i++) {
RowDataPacket row = new RowDataPacket(fieldCount);
if (result.next()) {
Value[] values = result.currentRow();
for (int j = 0; j < fieldCount; j++) {
if (values[j] == ValueNull.INSTANCE) {
row.add(new byte[0]);
} else {
row.add(values[j].toString().getBytes());
}
}
row.packetId = ++packetId;
row.write(out);
}
}
// write last eof
EOFPacket lastEof = new EOFPacket();
lastEof.packetId = ++packetId;
lastEof.write(out);
}
private void writeUpdateResult(int updateCount) {
writeOkPacket(updateCount);
}
public void writeOkPacket() {
writeOkPacket(0);
}
private void writeOkPacket(int updateCount) {
PacketOutput out = getPacketOutput();
OkPacket packet = new OkPacket();
packet.packetId = 1;
packet.affectedRows = updateCount;
packet.serverStatus = 2;
packet.write(out);
}
private final static byte[] encodeString(String src, String charset) {
if (src == null) {
return null;
}
if (charset == null) {
return src.getBytes();
}
try {
return src.getBytes(charset);
} catch (UnsupportedEncodingException e) {
return src.getBytes();
}
}
private void sendErrorMessage(Throwable e) {
if (e instanceof DbException) {
DbException dbe = (DbException) e;
sendErrorMessage(dbe.getErrorCode(), dbe.getMessage());
} else {
sendErrorMessage(DbException.convert(e));
}
}
public void sendErrorMessage(int errno, String msg) {
ErrorPacket err = new ErrorPacket();
err.packetId = 0;
err.errno = errno;
err.message = encodeString(msg, "utf-8");
err.write(getPacketOutput());
}
private PacketOutput getPacketOutput() {
return new PacketOutput(writableChannel, scheduler.getDataBufferFactory());
}
private void sendMessage(byte[] data) {
try (NetBufferOutputStream out = new NetBufferOutputStream(writableChannel, data.length,
scheduler.getDataBufferFactory())) {
out.write(data);
out.flush(false);
} catch (IOException e) {
logger.error("Failed to send message", e);
}
}
private final ByteBuffer packetLengthByteBuffer = ByteBuffer.allocateDirect(4);
@Override
public ByteBuffer getPacketLengthByteBuffer() {
return packetLengthByteBuffer;
}
@Override
public int getPacketLength() {
int length = (packetLengthByteBuffer.get(0) & 0xff);
length |= (packetLengthByteBuffer.get(1) & 0xff) << 8;
length |= (packetLengthByteBuffer.get(2) & 0xff) << 16;
return length;
}
@Override
public void handle(NetBuffer buffer) {
if (!buffer.isOnlyOnePacket()) {
DbException.throwInternalError("NetBuffer must be OnlyOnePacket");
}
try {
int length = buffer.length();
byte[] packet = new byte[length + 4];
packetLengthByteBuffer.get(packet, 0, 4);
packetLengthByteBuffer.clear();
buffer.read(packet, 4, length);
buffer.recycle();
PacketInput input = new PacketInput(packet);
packetHandler.handle(input);
} catch (Throwable e) {
logger.error("Failed to handle packet", e);
sendErrorMessage(e);
}
}
}
© 2015 - 2024 Weber Informatics LLC | Privacy Policy