fix: assertion failure in readBtPiece() (#5097)

This commit is contained in:
Charles Kerr
2023-03-03 17:43:51 -06:00
committed by GitHub
parent 8de331c6c4
commit 7f9ef4730e
2 changed files with 201 additions and 225 deletions

View File

@@ -81,6 +81,47 @@ auto constexpr FextAllowedFast = uint8_t{ 17 };
// see also LtepMessageIds below // see also LtepMessageIds below
auto constexpr Ltep = uint8_t{ 20 }; auto constexpr Ltep = uint8_t{ 20 };
[[nodiscard]] constexpr std::string_view debug_name(uint8_t type) noexcept
{
switch (type)
{
case Bitfield:
return "bitfield"sv;
case Cancel:
return "cancel"sv;
case Choke:
return "choke"sv;
case FextAllowedFast:
return "fext-allow-fast"sv;
case FextHaveAll:
return "fext-have-all"sv;
case FextHaveNone:
return "fext-have-none"sv;
case FextReject:
return "fext-reject"sv;
case FextSuggest:
return "fext-suggest"sv;
case Have:
return "have"sv;
case Interested:
return "interested"sv;
case Ltep:
return "ltep"sv;
case NotInterested:
return "not-interested"sv;
case Piece:
return "piece"sv;
case Port:
return "port"sv;
case Request:
return "request"sv;
case Unchoke:
return "unchoke"sv;
default:
return "unknown"sv;
}
}
} // namespace BtPeerMsgs } // namespace BtPeerMsgs
namespace LtepMessages namespace LtepMessages
@@ -145,14 +186,6 @@ auto constexpr MaxPexPeerCount = size_t{ 50 };
// --- // ---
enum class AwaitingBt
{
Length,
Id,
Message,
Piece
};
enum class EncryptionPreference enum class EncryptionPreference
{ {
Unknown, Unknown,
@@ -186,9 +219,9 @@ peer_request blockToReq(tr_torrent const* tor, tr_block_index_t block)
* the current message that it's sending us. */ * the current message that it's sending us. */
struct tr_incoming struct tr_incoming
{ {
uint8_t id = 0; // the protocol message, e.g. BtPeerMsgs::Piece std::optional<uint32_t> length; // the full message payload length. Includes the +1 for id length
uint32_t length = 0; // the full message payload length. Includes the +1 for id length std::optional<uint8_t> id; // the protocol message, e.g. BtPeerMsgs::Piece
std::optional<peer_request> block_req; // metadata for incoming blocks libtransmission::Buffer payload;
struct incoming_piece_data struct incoming_piece_data
{ {
@@ -237,6 +270,7 @@ void updateDesiredRequestCount(tr_peerMsgsImpl* msgs);
#define logdbg(msgs, text) myLogMacro(msgs, TR_LOG_DEBUG, text) #define logdbg(msgs, text) myLogMacro(msgs, TR_LOG_DEBUG, text)
#define logtrace(msgs, text) myLogMacro(msgs, TR_LOG_TRACE, text) #define logtrace(msgs, text) myLogMacro(msgs, TR_LOG_TRACE, text)
#define logwarn(msgs, text) myLogMacro(msgs, TR_LOG_WARN, text)
/** /**
* Low-level communication state information about a connected peer. * Low-level communication state information about a connected peer.
@@ -660,7 +694,6 @@ public:
* very quickly; others aren't as urgent. */ * very quickly; others aren't as urgent. */
int8_t outMessagesBatchPeriod; int8_t outMessagesBatchPeriod;
AwaitingBt state = AwaitingBt::Length;
uint8_t ut_pex_id = 0; uint8_t ut_pex_id = 0;
uint8_t ut_metadata_id = 0; uint8_t ut_metadata_id = 0;
@@ -983,16 +1016,11 @@ void sendLtepHandshake(tr_peerMsgsImpl* msgs)
tr_variantClear(&val); tr_variantClear(&val);
} }
void parseLtepHandshake(tr_peerMsgsImpl* msgs, uint32_t len) void parseLtepHandshake(tr_peerMsgsImpl* msgs, libtransmission::Buffer& payload)
{ {
msgs->peerSentLtepHandshake = true; msgs->peerSentLtepHandshake = true;
// LTEP messages are usually just a couple hundred bytes, auto const handshake_sv = payload.pullup_sv();
// so try using a strbuf to handle it on the stack
auto tmp = tr_strbuf<char, 512>{};
tmp.resize(len);
msgs->io->read_bytes(std::data(tmp), std::size(tmp));
auto const handshake_sv = tmp.sv();
auto val = tr_variant{}; auto val = tr_variant{};
if (!tr_variantFromBuf(&val, TR_VARIANT_PARSE_BENC | TR_VARIANT_PARSE_INPLACE, handshake_sv) || !tr_variantIsDict(&val)) if (!tr_variantFromBuf(&val, TR_VARIANT_PARSE_BENC | TR_VARIANT_PARSE_INPLACE, handshake_sv) || !tr_variantIsDict(&val))
@@ -1089,16 +1117,14 @@ void parseLtepHandshake(tr_peerMsgsImpl* msgs, uint32_t len)
tr_variantClear(&val); tr_variantClear(&val);
} }
void parseUtMetadata(tr_peerMsgsImpl* msgs, uint32_t msglen) void parseUtMetadata(tr_peerMsgsImpl* msgs, libtransmission::Buffer& payload_in)
{ {
int64_t msg_type = -1; int64_t msg_type = -1;
int64_t piece = -1; int64_t piece = -1;
int64_t total_size = 0; int64_t total_size = 0;
auto tmp = std::vector<char>{}; auto const tmp = payload_in.pullup_sv();
tmp.resize(msglen); auto const* const msg_end = std::data(tmp) + std::size(tmp);
msgs->io->read_bytes(std::data(tmp), std::size(tmp));
char const* const msg_end = std::data(tmp) + std::size(tmp);
auto dict = tr_variant{}; auto dict = tr_variant{};
char const* benc_end = nullptr; char const* benc_end = nullptr;
@@ -1158,7 +1184,7 @@ void parseUtMetadata(tr_peerMsgsImpl* msgs, uint32_t msglen)
} }
} }
void parseUtPex(tr_peerMsgsImpl* msgs, uint32_t msglen) void parseUtPex(tr_peerMsgsImpl* msgs, libtransmission::Buffer& payload)
{ {
auto* const tor = msgs->torrent; auto* const tor = msgs->torrent;
if (!tor->allowsPex()) if (!tor->allowsPex())
@@ -1166,9 +1192,7 @@ void parseUtPex(tr_peerMsgsImpl* msgs, uint32_t msglen)
return; return;
} }
auto tmp = std::vector<char>{}; auto const tmp = payload.pullup_sv();
tmp.resize(msglen);
msgs->io->read_bytes(std::data(tmp), std::size(tmp));
if (tr_variant val; tr_variantFromBuf(&val, TR_VARIANT_PARSE_BENC | TR_VARIANT_PARSE_INPLACE, tmp)) if (tr_variant val; tr_variantFromBuf(&val, TR_VARIANT_PARSE_BENC | TR_VARIANT_PARSE_INPLACE, tmp))
{ {
@@ -1208,18 +1232,16 @@ void parseUtPex(tr_peerMsgsImpl* msgs, uint32_t msglen)
} }
} }
void parseLtep(tr_peerMsgsImpl* msgs, uint32_t msglen) void parseLtep(tr_peerMsgsImpl* msgs, libtransmission::Buffer& payload)
{ {
TR_ASSERT(msglen > 0); TR_ASSERT(!std::empty(payload));
auto ltep_msgid = uint8_t{}; auto const ltep_msgid = payload.to_uint8();
msgs->io->read_uint8(&ltep_msgid);
msglen--;
if (ltep_msgid == LtepMessages::Handshake) if (ltep_msgid == LtepMessages::Handshake)
{ {
logtrace(msgs, "got ltep handshake"); logtrace(msgs, "got ltep handshake");
parseLtepHandshake(msgs, msglen); parseLtepHandshake(msgs, payload);
if (msgs->io->supports_ltep()) if (msgs->io->supports_ltep())
{ {
@@ -1231,73 +1253,23 @@ void parseLtep(tr_peerMsgsImpl* msgs, uint32_t msglen)
{ {
logtrace(msgs, "got ut pex"); logtrace(msgs, "got ut pex");
msgs->peerSupportsPex = true; msgs->peerSupportsPex = true;
parseUtPex(msgs, msglen); parseUtPex(msgs, payload);
} }
else if (ltep_msgid == UT_METADATA_ID) else if (ltep_msgid == UT_METADATA_ID)
{ {
logtrace(msgs, "got ut metadata"); logtrace(msgs, "got ut metadata");
msgs->peerSupportsMetadataXfer = true; msgs->peerSupportsMetadataXfer = true;
parseUtMetadata(msgs, msglen); parseUtMetadata(msgs, payload);
} }
else else
{ {
logtrace(msgs, fmt::format(FMT_STRING("skipping unknown ltep message ({:d})"), static_cast<int>(ltep_msgid))); logtrace(msgs, fmt::format(FMT_STRING("skipping unknown ltep message ({:d})"), static_cast<int>(ltep_msgid)));
msgs->io->read_buffer_drain(msglen);
} }
} }
ReadState readBtLength(tr_peerMsgsImpl* msgs, size_t inlen) using ReadResult = std::pair<ReadState, size_t /*n_piece_data_bytes_read*/>;
{
auto len = uint32_t{};
if (inlen < sizeof(len))
{
return READ_LATER;
}
msgs->io->read_uint32(&len); ReadResult process_peer_message(tr_peerMsgsImpl* msgs, uint8_t id, libtransmission::Buffer& payload);
if (len == 0) /* peer sent us a keepalive message */
{
logtrace(msgs, "got KeepAlive");
}
else
{
msgs->incoming.length = len;
msgs->state = AwaitingBt::Id;
}
return READ_NOW;
}
ReadState readBtMessage(tr_peerMsgsImpl* /*msgs*/, size_t /*inlen*/);
ReadState readBtId(tr_peerMsgsImpl* msgs, size_t inlen)
{
if (inlen < sizeof(uint8_t))
{
return READ_LATER;
}
auto id = uint8_t{};
msgs->io->read_uint8(&id);
msgs->incoming.id = id;
logtrace(
msgs,
fmt::format(FMT_STRING("msgs->incoming.id is now {:d}: msgs->incoming.length is {:d}"), id, msgs->incoming.length));
if (id == BtPeerMsgs::Piece)
{
msgs->state = AwaitingBt::Piece;
return READ_NOW;
}
if (msgs->incoming.length != 1)
{
msgs->state = AwaitingBt::Message;
return READ_NOW;
}
return readBtMessage(msgs, inlen - 1);
}
void prefetchPieces(tr_peerMsgsImpl* msgs) void prefetchPieces(tr_peerMsgsImpl* msgs)
{ {
@@ -1413,109 +1385,74 @@ bool messageLengthIsCorrect(tr_peerMsgsImpl const* msg, uint8_t id, uint32_t len
int clientGotBlock(tr_peerMsgsImpl* msgs, std::unique_ptr<std::vector<uint8_t>> block_data, tr_block_index_t block); int clientGotBlock(tr_peerMsgsImpl* msgs, std::unique_ptr<std::vector<uint8_t>> block_data, tr_block_index_t block);
ReadState readBtPiece(tr_peerMsgsImpl* msgs, size_t inlen, size_t* setme_piece_bytes_read) ReadResult read_piece_data(tr_peerMsgsImpl* msgs, libtransmission::Buffer& payload)
{ {
TR_ASSERT(msgs->io->read_buffer_size() >= inlen); // <index><begin><block>
auto const piece = payload.to_uint32();
auto const offset = payload.to_uint32();
auto const len = std::size(payload);
logtrace(msgs, "In readBtPiece"); auto const loc = msgs->torrent->pieceLoc(piece, offset);
// If this is the first we've seen of the piece data, parse out the header
auto& incoming = msgs->incoming;
if (!incoming.block_req)
{
if (inlen < 8)
{
return READ_LATER;
}
auto req = peer_request{};
msgs->io->read_uint32(&req.index);
msgs->io->read_uint32(&req.offset);
req.length = incoming.length - 9;
logtrace(msgs, fmt::format(FMT_STRING("got incoming block header {:d}:{:d}->{:d}"), req.index, req.offset, req.length));
incoming.block_req = req;
return READ_NOW;
}
auto& req = incoming.block_req;
auto const loc = msgs->torrent->pieceLoc(req->index, req->offset);
auto const block = loc.block; auto const block = loc.block;
auto const block_size = msgs->torrent->blockSize(block); auto const block_size = msgs->torrent->blockSize(block);
auto const n_this_pass = std::min(size_t{ req->length }, inlen); if (loc.block_offset + len > block_size)
TR_ASSERT(loc.block_offset + n_this_pass <= block_size);
if (n_this_pass == 0)
{ {
return READ_LATER; logwarn(msgs, fmt::format("got unaligned piece {:d}:{:d}->{:d}", piece, offset, len));
return { READ_ERR, len };
} }
auto& incoming_block = incoming.blocks.try_emplace(block, block_size).first->second; if (!tr_peerMgrDidPeerRequest(msgs->torrent, msgs, block))
msgs->io->read_bytes(std::data(*incoming_block.buf) + loc.block_offset, n_this_pass);
msgs->publish(tr_peer_event::GotPieceData(n_this_pass));
*setme_piece_bytes_read += n_this_pass;
incoming_block.have.setSpan(loc.block_offset, loc.block_offset + n_this_pass);
logtrace(msgs, fmt::format("got {:d} bytes for req {:d}:{:d}->{:d}", n_this_pass, req->index, req->offset, req->length));
// if we haven't gotten the full response yet,
// update what part of `req` is unfulfilled and wait for more
if (req->length > n_this_pass)
{ {
req->length -= n_this_pass; logwarn(msgs, fmt::format("got unrequested piece {:d}:{:d}->{:d}", piece, offset, len));
auto const new_loc = msgs->torrent->byteLoc(loc.byte + n_this_pass); return { READ_ERR, len };
req->index = new_loc.piece;
req->offset = new_loc.piece_offset;
return READ_LATER;
} }
// we've got the entire response message auto& blocks = msgs->incoming.blocks;
req.reset(); auto& incoming_block = blocks.try_emplace(block, block_size).first->second;
msgs->state = AwaitingBt::Length; payload.to_buf(std::data(*incoming_block.buf) + loc.block_offset, len);
msgs->publish(tr_peer_event::GotPieceData(len));
incoming_block.have.setSpan(loc.block_offset, loc.block_offset + len);
logtrace(msgs, fmt::format("got {:d} bytes for req {:d}:{:d}->{:d}", len, piece, offset, len));
// if we haven't gotten the entire block yet, wait for more // if we haven't gotten the entire block yet, wait for more
if (!incoming_block.have.hasAll()) if (!incoming_block.have.hasAll())
{ {
return READ_LATER; return { READ_LATER, len };
} }
// we've got the entire block, so send it along. // we've got the entire block, so send it along.
auto block_buf = std::move(incoming_block.buf); auto block_buf = std::move(incoming_block.buf);
incoming.blocks.erase(block); // note: invalidates `incoming_block` local blocks.erase(block); // note: invalidates `incoming_block` local
auto const ok = clientGotBlock(msgs, std::move(block_buf), block) == 0; auto const ok = clientGotBlock(msgs, std::move(block_buf), block) == 0;
return ok ? READ_NOW : READ_ERR; return { ok ? READ_NOW : READ_ERR, len };
} }
ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen) ReadResult process_peer_message(tr_peerMsgsImpl* msgs, uint8_t id, libtransmission::Buffer& payload)
{ {
uint8_t const id = msgs->incoming.id;
#ifdef TR_ENABLE_ASSERTS
auto const start_buflen = msgs->io->read_buffer_size();
#endif
bool const fext = msgs->io->supports_fext(); bool const fext = msgs->io->supports_fext();
auto ui32 = uint32_t{}; auto ui32 = uint32_t{};
auto msglen = uint32_t{ msgs->incoming.length };
TR_ASSERT(msglen > 0);
--msglen; /* id length */
logtrace( logtrace(
msgs, msgs,
fmt::format(FMT_STRING("got BT id {:d}, len {:d}, buffer size is {:d}"), static_cast<int>(id), msglen, inlen)); fmt::format(
"got peer msg '{:s}' ({:d}) with payload len {:d}",
BtPeerMsgs::debug_name(id),
static_cast<int>(id),
std::size(payload)));
if (inlen < msglen) if (!messageLengthIsCorrect(msgs, id, sizeof(id) + std::size(payload)))
{
return READ_LATER;
}
if (!messageLengthIsCorrect(msgs, id, msglen + 1))
{ {
logdbg( logdbg(
msgs, msgs,
fmt::format(FMT_STRING("bad packet - BT message #{:d} with a length of {:d}"), static_cast<int>(id), msglen)); fmt::format(
"bad msg: '{:s}' ({:d}) with payload len {:d}",
BtPeerMsgs::debug_name(id),
static_cast<int>(id),
std::size(payload)));
msgs->publish(tr_peer_event::GotError(EMSGSIZE)); msgs->publish(tr_peer_event::GotError(EMSGSIZE));
return READ_ERR; return { READ_ERR, {} };
} }
switch (id) switch (id)
@@ -1552,13 +1489,13 @@ ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
break; break;
case BtPeerMsgs::Have: case BtPeerMsgs::Have:
msgs->io->read_uint32(&ui32); ui32 = payload.to_uint32();
logtrace(msgs, fmt::format(FMT_STRING("got Have: {:d}"), ui32)); logtrace(msgs, fmt::format(FMT_STRING("got Have: {:d}"), ui32));
if (msgs->torrent->hasMetainfo() && ui32 >= msgs->torrent->pieceCount()) if (msgs->torrent->hasMetainfo() && ui32 >= msgs->torrent->pieceCount())
{ {
msgs->publish(tr_peer_event::GotError(ERANGE)); msgs->publish(tr_peer_event::GotError(ERANGE));
return READ_ERR; return { READ_ERR, {} };
} }
/* a peer can send the same HAVE message twice... */ /* a peer can send the same HAVE message twice... */
@@ -1574,10 +1511,9 @@ ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
case BtPeerMsgs::Bitfield: case BtPeerMsgs::Bitfield:
{ {
logtrace(msgs, "got a bitfield"); logtrace(msgs, "got a bitfield");
auto tmp = std::vector<uint8_t>(msglen); auto const [buf, buflen] = payload.pullup();
msgs->io->read_bytes(std::data(tmp), std::size(tmp)); msgs->have_ = tr_bitfield{ msgs->torrent->hasMetainfo() ? msgs->torrent->pieceCount() : buflen * 8 };
msgs->have_ = tr_bitfield{ msgs->torrent->hasMetainfo() ? msgs->torrent->pieceCount() : std::size(tmp) * 8 }; msgs->have_.setRaw(reinterpret_cast<uint8_t const*>(buf), buflen);
msgs->have_.setRaw(std::data(tmp), std::size(tmp));
msgs->publish(tr_peer_event::GotBitfield(&msgs->have_)); msgs->publish(tr_peer_event::GotBitfield(&msgs->have_));
msgs->invalidatePercentDone(); msgs->invalidatePercentDone();
break; break;
@@ -1586,9 +1522,9 @@ ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
case BtPeerMsgs::Request: case BtPeerMsgs::Request:
{ {
struct peer_request r; struct peer_request r;
msgs->io->read_uint32(&r.index); r.index = payload.to_uint32();
msgs->io->read_uint32(&r.offset); r.offset = payload.to_uint32();
msgs->io->read_uint32(&r.length); r.length = payload.to_uint32();
logtrace(msgs, fmt::format(FMT_STRING("got Request: {:d}:{:d}->{:d}"), r.index, r.offset, r.length)); logtrace(msgs, fmt::format(FMT_STRING("got Request: {:d}:{:d}->{:d}"), r.index, r.offset, r.length));
peerMadeRequest(msgs, &r); peerMadeRequest(msgs, &r);
break; break;
@@ -1597,9 +1533,9 @@ ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
case BtPeerMsgs::Cancel: case BtPeerMsgs::Cancel:
{ {
struct peer_request r; struct peer_request r;
msgs->io->read_uint32(&r.index); r.index = payload.to_uint32();
msgs->io->read_uint32(&r.offset); r.offset = payload.to_uint32();
msgs->io->read_uint32(&r.length); r.length = payload.to_uint32();
msgs->cancels_sent_to_client.add(tr_time(), 1); msgs->cancels_sent_to_client.add(tr_time(), 1);
logtrace(msgs, fmt::format(FMT_STRING("got a Cancel {:d}:{:d}->{:d}"), r.index, r.offset, r.length)); logtrace(msgs, fmt::format(FMT_STRING("got a Cancel {:d}:{:d}->{:d}"), r.index, r.offset, r.length));
@@ -1620,7 +1556,7 @@ ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
} }
case BtPeerMsgs::Piece: case BtPeerMsgs::Piece:
TR_ASSERT(false); /* handled elsewhere! */ return read_piece_data(msgs, payload);
break; break;
case BtPeerMsgs::Port: case BtPeerMsgs::Port:
@@ -1633,8 +1569,7 @@ ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
{ {
logtrace(msgs, "Got a BtPeerMsgs::Port"); logtrace(msgs, "Got a BtPeerMsgs::Port");
auto hport = uint16_t{}; auto const hport = payload.to_uint16();
msgs->io->read_uint16(&hport); // read_uint16 performs ntoh
if (auto const dht_port = tr_port::fromHost(hport); !std::empty(dht_port)) if (auto const dht_port = tr_port::fromHost(hport); !std::empty(dht_port))
{ {
msgs->dht_port = dht_port; msgs->dht_port = dht_port;
@@ -1645,32 +1580,32 @@ ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
case BtPeerMsgs::FextSuggest: case BtPeerMsgs::FextSuggest:
logtrace(msgs, "Got a BtPeerMsgs::FextSuggest"); logtrace(msgs, "Got a BtPeerMsgs::FextSuggest");
msgs->io->read_uint32(&ui32);
if (fext) if (fext)
{ {
msgs->publish(tr_peer_event::GotSuggest(ui32)); auto const piece = payload.to_uint32();
msgs->publish(tr_peer_event::GotSuggest(piece));
} }
else else
{ {
msgs->publish(tr_peer_event::GotError(EMSGSIZE)); msgs->publish(tr_peer_event::GotError(EMSGSIZE));
return READ_ERR; return { READ_ERR, {} };
} }
break; break;
case BtPeerMsgs::FextAllowedFast: case BtPeerMsgs::FextAllowedFast:
logtrace(msgs, "Got a BtPeerMsgs::FextAllowedFast"); logtrace(msgs, "Got a BtPeerMsgs::FextAllowedFast");
msgs->io->read_uint32(&ui32);
if (fext) if (fext)
{ {
msgs->publish(tr_peer_event::GotAllowedFast(ui32)); auto const piece = payload.to_uint32();
msgs->publish(tr_peer_event::GotAllowedFast(piece));
} }
else else
{ {
msgs->publish(tr_peer_event::GotError(EMSGSIZE)); msgs->publish(tr_peer_event::GotError(EMSGSIZE));
return READ_ERR; return { READ_ERR, {} };
} }
break; break;
@@ -1687,7 +1622,7 @@ ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
else else
{ {
msgs->publish(tr_peer_event::GotError(EMSGSIZE)); msgs->publish(tr_peer_event::GotError(EMSGSIZE));
return READ_ERR; return { READ_ERR, {} };
} }
break; break;
@@ -1704,7 +1639,7 @@ ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
else else
{ {
msgs->publish(tr_peer_event::GotError(EMSGSIZE)); msgs->publish(tr_peer_event::GotError(EMSGSIZE));
return READ_ERR; return { READ_ERR, {} };
} }
break; break;
@@ -1712,10 +1647,9 @@ ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
case BtPeerMsgs::FextReject: case BtPeerMsgs::FextReject:
{ {
struct peer_request r; struct peer_request r;
logtrace(msgs, "Got a BtPeerMsgs::FextReject"); r.index = payload.to_uint32();
msgs->io->read_uint32(&r.index); r.offset = payload.to_uint32();
msgs->io->read_uint32(&r.offset); r.length = payload.to_uint32();
msgs->io->read_uint32(&r.length);
if (fext) if (fext)
{ {
@@ -1725,7 +1659,7 @@ ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
else else
{ {
msgs->publish(tr_peer_event::GotError(EMSGSIZE)); msgs->publish(tr_peer_event::GotError(EMSGSIZE));
return READ_ERR; return { READ_ERR, {} };
} }
break; break;
@@ -1733,20 +1667,15 @@ ReadState readBtMessage(tr_peerMsgsImpl* msgs, size_t inlen)
case BtPeerMsgs::Ltep: case BtPeerMsgs::Ltep:
logtrace(msgs, "Got a BtPeerMsgs::Ltep"); logtrace(msgs, "Got a BtPeerMsgs::Ltep");
parseLtep(msgs, msglen); parseLtep(msgs, payload);
break; break;
default: default:
logtrace(msgs, fmt::format(FMT_STRING("peer sent us an UNKNOWN: {:d}"), static_cast<int>(id))); logtrace(msgs, fmt::format(FMT_STRING("peer sent us an UNKNOWN: {:d}"), static_cast<int>(id)));
msgs->io->read_buffer_drain(msglen);
break; break;
} }
TR_ASSERT(msglen + 1 == msgs->incoming.length); return { READ_NOW, {} };
TR_ASSERT(msgs->io->read_buffer_size() == start_buflen - msglen);
msgs->state = AwaitingBt::Length;
return READ_NOW;
} }
/* returns 0 on success, or an errno on failure */ /* returns 0 on success, or an errno on failure */
@@ -1812,48 +1741,81 @@ void didWrite(tr_peerIo* /*io*/, size_t bytes_written, bool was_piece_data, void
ReadState canRead(tr_peerIo* io, void* vmsgs, size_t* piece) ReadState canRead(tr_peerIo* io, void* vmsgs, size_t* piece)
{ {
auto* msgs = static_cast<tr_peerMsgsImpl*>(vmsgs); auto* msgs = static_cast<tr_peerMsgsImpl*>(vmsgs);
size_t const inlen = io->read_buffer_size();
logtrace( // https://www.bittorrent.org/beps/bep_0003.html
msgs, // Next comes an alternating stream of length prefixes and messages.
fmt::format(FMT_STRING("canRead: inlen is {:d}, msgs->state is {:d}"), inlen, static_cast<int>(msgs->state))); // Messages of length zero are keepalives, and ignored.
// All non-keepalive messages start with a single byte which gives their type.
//
// https://wiki.theory.org/BitTorrentSpecification
// All of the remaining messages in the protocol take the form of
// <length prefix><message ID><payload>. The length prefix is a four byte
// big-endian value. The message ID is a single decimal byte.
// The payload is message dependent.
auto ret = ReadState{}; // read <length prefix>
if (inlen == 0) auto& current_message_len = msgs->incoming.length; // the full message payload length. Includes the +1 for id length
if (!current_message_len)
{ {
ret = READ_LATER; auto message_len = uint32_t{};
if (io->read_buffer_size() < sizeof(message_len))
{
return READ_LATER;
} }
else if (msgs->state == AwaitingBt::Piece)
{
ret = readBtPiece(msgs, inlen, piece);
}
else
{
switch (msgs->state)
{
case AwaitingBt::Length:
ret = readBtLength(msgs, inlen);
break;
case AwaitingBt::Id: io->read_uint32(&message_len);
ret = readBtId(msgs, inlen); current_message_len = message_len;
break;
case AwaitingBt::Message: // The keep-alive message is a message with zero bytes,
ret = readBtMessage(msgs, inlen); // specified with the length prefix set to zero.
break; // There is no message ID and no payload.
if (auto const is_keepalive = message_len == uint32_t{}; is_keepalive)
default: {
#ifdef TR_ENABLE_ASSERTS logtrace(msgs, "got KeepAlive");
TR_ASSERT_MSG(false, fmt::format(FMT_STRING("unhandled peer messages state {:d}"), static_cast<int>(msgs->state))); current_message_len.reset();
#else return READ_NOW;
ret = READ_ERR;
break;
#endif
} }
} }
return ret; // read <message ID>
auto& current_message_type = msgs->incoming.id;
if (!current_message_type)
{
auto message_type = uint8_t{};
if (io->read_buffer_size() < sizeof(message_type))
{
return READ_LATER;
}
io->read_uint8(&message_type);
current_message_type = message_type;
}
// read <payload>
auto& payload = msgs->incoming.payload;
auto const full_payload_len = *current_message_len - sizeof(uint8_t /*message_type*/);
auto n_left = full_payload_len - std::size(payload);
while (n_left > 0U && io->read_buffer_size() > 0U)
{
auto buf = std::array<char, tr_block_info::BlockSize>{};
auto const n_this_pass = std::min({ n_left, io->read_buffer_size(), std::size(buf) });
io->read_bytes(std::data(buf), n_this_pass);
payload.add(std::data(buf), n_this_pass);
n_left -= n_this_pass;
logtrace(msgs, fmt::format("read {:d} payload bytes; {:d} left to go", n_this_pass, n_left));
}
if (n_left > 0U)
{
return READ_LATER;
}
auto const [read_state, n_piece_bytes_read] = process_peer_message(msgs, *current_message_type, payload);
current_message_type.reset();
current_message_len.reset();
payload.clear();
*piece = n_piece_bytes_read;
return read_state;
} }
// --- // ---

View File

@@ -10,6 +10,7 @@
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <string> #include <string>
#include <string_view>
#include <event2/buffer.h> #include <event2/buffer.h>
@@ -197,6 +198,13 @@ public:
return evbuffer_remove(buf_.get(), tgt, n_bytes); return evbuffer_remove(buf_.get(), tgt, n_bytes);
} }
[[nodiscard]] auto to_uint8()
{
auto tmp = uint8_t{};
to_buf(&tmp, sizeof(tmp));
return tmp;
}
[[nodiscard]] uint16_t to_uint16() [[nodiscard]] uint16_t to_uint16()
{ {
auto tmp = uint16_t{}; auto tmp = uint16_t{};
@@ -247,6 +255,12 @@ public:
return { reinterpret_cast<std::byte*>(evbuffer_pullup(buf_.get(), -1)), size() }; return { reinterpret_cast<std::byte*>(evbuffer_pullup(buf_.get(), -1)), size() };
} }
[[nodiscard]] auto pullup_sv()
{
auto const [buf, buflen] = pullup();
return std::string_view{ reinterpret_cast<char const*>(buf), buflen };
}
void reserve(size_t n_bytes) void reserve(size_t n_bytes)
{ {
evbuffer_expand(buf_.get(), n_bytes - size()); evbuffer_expand(buf_.get(), n_bytes - size());