1 Star 1 Fork 0

俞寅达/nicenet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
Tcp.cpp 14.84 KB
一键复制 编辑 原始数据 按行查看 历史
俞寅达 提交于 2021-11-17 11:43 . support coroutine
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707
#include "Tcp.h"
#include "Service.h"
#include <asio/asio.hpp>
#include <micro-ecc/uECC.h>
#include "Log.h"
#include "Message.h"
#include "Clock.h"
#include <map>
#include <asio/yield.hpp>
extern "C"
{
extern void *sha3(const void *in, size_t inlen, void *md, int mdlen);
}
namespace nicehero
{
TcpMessageParser& getTcpMessagerParse(const std::type_info& typeInfo)
{
static std::map<const std::type_info*, TcpMessageParser> gTcpMessageParse;
return gTcpMessageParse[&typeInfo];
}
class TcpSessionImpl
{
public:
TcpSessionImpl(TcpSession& session)
:m_iocontext(getWorkerService()),m_socket(m_iocontext),m_session(session)
{
}
asio::io_context& m_iocontext;
asio::ip::tcp::socket m_socket;
TcpSession& m_session;
};
class TcpServerImpl
{
public:
TcpServerImpl(asio::ip::address ip,ui16 port,TcpServer& server_)
:m_acceptor(getWorkerService(),{ ip,port }),m_server(server_)
{
}
~TcpServerImpl()
{
nlogerr("~TcpServerImpl()");
}
void accept()
{
TcpSessionPtr s = TcpSessionPtr(m_server.createSession());
TcpSessionS* ss = dynamic_cast<TcpSessionS*>(s.get());
if (ss)
{
ss->m_TcpServer = &m_server;
ss->m_MessageParser = &getTcpMessagerParse(typeid(*ss));
}
m_acceptor.async_accept(s->m_impl->m_socket,
[this,s](std::error_code ec) {
if (ec)
{
nlogerr(ec.message().c_str());
}
s->init(m_server);
accept();
});
}
asio::ip::tcp::acceptor m_acceptor;
TcpServer& m_server;
};
TcpServer::TcpServer(const std::string& ip, ui16 port)
{
asio::error_code ec;
auto addr = asio::ip::address::from_string(ip, ec);
if (ec)
{
nlogerr("TcpServer::TcpServer ip error:%s", ip.c_str());
}
try
{
m_impl = std::unique_ptr<TcpServerImpl>(new TcpServerImpl(addr, port,*this));
}
catch (asio::system_error & ec)
{
nlogerr("cannot open %s:%d", ip.c_str(), int(port));
nlogerr("%s",ec.what());
}
}
TcpServer::~TcpServer()
{
}
TcpSessionS* TcpServer::createSession()
{
return new TcpSessionS();
}
void TcpServer::addSession(const tcpuid& uid, TcpSessionPtr session)
{
auto it = m_sessions.find(uid);
if (it != m_sessions.end())
{
it->second->close();
}
m_sessions[uid] = session;
}
void TcpServer::removeSession(const tcpuid& uid, ui64 serialID)
{
auto it = m_sessions.find(uid);
if (it != m_sessions.end() && it->second->m_serialID == serialID)
{
it->second->close();
m_sessions.erase(uid);
}
}
void TcpServer::accept()
{
m_impl->accept();
}
TcpSessionS::TcpSessionS()
{
}
TcpSessionS::~TcpSessionS()
{
}
void TcpSessionS::init(TcpServer& server)
{
auto self(shared_from_this());
ui8 buff[PUBLIC_KEY_SIZE + 8 + HASH_SIZE + SIGN_SIZE] = {0};
memcpy(buff, server.m_publicKey, PUBLIC_KEY_SIZE);
ui64 now = nNow;
*(ui64*)(buff + PUBLIC_KEY_SIZE) = now;
sha3(buff, PUBLIC_KEY_SIZE + 8, buff + PUBLIC_KEY_SIZE + 8, HASH_SIZE);
m_hash = std::string((const char*)(buff + PUBLIC_KEY_SIZE + 8),HASH_SIZE);
if (uECC_sign(server.m_privateKey
, buff + PUBLIC_KEY_SIZE + 8, HASH_SIZE
, buff + PUBLIC_KEY_SIZE + 8 + HASH_SIZE
, uECC_secp256k1()) != 1
)
{
nlogerr("uECC_sign() failed\n");
return;
}
// if (uECC_verify(buff, (const ui8*)(buff + PUBLIC_KEY_SIZE + 8), HASH_SIZE, buff + PUBLIC_KEY_SIZE + 8 + HASH_SIZE, uECC_secp256k1()) != 1)
// {
// nlogerr("error check hash2");
// }
m_impl->m_socket.async_write_some(
asio::buffer(buff, sizeof(buff)),
[&,self](std::error_code ec,size_t s) {
if (ec)
{
nlogerr("%d\n", ec.value());
return;
}
self->init2(server);
});
}
void TcpSessionS::init2(TcpServer& server)
{
auto self(shared_from_this());
std::shared_ptr<asio::steady_timer> t = std::make_shared<asio::steady_timer>(getWorkerService());
m_impl->m_socket.async_wait(
asio::ip::tcp::socket::wait_read,
[&, self,t](std::error_code ec) {
t->cancel();
if (ec)
{
return;
}
ui8 data_[PUBLIC_KEY_SIZE + SIGN_SIZE] = "";
std::size_t len = m_impl->m_socket.read_some(
asio::buffer(data_, sizeof(data_)), ec);
if (ec)
{
return;
}
if (len < sizeof(data_))
{
return;
}
bool allSame = true;
for (size_t i = 0; i < PUBLIC_KEY_SIZE; ++ i)
{
if (server.m_publicKey[i] != data_[i])
{
allSame = false;
break;
}
}
if (allSame)
{
return;
}
if (uECC_verify(data_, (const ui8*)m_hash.c_str(), HASH_SIZE, data_ + PUBLIC_KEY_SIZE, uECC_secp256k1()) != 1)
{
return;
}
m_uid = std::string((const char*)data_, PUBLIC_KEY_SIZE);
static ui64 nowSerialID = 10000;
m_serialID = nowSerialID++;
nicehero::post([&,this, self] {
server.addSession(m_uid, self);
doRead();
});
});
t->expires_from_now(std::chrono::seconds(2));
t->async_wait([self](std::error_code ec) {
if (!ec)
{
nlog("session connecting timeout");
self->close();
}
});
}
void TcpSessionS::removeSelf()
{
auto self(shared_from_this());
nicehero::post([&,self] {
removeSelfImpl();
});
}
void TcpSessionS::removeSelfImpl()
{
if (m_TcpServer)
{
m_TcpServer->removeSession(m_uid, m_serialID);
}
}
TcpSession::TcpSession()
{
m_impl = std::unique_ptr<TcpSessionImpl>(new TcpSessionImpl(*this));
m_IsSending = false;
}
void TcpSession::init(TcpServer& server)
{
}
void TcpSession::init()
{
}
void TcpSession::init2(TcpServer& server)
{
}
void TcpSession::doRead()
{
auto self(shared_from_this());
this->m_impl->m_socket.async_wait(asio::ip::tcp::socket::wait_read,
[self,this](std::error_code ec) {
if (ec)
{
self->removeSelf();
return;
}
unsigned char data_[NETWORK_BUF_SIZE];
ui32 len = (ui32)self->m_impl->m_socket.read_some(asio::buffer(data_), ec);
if (ec)
{
self->removeSelf();
return;
}
if (len > 0)
{
if (!parseMsg(data_, len))
{
self->removeSelf();
return;
}
}
doRead();
});
}
bool TcpSession::parseMsg(unsigned char* data, ui32 len)
{
if (len > (ui32)NETWORK_BUF_SIZE)
{
return false;
}
Message& prevMsg = m_PreMsg;
auto self(shared_from_this());
if (prevMsg.m_buff == nullptr)
{
if (len < 4)
{
// nlog("TcpSession::parseMsg len < 4");
memcpy(&prevMsg.m_writePoint, data, len);
prevMsg.m_buff = (unsigned char*)&prevMsg.m_writePoint;
prevMsg.m_readPoint = len;
return true;
}
ui32 msgLen = *((ui32*)data);
if (msgLen > MSG_SIZE)
{
return false;
}
if (msgLen <= len)
{
auto recvMsg = make_copyable<Message>(data, *((ui32*)data));
// if (m_MessageParser && m_MessageParser->m_commands[recvMsg->getMsgID()] == nullptr)
// {
// nlogerr("TcpSession::parseMsg err 1");
// }
nicehero::post([self,recvMsg] {
self->handleMessage(recvMsg);
});
if (msgLen < len)
{
return parseMsg( data + msgLen, len - msgLen);
}
else
{
return true;
}
}
else
{
prevMsg.m_buff = new unsigned char[msgLen];
memcpy(prevMsg.m_buff, data, len);
prevMsg.m_writePoint = len;
return true;
}
}
ui32 msgLen = 0;
ui32 cutSize = 0;
if (prevMsg.m_buff == (unsigned char*)&prevMsg.m_writePoint)
{
if (prevMsg.m_readPoint + len < 4)
{
// nlog("TcpSession::parseMsg prevMsg.m_readPoint + len < 4");
memcpy(((unsigned char*)&prevMsg.m_writePoint) + prevMsg.m_readPoint
, data, len);
prevMsg.m_readPoint = prevMsg.m_readPoint + len;
return true;
}
cutSize = 4 - prevMsg.m_readPoint;
memcpy(((unsigned char*)&prevMsg.m_writePoint) + prevMsg.m_readPoint
, data, cutSize);
msgLen = prevMsg.m_writePoint;
prevMsg.m_buff = new unsigned char[msgLen];
memcpy(prevMsg.m_buff, &msgLen, 4);
prevMsg.m_readPoint = 4;
prevMsg.m_writePoint = 4;
}
msgLen = prevMsg.getSize();
if (msgLen > MSG_SIZE)
{
return false;
}
if (len + prevMsg.m_writePoint - cutSize >= msgLen)
{
// ui32 oldWritePoint = 0;//test value
// oldWritePoint = prevMsg.m_writePoint;//test value
memcpy(prevMsg.m_buff + prevMsg.m_writePoint, data + cutSize, msgLen - prevMsg.m_writePoint);
data = data + cutSize + (msgLen - prevMsg.m_writePoint);
len = len - cutSize - (msgLen - prevMsg.m_writePoint);
auto recvMsg = MessagePtr();
recvMsg->swap(prevMsg);
// if (m_MessageParser && m_MessageParser->m_commands[recvMsg->getMsgID()] == nullptr)
// {
// nlogerr("TcpSession::parseMsg err 2");
// }
nicehero::post([this,self,recvMsg] {
self->handleMessage(recvMsg);
});
if (len > 0)
{
return parseMsg( data, len);
}
return true;
}
// nlog("TcpSession::parseMsg else");
memcpy(prevMsg.m_buff + prevMsg.m_writePoint, data + cutSize, len - cutSize);
prevMsg.m_writePoint += len - cutSize;
return true;
}
void TcpSession::removeSelf()
{
}
void TcpSession::removeSelfImpl()
{
}
void TcpSession::handleMessage(MessagePtr msg)
{
if (m_MessageParser)
{
if (m_MessageParser->m_commands[msg->getMsgID()] == nullptr)
{
nlogerr("TcpSession::handleMessage undefined msg:%d", ui32(msg->getMsgID()));
return;
}
m_MessageParser->m_commands[msg->getMsgID()](shared_from_this(),msg);
}
}
void TcpSession::close()
{
m_impl->m_socket.close();
}
void TcpSession::setMessageParser(TcpMessageParser* messageParser)
{
m_MessageParser = messageParser;
}
std::string& TcpSession::getUid()
{
return m_uid;
}
void TcpSession::doSend(Message& msg)
{
auto self(shared_from_this());
std::shared_ptr<Message> msg_ = std::make_shared<Message>();
msg_->swap(msg);
m_impl->m_iocontext.post([this,self, msg_] {
//same thread ,no need lock
m_SendList.emplace_back();
m_SendList.back().swap(*msg_);
if (m_IsSending)
{
return;
}
doSend();
});
}
void TcpSession::doSend()
{
auto self(shared_from_this());
m_IsSending = true;
while (m_SendList.size() > 0 && m_SendList.front().m_buff == nullptr)
{
m_SendList.pop_front();
}
if (m_SendList.empty())
{
m_IsSending = false;
return;
}
ui8* data = m_SendList.front().m_buff;
ui32 size_ = m_SendList.front().getSize();
if (m_SendList.size() > 1 && size_ <= (ui32)NETWORK_BUF_SIZE)
{
ui8 data2[NETWORK_BUF_SIZE];
size_ = 0;
while (m_SendList.size() > 0)
{
Message& msg = m_SendList.front();
if (size_ + msg.getSize() > (ui32)NETWORK_BUF_SIZE)
{
break;
}
memcpy(data2 + size_, msg.m_buff, msg.getSize());
size_ += msg.getSize();
m_SendList.pop_front();
}
asio::async_write(m_impl->m_socket,asio::buffer(data2, size_)
, asio::transfer_at_least(size_)
, [this,self, size_](asio::error_code ec, std::size_t s) {
if (ec)
{
removeSelf();
return;
}
if (s < size_)
{
nlogerr("async_write buffer(data2, size_) err s:%d < size_:%d", int(s), int(size_));
removeSelf();
return;
}
if (m_SendList.size() > 0)
{
doSend();
return;
}
m_IsSending = false;
});
}
else
{
asio::async_write(m_impl->m_socket
, asio::buffer(data, size_)
, asio::transfer_at_least(size_)
, [this, self, size_](asio::error_code ec, std::size_t s) {
if (ec)
{
removeSelf();
return;
}
if (s < size_)
{
nlogerr("async_write buffer(data2, size_) err s:%d < size_:%d", int(s), int(size_));
removeSelf();
return;
}
m_SendList.pop_front();
if (m_SendList.size() > 0)
{
doSend();
return;
}
m_IsSending = false;
});
}
}
void TcpSession::sendMessage(Message& msg)
{
doSend(msg);
}
void TcpSession::sendMessage(const Serializable& msg)
{
Message msg_;
msg.toMsg(msg_);
sendMessage(msg_);
}
TcpSessionC::TcpSessionC()
{
m_isInit = false;
m_impl = std::unique_ptr<TcpSessionImpl>(new TcpSessionImpl(*this));
}
TcpSessionC::~TcpSessionC()
{
}
bool TcpSessionC::connect(const std::string& ip, ui16 port)
{
asio::error_code ec;
auto addr = asio::ip::address::from_string(ip, ec);
if (ec)
{
nlogerr("TcpSessionC::TcpSessionC ip error:%s", ip.c_str());
return false;
}
m_impl->m_socket.connect({addr,port} , ec);
if (ec)
{
nlogerr("TcpSessionC::TcpSessionC connect error:%s", ec.message().c_str());
return false;
}
return true;
}
void TcpSessionC::init(bool isAsync)
{
std::shared_ptr<asio::steady_timer> t = std::make_shared<asio::steady_timer>(getWorkerService());
auto f = [&, t](std::error_code ec) {
t->cancel();
if (ec)
{
nlogerr("TcpSessionC::init err %s", ec.message().c_str());
return;
}
ui8 data_[PUBLIC_KEY_SIZE + 8 + HASH_SIZE + SIGN_SIZE] = "";
std::size_t len = m_impl->m_socket.read_some(
asio::buffer(data_, sizeof(data_)), ec);
if (ec)
{
nlogerr("TcpSessionC::init err %s", ec.message().c_str());
return;
}
if (len < sizeof(data_))
{
nlogerr("server sign data len err");
return;
}
if (checkServerSign(data_) == 1)
{
nlogerr("server sign err");
return;
}
ui8 sendSign[PUBLIC_KEY_SIZE + SIGN_SIZE] = { 0 };
memcpy(sendSign, m_publicKey, PUBLIC_KEY_SIZE);
if (uECC_sign(m_privateKey
, (const ui8*)data_ + PUBLIC_KEY_SIZE + 8, HASH_SIZE
, sendSign + PUBLIC_KEY_SIZE
, uECC_secp256k1()) != 1)
{
nlogerr("uECC_sign() failed\n");
return;
}
m_uid = std::string((const char*)data_, PUBLIC_KEY_SIZE);
static ui64 nowSerialID = 10000;
m_serialID = nowSerialID++;
m_impl->m_socket.write_some(asio::buffer(sendSign, PUBLIC_KEY_SIZE + SIGN_SIZE), ec);
if (ec)
{
nlogerr("TcpSessionC::init err %s", ec.message().c_str());
return;
}
m_isInit = true;
m_MessageParser = &getTcpMessagerParse(typeid(*this));
};
if (isAsync)
{
m_impl->m_socket.async_wait(
asio::ip::tcp::socket::wait_read,f);
t->expires_from_now(std::chrono::seconds(2));
t->async_wait([&](std::error_code ec) {
if (!ec)
{
close();
}
});
}
else
{
f(std::error_code());
}
}
void TcpSessionC::startRead()
{
doRead();
}
void TcpSessionC::removeSelf()
{
close();
}
int TcpSessionC::checkServerSign(ui8* data_)
{
bool allSame = true;
for (size_t i = 0; i < PUBLIC_KEY_SIZE; ++i)
{
if (m_publicKey[i] != data_[i])
{
allSame = false;
break;
}
}
if (allSame)
{
nlogerr("same publicKey");
return 1;
}
ui64& serverTime = *(ui64*)(data_ + PUBLIC_KEY_SIZE);
int ret = 0;
if (nNow > serverTime + 10 || nNow < serverTime - 10)
{
nlogerr("your time is diff from serverTime");
ret = 2;
}
ui64 checkHash[HASH_SIZE / 8] = { 0 };
sha3(data_, PUBLIC_KEY_SIZE + 8, checkHash, HASH_SIZE);
allSame = true;
for (size_t i = 0; i < HASH_SIZE / 8; ++i)
{
if (checkHash[i] != ((ui64*)(data_ + PUBLIC_KEY_SIZE + 8))[i])
{
allSame = false;
}
}
if (!allSame)
{
nlogerr("error check hash");
return 1;
}
if (uECC_verify(data_, (const ui8*)(data_ + PUBLIC_KEY_SIZE + 8), HASH_SIZE, data_ + PUBLIC_KEY_SIZE + 8 + HASH_SIZE, uECC_secp256k1()) != 1)
{
nlogerr("error check hash2");
return 1;
}
return ret;
}
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/nicehero/nicenet.git
[email protected]:nicehero/nicenet.git
nicehero
nicenet
nicenet
master

搜索帮助