com.moomoo.openapi.MMAPI_Conn Maven / Gradle / Ivy
Go to download
Show more of this group Show more artifacts with this name
Show all versions of moomoo-api Show documentation
Show all versions of moomoo-api Show documentation
Moomoo OpenAPI quantitative transaction interface for Java.
The newest version!
package com.moomoo.openapi;
import com.moomoo.openapi.pb.*;
import com.google.protobuf.GeneratedMessageV3;
import com.google.protobuf.InvalidProtocolBufferException;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.time.Clock;
import java.util.*;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicInteger;
class SentProtoData {
ProtoHeader header;
long sentTime;
Semaphore sem;
SentProtoData(boolean isSync) {
if (isSync)
sem = new Semaphore(0);
}
}
/**
* 表示和OpenD的一条连接
*/
public class MMAPI_Conn implements AutoCloseable, ConnHandler {
public static final int INIT_FAIL = 100;
public static final int REPLY_TIMEOUT = 12 * 1000; //12s
public static final int CONNECT_TIMEOUT = 10 * 1000;
protected long localConnID;
private MMSPI_Conn connSpi;
final private Object connSpiLock = new Object();
String clientID = "";
int clientVer = 0;
KeyPair rsaKeyPair = null;
AesCbcCipher cipher;
AtomicInteger nextPacketSN = new AtomicInteger(1);
ConnStatus connStatus = ConnStatus.START;
String ip = "";
int port;
boolean isEncrypt = false;
HashMap sentProtoMap = new HashMap<>();
SimpleBuffer readBuf = new SimpleBuffer(64*1024);
long connID;
long userID;
String aesKey = "";
String aesCBIV = "";
int keepAliveInterval = 8 * 1000; //心跳间隔,单位ms
long lastKeepAliveTime = 0; //上次心跳时间,单位ms
public MMAPI_Conn() {
}
/**
* 设置连接回调
* @param callback 回调
*/
public void setConnSpi(MMSPI_Conn callback) {
synchronized (connSpiLock) {
connSpi = callback;
}
}
/***
* 关闭连接
*/
public void close() {
synchronized (this) {
if (connStatus != ConnStatus.START && connStatus != ConnStatus.CLOSED) {
NetManager.getInstance().close(localConnID);
} else {
connStatus = ConnStatus.CLOSED;
}
}
}
/***
* 断开与OpenD的连接
* @return 是否成功
* @deprecated 已经废弃,统一使用close关闭
*/
@Deprecated
public boolean disconnect() {
close();
return true;
}
/***
* 设置客户端信息,用于记录
* @param clientID 客户端名字
* @param clientVer 客户端版本
*/
public void setClientInfo(String clientID, int clientVer) {
synchronized (this) {
this.clientID = clientID;
this.clientVer = clientVer;
}
}
/***
* 设置加密私钥
* @param key 私钥字符串
*/
public void setRSAPrivateKey(String key) {
byte[] binaryKey = key.getBytes(StandardCharsets.UTF_8);
synchronized (this) {
try{
rsaKeyPair = RsaUtil.loadKeyPairFromArray(binaryKey);
} catch (IOException ex) {
throw new APIError(String.format("Invalid rsa private key: %s", ex.getMessage()), ex);
}
}
}
/**
* 连接并初始化
* @param ip 地址
* @param port 端口
* @param isEnableEncrypt 启用加密
* @return bool 是否启动了执行,不代表连接结果,结果通过OnInitConnect回调
*/
public boolean initConnect(String ip, int port, boolean isEnableEncrypt) {
if (ip == null || ip.isEmpty()) {
throw new IllegalArgumentException("ip is invalid");
}
if (port <= 0) {
throw new IllegalArgumentException("port is invalid");
}
synchronized (this) {
this.ip = ip;
this.port = port;
this.isEncrypt = isEnableEncrypt;
this.localConnID = NetManager.getInstance().connect(new InetSocketAddress(ip, port), CONNECT_TIMEOUT, this);
}
return true;
}
/**
* 此连接的连接ID,连接的唯一标识,InitConnect协议返回,没有初始化前为0
* @return 连接id
*/
public long getConnectID() {
synchronized (this) {
return this.connID;
}
}
public long getLocalConnID() {
synchronized (this) {
return this.localConnID;
}
}
/**
* 发送协议数据
* @param protoID 协议id,@see ProtoID
* @param req proto数据包
* @return 数据包序列号
*/
protected int sendProto(int protoID, GeneratedMessageV3 req) {
synchronized (this) {
MMAPI.netEventListener.onBeginSend(this, protoID);
if (connStatus != ConnStatus.CONNECTED && connStatus != ConnStatus.READY)
return 0;
byte[] body = req.toByteArray();
byte[] bodySha1;
try {
bodySha1 = SHA1Util.calc(body);
if (protoID == ProtoID.INIT_CONNECT) {
if (rsaKeyPair != null) {
body = RsaUtil.encrypt(body, rsaKeyPair.getPublic());
}
} else if (isEncrypt && cipher != null) {
body = cipher.encrypt(body);
}
} catch (NoSuchAlgorithmException ex) {
throw new APIError(String.format("Calc body sha1 fail: %s", ex.getMessage()), ex);
} catch (GeneralSecurityException ex) {
throw new APIError(String.format("RSA encrypt fail: %s", ex.getMessage()), ex);
}
ProtoHeader header = new ProtoHeader();
header.nProtoID = protoID;
header.nProtoFmtType = (byte)0;
header.nProtoVer = 0;
header.nSerialNo = nextPacketSN.getAndIncrement();
header.nBodyLen = body.length;
header.arrBodySHA1 = bodySha1;
SentProtoData sentData = new SentProtoData(false);
sentData.header = header;
sentData.sentTime = Clock.systemUTC().millis();
sentProtoMap.put(header.nSerialNo, sentData);
byte[] buffer = new byte[ProtoHeader.HEADER_SIZE + body.length];
header.write(buffer);
System.arraycopy(body, 0, buffer, ProtoHeader.HEADER_SIZE, body.length);
NetManager.getInstance().send(localConnID, buffer);
MMAPI.netEventListener.onEndSend(this, protoID, header.nSerialNo);
return header.nSerialNo;
}
}
/**
* 发送的协议收到了返回
* @param replyType 表示返回数据的错误类型
* @param protoHeader 包头
* @param data 包体
*/
protected void onReply(ReqReplyType replyType, ProtoHeader protoHeader, byte[] data) {
}
/**
* 收到OpenD的推送
* @param protoHeader 包头
* @param data 包体
*/
protected void onPush(ProtoHeader protoHeader, byte[] data) {
}
protected void onInitConnect(long errCode, String desc) {
synchronized (this) {
if (errCode == 0 && connStatus == ConnStatus.CONNECTED) {
connStatus = ConnStatus.READY;
} else {
close();
}
}
synchronized (connSpiLock) {
if (connSpi != null) {
connSpi.onInitConnect(this, errCode, desc);
}
}
}
@Override
public final void onConnect(long connID, ConnErr err, String msg) {
MMAPI.netEventListener.onConnect(connID, err, msg);
synchronized (this) {
if (err == ConnErr.OK) {
connStatus = ConnStatus.CONNECTED;
sendInitConnect();
return;
} else {
connStatus = ConnStatus.CLOSED;
}
}
long errCode = makeInitConnectErrCode(connErrToConnectFailType(err), 0);
onInitConnect(errCode, msg);
}
@Override
public final void onDisConnect(long connID, ConnErr err, String msg) {
synchronized (this) {
connStatus = ConnStatus.CLOSED;
Set> items = sentProtoMap.entrySet();
for (Iterator> iter = items.iterator(); iter.hasNext(); ) {
Map.Entry item = iter.next();
SentProtoData protoData = item.getValue();
handleReplyPacket(ReqReplyType.DisConnect, protoData.header, null, false);
iter.remove();
}
}
synchronized (connSpiLock) {
if (connSpi != null) {
long errCode = makeInitConnectErrCode(connErrToConnectFailType(err), 0);
connSpi.onDisconnect(this, errCode);
}
}
}
@Override
public final void onRecv(long connID, byte[] data, int offset, int count) {
synchronized (this) {
for (int remain = count; remain > 0; ) {
int readCount = readBuf.append(data, offset, remain);
offset += readCount;
remain -= readCount;
handleReadBuf();
}
}
}
@Override
public void onTick() {
long now = Clock.systemUTC().millis();
synchronized (this) {
if (now - lastKeepAliveTime >= keepAliveInterval && connStatus == ConnStatus.READY) {
sendKeepAlive();
lastKeepAliveTime = now;
}
Set> items = sentProtoMap.entrySet();
for (Iterator> iter = items.iterator(); iter.hasNext(); ) {
Map.Entry item = iter.next();
SentProtoData protoData = item.getValue();
if (now - protoData.sentTime > REPLY_TIMEOUT) {
handleReplyPacket(ReqReplyType.Timeout, protoData.header, null, false);
iter.remove();
}
}
}
}
private void handleReadBuf() {
while (ProtoHeader.HEADER_SIZE <= readBuf.length) {
ProtoHeader header = ProtoHeader.parse(readBuf.buf, readBuf.start);
if (header == null)
break;
if (header.nBodyLen + ProtoHeader.HEADER_SIZE > readBuf.limit) {
readBuf.resize(header.nBodyLen + ProtoHeader.HEADER_SIZE);
}
if (ProtoHeader.HEADER_SIZE + header.nBodyLen > readBuf.length)
break;
readBuf.consume(ProtoHeader.HEADER_SIZE);
byte[] body = null;
ReqReplyType replyType = ReqReplyType.SvrReply;
if (header.nBodyLen > 0) {
body = new byte[header.nBodyLen];
System.arraycopy(readBuf.buf, readBuf.start, body, 0, header.nBodyLen);
try {
if (isEncrypt) {
if (header.nProtoID == ProtoID.INIT_CONNECT) {
body = RsaUtil.decrypt(body, rsaKeyPair.getPrivate());
} else {
body = cipher.decrypt(body);
}
}
} catch (Exception ex) {
body = null;
replyType = ReqReplyType.Invalid;
} finally {
readBuf.consume(header.nBodyLen);
}
}
MMAPI.netEventListener.onBeginRecv(this, header.nProtoID, header.nSerialNo);
try {
if (ProtoID.isPushProto(header.nProtoID)) {
if (replyType == ReqReplyType.SvrReply) {
handlePushPacket(header, body);
}
} else {
handleReplyPacket(replyType, header, body, true);
}
}
finally {
MMAPI.netEventListener.onEndRecv(this, header.nProtoID, header.nSerialNo);
}
}
}
private void handleReplyPacket(ReqReplyType replyType, ProtoHeader header, byte[] body, boolean isRemoveFromSent) {
SentProtoData sentData = sentProtoMap.getOrDefault(header.nSerialNo, null);
if (sentData == null)
return;
if (sentData.header.nProtoID != header.nProtoID)
return;
if (isRemoveFromSent)
sentProtoMap.remove(header.nSerialNo);
if (header.nProtoID == ProtoID.INIT_CONNECT)
handleInitConnectRsp(replyType, header, body);
else
onReply(replyType, header, body);
}
private int sendInitConnect() {
InitConnect.C2S.Builder c2s = InitConnect.C2S.newBuilder().setClientVer(clientVer)
.setClientID(clientID)
.setRecvNotify(true)
.setPushProtoFmt(Common.ProtoFmt.ProtoFmt_Protobuf_VALUE)
.setProgrammingLanguage("Java");
if (isEncrypt)
c2s.setPacketEncAlgo(Common.PacketEncAlgo.PacketEncAlgo_AES_CBC_VALUE);
else
c2s.setPacketEncAlgo(Common.PacketEncAlgo.PacketEncAlgo_None_VALUE);
InitConnect.Request req = InitConnect.Request.newBuilder().setC2S(c2s).build();
return sendProto(ProtoID.INIT_CONNECT, req);
}
private int sendKeepAlive() {
KeepAlive.C2S c2s = KeepAlive.C2S.newBuilder().setTime(Clock.systemUTC().millis() / 1000).build();
KeepAlive.Request req = KeepAlive.Request.newBuilder().setC2S(c2s).build();
return sendProto(ProtoID.KEEP_ALIVE, req);
}
private void handleInitConnectRsp(ReqReplyType replyType, ProtoHeader header, byte[] body) {
InitConnect.Response rsp;
long errCode = 0;
String desc = "";
if (replyType == ReqReplyType.SvrReply) {
try {
rsp = InitConnect.Response.parseFrom(body);
if (rsp.getRetType() == Common.RetType.RetType_Succeed_VALUE) {
synchronized (this) {
connID = rsp.getS2C().getConnID();
userID = rsp.getS2C().getLoginUserID();
keepAliveInterval = rsp.getS2C().getKeepAliveInterval() * 1000 * 4 / 5;
aesKey = rsp.getS2C().getConnAESKey();
aesCBIV = rsp.getS2C().getAesCBCiv();
cipher = new AesCbcCipher(aesKey.getBytes(StandardCharsets.UTF_8),
aesCBIV.getBytes(StandardCharsets.UTF_8));
}
} else {
errCode = makeInitConnectErrCode(INIT_FAIL, InitFailType.OPENDREJECT.getCode());
if (rsp.getRetMsg() != null && !rsp.getRetMsg().isEmpty()) {
desc = rsp.getRetMsg();
} else {
desc = String.format("retType=%d", rsp.getRetType());
}
}
} catch (InvalidProtocolBufferException e) {
errCode = makeInitConnectErrCode(INIT_FAIL, InitFailType.OPENDREJECT.getCode());
desc = String.format("Parse packet fail, serialNO=%d", header.nSerialNo);
}
} else if (replyType == ReqReplyType.Timeout){
errCode = makeInitConnectErrCode(INIT_FAIL, InitFailType.TIMEOUT.getCode());
} else if (replyType == ReqReplyType.DisConnect) {
errCode = makeInitConnectErrCode(INIT_FAIL, InitFailType.DISCONNECT.getCode());
} else if (replyType == ReqReplyType.Invalid) {
errCode = makeInitConnectErrCode(INIT_FAIL, InitFailType.UNKNOW.getCode());
desc = String.format("Invalid packet body, serialNO=%d", header.nSerialNo);
}
onInitConnect(errCode, desc);
}
private void handlePushPacket(ProtoHeader header, byte[] body) {
onPush(header, body);
}
private long makeInitConnectErrCode(int high, int low) {
long errCode;
errCode = (long)high << 32;
errCode |= low;
return errCode;
}
private int connErrToConnectFailType(ConnErr err) {
switch (err) {
case OK:
return ConnectFailType.NONE.getCode();
case CONNECT_FAIL:
case CONNECT_TIMEOUT:
return ConnectFailType.CONNECTFAILED.getCode();
case SEND_FAIL:
return ConnectFailType.SENDFAILED.getCode();
case READ_FAIL:
return ConnectFailType.RECVFAILED.getCode();
}
return ConnectFailType.UNKNOWN.getCode();
}
public ConnStatus getConnStatus() {
synchronized (this) {
return connStatus;
}
}
}