diff --git a/libtransmission/bandwidth.cc b/libtransmission/bandwidth.cc index 048a07309..283b91117 100644 --- a/libtransmission/bandwidth.cc +++ b/libtransmission/bandwidth.cc @@ -139,7 +139,7 @@ void tr_bandwidth::allocateBandwidth( tr_priority_t parent_priority, tr_direction dir, unsigned int period_msec, - std::vector& peer_pool) + std::vector>& peer_pool) { tr_priority_t const priority = std::max(parent_priority, this->priority_); @@ -151,10 +151,10 @@ void tr_bandwidth::allocateBandwidth( } /* add this bandwidth's peer, if any, to the peer pool */ - if (this->peer_ != nullptr) + if (auto shared = this->peer_.lock(); shared) { - this->peer_->priority = priority; - peer_pool.push_back(this->peer_); + shared->priority = priority; + peer_pool.push_back(std::move(shared)); } // traverse & repeat for the subtree @@ -199,33 +199,34 @@ void tr_bandwidth::allocate(tr_direction dir, unsigned int period_msec) { TR_ASSERT(tr_isDirection(dir)); + // keep these peers alive for the scope of this function + auto refs = std::vector>{}; + auto high = std::vector{}; auto low = std::vector{}; auto normal = std::vector{}; - auto tmp = std::vector{}; /* allocateBandwidth () is a helper function with two purposes: * 1. allocate bandwidth to b and its subtree * 2. accumulate an array of all the peerIos from b and its subtree. */ - this->allocateBandwidth(TR_PRI_LOW, dir, period_msec, tmp); + this->allocateBandwidth(TR_PRI_LOW, dir, period_msec, refs); - for (auto* io : tmp) + for (auto& io : refs) { - tr_peerIoRef(io); io->flushOutgoingProtocolMsgs(); switch (io->priority) { case TR_PRI_HIGH: - high.push_back(io); + high.push_back(io.get()); [[fallthrough]]; case TR_PRI_NORMAL: - normal.push_back(io); + normal.push_back(io.get()); [[fallthrough]]; default: - low.push_back(io); + low.push_back(io.get()); } } @@ -241,15 +242,10 @@ void tr_bandwidth::allocate(tr_direction dir, unsigned int period_msec) * enable on-demand IO for peers with bandwidth left to burn. * This on-demand IO is enabled until (1) the peer runs out of bandwidth, * or (2) the next tr_bandwidth::allocate () call, when we start over again. */ - for (auto* io : tmp) + for (auto& io : refs) { io->setEnabled(dir, io->hasBandwidthLeft(dir)); } - - for (auto* io : tmp) - { - tr_peerIoUnref(io); - } } /*** diff --git a/libtransmission/bandwidth.h b/libtransmission/bandwidth.h index 2ef7f489c..9ece06760 100644 --- a/libtransmission/bandwidth.h +++ b/libtransmission/bandwidth.h @@ -12,6 +12,7 @@ #include #include // size_t #include // uint64_t +#include #include #include "transmission.h" @@ -98,12 +99,10 @@ public: tr_bandwidth(tr_bandwidth&&) = delete; tr_bandwidth(tr_bandwidth&) = delete; - /** - * @brief Sets new peer, nullptr is allowed. - */ - constexpr void setPeer(tr_peerIo* peer) noexcept + // @brief Sets the peer. nullptr is allowed. + void setPeer(std::weak_ptr peer) noexcept { - this->peer_ = peer; + this->peer_ = std::move(peer); } /** @@ -258,12 +257,12 @@ private: tr_priority_t parent_priority, tr_direction dir, unsigned int period_msec, - std::vector& peer_pool); + std::vector>& peer_pool); mutable std::array band_ = {}; std::vector children_; tr_bandwidth* parent_ = nullptr; - tr_peerIo* peer_ = nullptr; + std::weak_ptr peer_; tr_priority_t priority_ = 0; }; diff --git a/libtransmission/cache.cc b/libtransmission/cache.cc index e8a680147..3784cca59 100644 --- a/libtransmission/cache.cc +++ b/libtransmission/cache.cc @@ -21,7 +21,6 @@ #include "torrent.h" #include "torrents.h" #include "tr-assert.h" -#include "trevent.h" #include "utils.h" // tr_time(), tr_formatter Cache::Key Cache::makeKey(tr_torrent const* torrent, tr_block_info::Location loc) noexcept diff --git a/libtransmission/handshake.cc b/libtransmission/handshake.cc index 33e30528c..4a33125a4 100644 --- a/libtransmission/handshake.cc +++ b/libtransmission/handshake.cc @@ -114,8 +114,12 @@ enum handshake_state_t struct tr_handshake { - tr_handshake(std::shared_ptr mediator_in, tr_encryption_mode encryption_mode_in) + tr_handshake( + std::shared_ptr mediator_in, + std::shared_ptr io_in, + tr_encryption_mode encryption_mode_in) : mediator{ std::move(mediator_in) } + , io{ std::move(io_in) } , dh{ mediator->privateKey() } , encryption_mode{ encryption_mode_in } { @@ -125,16 +129,9 @@ struct tr_handshake tr_handshake(tr_handshake const&) = delete; tr_handshake& operator=(tr_handshake&&) = delete; tr_handshake& operator=(tr_handshake const&) = delete; + ~tr_handshake() = default; - ~tr_handshake() - { - if (io != nullptr) - { - tr_peerIoUnref(io); /* balanced by the ref in tr_handshakeNew */ - } - } - - [[nodiscard]] auto constexpr isIncoming() const noexcept + [[nodiscard]] auto isIncoming() const noexcept { return io->isIncoming(); } @@ -143,7 +140,7 @@ struct tr_handshake bool haveReadAnythingFromPeer = false; bool haveSentBitTorrentHandshake = false; - tr_peerIo* io = nullptr; + std::shared_ptr const io; DH dh = {}; handshake_state_t state = AWAITING_HANDSHAKE; tr_encryption_mode encryption_mode; @@ -1133,20 +1130,18 @@ static void gotError(tr_peerIo* io, short what, void* vhandshake) tr_handshake* tr_handshakeNew( std::shared_ptr mediator, - tr_peerIo* io, + std::shared_ptr io, tr_encryption_mode encryption_mode, tr_handshake_done_func done_func, void* done_func_user_data) { - auto* const handshake = new tr_handshake{ std::move(mediator), encryption_mode }; - handshake->io = io; + auto* const handshake = new tr_handshake{ std::move(mediator), std::move(io), encryption_mode }; handshake->done_func = done_func; handshake->done_func_user_data = done_func_user_data; handshake->timeout_timer = handshake->mediator->createTimer(); handshake->timeout_timer->setCallback([handshake]() { tr_handshakeAbort(handshake); }); handshake->timeout_timer->startSingleShot(HandshakeTimeoutSec); - tr_peerIoRef(io); /* balanced by the unref in ~tr_handshake() */ handshake->io->setCallbacks(canRead, nullptr, gotError, handshake); if (handshake->isIncoming()) @@ -1169,13 +1164,3 @@ tr_handshake* tr_handshakeNew( return handshake; } - -tr_peerIo* tr_handshakeStealIO(tr_handshake* handshake) -{ - TR_ASSERT(handshake != nullptr); - TR_ASSERT(handshake->io != nullptr); - - tr_peerIo* io = handshake->io; - handshake->io = nullptr; - return io; -} diff --git a/libtransmission/handshake.h b/libtransmission/handshake.h index cef8a4c0f..9367367ea 100644 --- a/libtransmission/handshake.h +++ b/libtransmission/handshake.h @@ -31,7 +31,7 @@ struct tr_handshake; struct tr_handshake_result { struct tr_handshake* handshake; - tr_peerIo* io; + std::shared_ptr io; bool readAnythingFromPeer; bool isConnected; void* userData; @@ -77,13 +77,11 @@ using tr_handshake_done_func = bool (*)(tr_handshake_result const& result); /** @brief create a new handshake */ tr_handshake* tr_handshakeNew( std::shared_ptr mediator, - tr_peerIo* io, + std::shared_ptr io, tr_encryption_mode encryption_mode, tr_handshake_done_func done_func, void* done_func_user_data); void tr_handshakeAbort(tr_handshake* handshake); -tr_peerIo* tr_handshakeStealIO(tr_handshake* handshake); - /** @} */ diff --git a/libtransmission/peer-io.cc b/libtransmission/peer-io.cc index 6c28833eb..f86045e9b 100644 --- a/libtransmission/peer-io.cc +++ b/libtransmission/peer-io.cc @@ -26,7 +26,6 @@ #include "peer-io.h" #include "tr-assert.h" #include "tr-utp.h" -#include "trevent.h" /* tr_runInEventThread() */ #include "utils.h" #ifdef _WIN32 @@ -113,12 +112,11 @@ static void didWriteWrapper(tr_peerIo* io, unsigned int bytes_transferred) } } -static void canReadWrapper(tr_peerIo* io) +static void canReadWrapper(tr_peerIo* io_in) { + auto const io = io_in->shared_from_this(); tr_logAddTraceIo(io, "canRead"); - tr_peerIoRef(io); - tr_session const* const session = io->session; /* try to consume the input buffer */ @@ -134,7 +132,7 @@ static void canReadWrapper(tr_peerIo* io) { size_t piece = 0; size_t const oldLen = evbuffer_get_length(io->inbuf.get()); - int const ret = io->canRead(io, io->userData, &piece); + int const ret = io->canRead(io.get(), io->userData, &piece); size_t const used = oldLen - evbuffer_get_length(io->inbuf.get()); unsigned int const overhead = guessPacketOverhead(used); @@ -175,12 +173,8 @@ static void canReadWrapper(tr_peerIo* io) err = true; break; } - - TR_ASSERT(tr_isPeerIo(io)); } } - - tr_peerIoUnref(io); } static void event_read_cb(evutil_socket_t fd, short /*event*/, void* vio) @@ -491,7 +485,7 @@ static uint64 utp_callback(utp_callback_arguments* args) #endif /* #ifdef WITH_UTP */ -tr_peerIo* tr_peerIo::create( +std::shared_ptr tr_peerIo::create( tr_session* session, tr_bandwidth* parent, tr_address const* addr, @@ -519,7 +513,9 @@ tr_peerIo* tr_peerIo::create( maybeSetCongestionAlgorithm(socket.handle.tcp, session->peerCongestionAlgorithm()); } - auto* io = new tr_peerIo{ session, torrent_hash, is_incoming, *addr, port, is_seed, current_time, parent }; + auto io = std::shared_ptr{ + new tr_peerIo{ session, torrent_hash, is_incoming, *addr, port, is_seed, current_time, parent } + }; io->socket = socket; io->bandwidth().setPeer(io); tr_logAddTraceIo(io, fmt::format("bandwidth is {}; its parent is {}", fmt::ptr(&io->bandwidth()), fmt::ptr(parent))); @@ -528,15 +524,15 @@ tr_peerIo* tr_peerIo::create( { case TR_PEER_SOCKET_TYPE_TCP: tr_logAddTraceIo(io, fmt::format("socket (tcp) is {}", socket.handle.tcp)); - io->event_read = event_new(session->eventBase(), socket.handle.tcp, EV_READ, event_read_cb, io); - io->event_write = event_new(session->eventBase(), socket.handle.tcp, EV_WRITE, event_write_cb, io); + io->event_read = event_new(session->eventBase(), socket.handle.tcp, EV_READ, event_read_cb, io.get()); + io->event_write = event_new(session->eventBase(), socket.handle.tcp, EV_WRITE, event_write_cb, io.get()); break; #ifdef WITH_UTP case TR_PEER_SOCKET_TYPE_UTP: tr_logAddTraceIo(io, fmt::format("socket (utp) is {}", fmt::ptr(socket.handle.utp))); - utp_set_userdata(socket.handle.utp, io); + utp_set_userdata(socket.handle.utp, io.get()); break; #endif @@ -563,7 +559,7 @@ void tr_peerIo::utpInit([[maybe_unused]] struct_utp_context* ctx) #endif } -tr_peerIo* tr_peerIo::newIncoming( +std::shared_ptr tr_peerIo::newIncoming( tr_session* session, tr_bandwidth* parent, tr_address const* addr, @@ -577,7 +573,7 @@ tr_peerIo* tr_peerIo::newIncoming( return tr_peerIo::create(session, parent, addr, port, current_time, nullptr, true, false, socket); } -tr_peerIo* tr_peerIo::newOutgoing( +std::shared_ptr tr_peerIo::newOutgoing( tr_session* session, tr_bandwidth* parent, tr_address const* addr, @@ -620,7 +616,6 @@ tr_peerIo* tr_peerIo::newOutgoing( static void event_enable(tr_peerIo* io, short event) { - TR_ASSERT(tr_amInEventThread(io->session)); TR_ASSERT(io->session != nullptr); TR_ASSERT(io->session->events != nullptr); @@ -659,8 +654,6 @@ static void event_enable(tr_peerIo* io, short event) static void event_disable(tr_peerIo* io, short event) { - TR_ASSERT(tr_amInEventThread(io->session)); - TR_ASSERT(io->session != nullptr); TR_ASSERT(io->session->events != nullptr); bool const need_events = io->socket.type == TR_PEER_SOCKET_TYPE_TCP; @@ -699,8 +692,6 @@ static void event_disable(tr_peerIo* io, short event) void tr_peerIo::setEnabled(tr_direction dir, bool is_enabled) { TR_ASSERT(tr_isDirection(dir)); - TR_ASSERT(tr_amInEventThread(session)); - TR_ASSERT(session->events != nullptr); short const event = dir == TR_UP ? EV_WRITE : EV_READ; @@ -757,55 +748,18 @@ static void io_close_socket(tr_peerIo* io) } } -static void io_dtor(tr_peerIo* const io) +tr_peerIo::~tr_peerIo() { - TR_ASSERT(tr_isPeerIo(io)); - TR_ASSERT(tr_amInEventThread(io->session)); - TR_ASSERT(io->session->events != nullptr); + auto const lock = session->unique_lock(); + TR_ASSERT(session->events != nullptr); - tr_logAddTraceIo(io, "in tr_peerIo destructor"); - event_disable(io, EV_READ | EV_WRITE); - io_close_socket(io); + this->canRead = nullptr; + this->didWrite = nullptr; + this->gotError = nullptr; - io->magic_number = ~0; - delete io; -} - -static void tr_peerIoFree(tr_peerIo* io) -{ - if (io != nullptr) - { - tr_logAddTraceIo(io, "in tr_peerIoFree"); - io->canRead = nullptr; - io->didWrite = nullptr; - io->gotError = nullptr; - tr_runInEventThread(io->session, io_dtor, io); - } -} - -void tr_peerIoRefImpl(char const* file, int line, tr_peerIo* io) -{ - TR_ASSERT(tr_isPeerIo(io)); - - tr_logAddTraceIo( - io, - fmt::format("{}:{} incrementing the IO's refcount from {} to {}", file, line, io->refCount, io->refCount + 1)); - - ++io->refCount; -} - -void tr_peerIoUnrefImpl(char const* file, int line, tr_peerIo* io) -{ - TR_ASSERT(tr_isPeerIo(io)); - - tr_logAddTraceIo( - io, - fmt::format("{}:{} decrementing the IO's refcount from {} to {}", file, line, io->refCount, io->refCount - 1)); - - if (--io->refCount == 0) - { - tr_peerIoFree(io); - } + tr_logAddTraceIo(this, "in tr_peerIo destructor"); + event_disable(this, EV_READ | EV_WRITE); + io_close_socket(this); } std::string tr_peerIo::addrStr() const diff --git a/libtransmission/peer-io.h b/libtransmission/peer-io.h index 7a67fdb0d..d67c78696 100644 --- a/libtransmission/peer-io.h +++ b/libtransmission/peer-io.h @@ -66,14 +66,23 @@ struct evbuffer_deleter using tr_evbuffer_ptr = std::unique_ptr; -class tr_peerIo +namespace libtransmission::test +{ + +class HandshakeTest; + +} // namespace libtransmission::test + +class tr_peerIo final : public std::enable_shared_from_this { using DH = tr_message_stream_encryption::DH; using Filter = tr_message_stream_encryption::Filter; public: + ~tr_peerIo(); + // TODO: 8 constructor args is too many; maybe a builder object? - static tr_peerIo* newOutgoing( + static std::shared_ptr newOutgoing( tr_session* session, tr_bandwidth* parent, struct tr_address const* addr, @@ -83,7 +92,7 @@ public: bool is_seed, bool utp); - static tr_peerIo* newIncoming( + static std::shared_ptr newIncoming( tr_session* session, tr_bandwidth* parent, struct tr_address const* addr, @@ -91,19 +100,6 @@ public: time_t current_time, struct tr_peer_socket const socket); - // this is only public for testing purposes. - // production code should use newOutgoing() or newIncoming() - static tr_peerIo* create( - tr_session* session, - tr_bandwidth* parent, - tr_address const* addr, - tr_port port, - time_t current_time, - tr_sha1_digest_t const* torrent_hash, - bool is_incoming, - bool is_seed, - struct tr_peer_socket const socket); - void clear(); void readBytes(void* bytes, size_t byte_count); @@ -218,7 +214,7 @@ public: bandwidth_.setParent(parent); } - [[nodiscard]] constexpr auto isIncoming() noexcept + [[nodiscard]] constexpr auto isIncoming() const noexcept { return is_incoming_; } @@ -235,12 +231,6 @@ public: void setCallbacks(tr_can_read_cb readcb, tr_did_write_cb writecb, tr_net_error_cb errcb, void* user_data); - // TODO(ckerr): yikes, unlike other class' magic_numbers it looks - // like this one isn't being used just for assertions, but also in - // didWriteWrapper() to see if the tr_peerIo got freed during the - // notify-consumed events. Fix this before removing this field. - int magic_number = PEER_IO_MAGIC_NUMBER; - struct tr_peer_socket socket = {}; tr_session* const session; @@ -260,9 +250,6 @@ public: struct event* event_read = nullptr; struct event* event_write = nullptr; - // TODO: use std::shared_ptr instead of manual refcounting? - int refCount = 1; - short int pendingEvents = 0; tr_priority_t priority = TR_PRI_NORMAL; @@ -297,6 +284,21 @@ public: static void utpInit(struct_utp_context* ctx); private: + friend class libtransmission::test::HandshakeTest; + + // this is only public for testing purposes. + // production code should use newOutgoing() or newIncoming() + static std::shared_ptr create( + tr_session* session, + tr_bandwidth* parent, + tr_address const* addr, + tr_port port, + time_t current_time, + tr_sha1_digest_t const* torrent_hash, + bool is_incoming, + bool is_seed, + struct tr_peer_socket const socket); + tr_peerIo( tr_session* session_in, tr_sha1_digest_t const* torrent_hash, @@ -347,24 +349,11 @@ private: bool fast_extension_supported_ = false; }; -void tr_peerIoRefImpl(char const* file, int line, tr_peerIo* io); - -#define tr_peerIoRef(io) tr_peerIoRefImpl(__FILE__, __LINE__, (io)) - -void tr_peerIoUnrefImpl(char const* file, int line, tr_peerIo* io); - -#define tr_peerIoUnref(io) tr_peerIoUnrefImpl(__FILE__, __LINE__, (io)) - constexpr bool tr_isPeerIo(tr_peerIo const* io) { - return io != nullptr && io->magic_number == PEER_IO_MAGIC_NUMBER && io->refCount >= 0 && - tr_address_is_valid(&io->address()); + return io != nullptr && tr_address_is_valid(&io->address()); } -/** -*** -**/ - void evbuffer_add_uint8(struct evbuffer* outbuf, uint8_t addme); void evbuffer_add_uint16(struct evbuffer* outbuf, uint16_t hs); void evbuffer_add_uint32(struct evbuffer* outbuf, uint32_t hl); @@ -373,5 +362,3 @@ void evbuffer_add_uint64(struct evbuffer* outbuf, uint64_t hll); void evbuffer_add_hton_16(struct evbuffer* buf, uint16_t val); void evbuffer_add_hton_32(struct evbuffer* buf, uint32_t val); void evbuffer_add_hton_64(struct evbuffer* buf, uint64_t val); - -/* @} */ diff --git a/libtransmission/peer-mgr.cc b/libtransmission/peer-mgr.cc index 83c54b307..315cbf8bc 100644 --- a/libtransmission/peer-mgr.cc +++ b/libtransmission/peer-mgr.cc @@ -1118,7 +1118,7 @@ static struct peer_atom* ensureAtomExists( return tor->max_connected_peers; } -static void createBitTorrentPeer(tr_torrent* tor, tr_peerIo* io, struct peer_atom* atom, tr_quark client) +static void createBitTorrentPeer(tr_torrent* tor, std::shared_ptr io, struct peer_atom* atom, tr_quark client) { TR_ASSERT(atom != nullptr); TR_ASSERT(tr_isTorrent(tor)); @@ -1126,7 +1126,7 @@ static void createBitTorrentPeer(tr_torrent* tor, tr_peerIo* io, struct peer_ato tr_swarm* swarm = tor->swarm; - auto* peer = tr_peerMsgsNew(tor, atom, io, peerCallbackFunc, swarm); + auto* peer = tr_peerMsgsNew(tor, atom, std::move(io), peerCallbackFunc, swarm); peer->client = client; atom->is_connected = true; @@ -1232,10 +1232,8 @@ static bool on_handshake_done(tr_handshake_result const& result) client = tr_quark_new(std::data(buf)); } - /* this steals its refcount too, which is balanced by our unref in peerDelete() */ - tr_peerIo* stolen = tr_handshakeStealIO(result.handshake); - stolen->setParent(&s->tor->bandwidth_); - createBitTorrentPeer(s->tor, stolen, atom, client); + result.io->setParent(&s->tor->bandwidth_); + createBitTorrentPeer(s->tor, result.io, atom, client); success = true; } @@ -1263,11 +1261,8 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const* addr, tr_port else /* we don't have a connection to them yet... */ { auto mediator = std::make_shared(*session); - tr_peerIo* const io = tr_peerIo::newIncoming(session, &session->top_bandwidth_, addr, port, tr_time(), socket); - tr_handshake* const handshake = tr_handshakeNew(mediator, io, session->encryptionMode(), on_handshake_done, manager); - - tr_peerIoUnref(io); /* balanced by the implicit ref in tr_peerIo::NewIncoming() */ - + auto io = tr_peerIo::newIncoming(session, &session->top_bandwidth_, addr, port, tr_time(), socket); + auto* const handshake = tr_handshakeNew(mediator, std::move(io), session->encryptionMode(), on_handshake_done, manager); manager->incoming_handshakes.add(*addr, handshake); } } @@ -2836,7 +2831,7 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom) tr_logAddTraceSwarm(s, fmt::format("Starting an OUTGOING {} connection with {}", utp ? " µTP" : "TCP", atom.readable())); - tr_peerIo* const io = tr_peerIo::newOutgoing( + auto io = tr_peerIo::newOutgoing( mgr->session, &mgr->session->top_bandwidth_, &atom.addr, @@ -2855,12 +2850,12 @@ void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, peer_atom& atom) else { auto mediator = std::make_shared(*mgr->session); - tr_handshake* handshake = tr_handshakeNew(mediator, io, mgr->session->encryptionMode(), on_handshake_done, mgr); - - TR_ASSERT(io->torrentHash()); - - tr_peerIoUnref(io); /* balanced by the initial ref in tr_peerIo::newOutgoing() */ - + auto* const handshake = tr_handshakeNew( + mediator, + std::move(io), + mgr->session->encryptionMode(), + on_handshake_done, + mgr); s->outgoing_handshakes.add(atom.addr, handshake); } diff --git a/libtransmission/peer-msgs.cc b/libtransmission/peer-msgs.cc index 9e5b290b4..03a7b1914 100644 --- a/libtransmission/peer-msgs.cc +++ b/libtransmission/peer-msgs.cc @@ -248,12 +248,17 @@ static void updateDesiredRequestCount(tr_peerMsgsImpl* msgs); class tr_peerMsgsImpl final : public tr_peerMsgs { public: - tr_peerMsgsImpl(tr_torrent* torrent_in, peer_atom* atom_in, tr_peerIo* io_in, tr_peer_callback callback, void* callbackData) + tr_peerMsgsImpl( + tr_torrent* torrent_in, + peer_atom* atom_in, + std::shared_ptr io_in, + tr_peer_callback callback, + void* callbackData) : tr_peerMsgs{ torrent_in, atom_in } , outMessagesBatchPeriod{ LowPriorityIntervalSecs } , torrent{ torrent_in } , outMessages{ evbuffer_new() } - , io{ io_in } + , io{ std::move(io_in) } , have_{ torrent_in->pieceCount() } , callback_{ callback } , callbackData_{ callbackData } @@ -300,10 +305,9 @@ public: set_active(TR_UP, false); set_active(TR_DOWN, false); - if (this->io != nullptr) + if (this->io) { this->io->clear(); - tr_peerIoUnref(this->io); /* balanced by the ref in handshakeDoneCB() */ } evbuffer_free(this->outMessages); @@ -816,7 +820,7 @@ public: evbuffer* const outMessages; /* all the non-piece messages */ - tr_peerIo* const io; + std::shared_ptr const io; struct QueuedPeerRequest : public peer_request { @@ -864,9 +868,14 @@ private: static auto constexpr SendPexInterval = 90s; }; -tr_peerMsgs* tr_peerMsgsNew(tr_torrent* torrent, peer_atom* atom, tr_peerIo* io, tr_peer_callback callback, void* callback_data) +tr_peerMsgs* tr_peerMsgsNew( + tr_torrent* torrent, + peer_atom* atom, + std::shared_ptr io, + tr_peer_callback callback, + void* callback_data) { - return new tr_peerMsgsImpl(torrent, atom, io, callback, callback_data); + return new tr_peerMsgsImpl(torrent, atom, std::move(io), callback, callback_data); } /** @@ -2377,7 +2386,7 @@ static void peerPulse(void* vmsgs) auto* msgs = static_cast(vmsgs); time_t const now = tr_time(); - if (tr_isPeerIo(msgs->io)) + if (msgs->io) { updateDesiredRequestCount(msgs); updateBlockRequests(msgs); diff --git a/libtransmission/peer-msgs.h b/libtransmission/peer-msgs.h index 93a889bce..aa3c52022 100644 --- a/libtransmission/peer-msgs.h +++ b/libtransmission/peer-msgs.h @@ -12,6 +12,7 @@ #include // int8_t #include // size_t #include // time_t +#include #include #include "bitfield.h" @@ -77,7 +78,7 @@ protected: tr_peerMsgs* tr_peerMsgsNew( tr_torrent* torrent, peer_atom* atom, - tr_peerIo* io, + std::shared_ptr io, tr_peer_callback callback, void* callback_data); diff --git a/tests/libtransmission/handshake-test.cc b/tests/libtransmission/handshake-test.cc index 995d79c51..860efe7e5 100644 --- a/tests/libtransmission/handshake-test.cc +++ b/tests/libtransmission/handshake-test.cc @@ -33,197 +33,199 @@ namespace test auto constexpr MaxWaitMsec = int{ 5000 }; -using HandshakeTest = SessionTest; - -class MediatorMock final : public tr_handshake_mediator +class HandshakeTest : public SessionTest { public: - explicit MediatorMock(tr_session* session) - : session_{ session } + class MediatorMock final : public tr_handshake_mediator { - } - - virtual ~MediatorMock() = default; - - [[nodiscard]] std::optional torrentInfo(tr_sha1_digest_t const& info_hash) const override - { - if (auto const iter = torrents.find(info_hash); iter != std::end(torrents)) + public: + explicit MediatorMock(tr_session* session) + : session_{ session } { - return iter->second; } - return {}; - } + virtual ~MediatorMock() = default; - [[nodiscard]] std::optional torrentInfoFromObfuscated(tr_sha1_digest_t const& obfuscated) const override - { - for (auto const& [info_hash, info] : torrents) + [[nodiscard]] std::optional torrentInfo(tr_sha1_digest_t const& info_hash) const override { - if (obfuscated == tr_sha1::digest("req2"sv, info.info_hash)) + if (auto const iter = torrents.find(info_hash); iter != std::end(torrents)) { - return info; + return iter->second; } + + return {}; } - return {}; - } + [[nodiscard]] std::optional torrentInfoFromObfuscated(tr_sha1_digest_t const& obfuscated) const override + { + for (auto const& [info_hash, info] : torrents) + { + if (obfuscated == tr_sha1::digest("req2"sv, info.info_hash)) + { + return info; + } + } - [[nodiscard]] std::unique_ptr createTimer() override - { - return session_->timerMaker().create(); - } + return {}; + } - [[nodiscard]] bool isDHTEnabled() const override - { - return false; - } + [[nodiscard]] std::unique_ptr createTimer() override + { + return session_->timerMaker().create(); + } - [[nodiscard]] bool allowsTCP() const override - { - return true; - } + [[nodiscard]] bool isDHTEnabled() const override + { + return false; + } - [[nodiscard]] bool isPeerKnownSeed(tr_torrent_id_t /*tor_id*/, tr_address /*addr*/) const override - { - return false; - } + [[nodiscard]] bool allowsTCP() const override + { + return true; + } - [[nodiscard]] size_t pad(void* setme, [[maybe_unused]] size_t maxlen) const override - { - TR_ASSERT(maxlen > 10); - auto const len = size_t{ 10 }; - std::fill_n(static_cast(setme), 10, ' '); - return len; - } + [[nodiscard]] bool isPeerKnownSeed(tr_torrent_id_t /*tor_id*/, tr_address /*addr*/) const override + { + return false; + } - [[nodiscard]] tr_message_stream_encryption::DH::private_key_bigend_t privateKey() const override - { - return private_key_; - } + [[nodiscard]] size_t pad(void* setme, [[maybe_unused]] size_t maxlen) const override + { + TR_ASSERT(maxlen > 10); + auto const len = size_t{ 10 }; + std::fill_n(static_cast(setme), 10, ' '); + return len; + } - void setUTPFailed(tr_sha1_digest_t const& /*info_hash*/, tr_address /*addr*/) override - { - } + [[nodiscard]] tr_message_stream_encryption::DH::private_key_bigend_t privateKey() const override + { + return private_key_; + } - void setPrivateKeyFromBase64(std::string_view b64) - { - auto const str = tr_base64_decode(b64); - assert(std::size(str) == std::size(private_key_)); - std::copy_n(reinterpret_cast(std::data(str)), std::size(str), std::begin(private_key_)); - } + void setUTPFailed(tr_sha1_digest_t const& /*info_hash*/, tr_address /*addr*/) override + { + } - tr_session* const session_; - std::map torrents; - tr_message_stream_encryption::DH::private_key_bigend_t private_key_ = {}; -}; + void setPrivateKeyFromBase64(std::string_view b64) + { + auto const str = tr_base64_decode(b64); + assert(std::size(str) == std::size(private_key_)); + std::copy_n(reinterpret_cast(std::data(str)), std::size(str), std::begin(private_key_)); + } -template -void sendToClient(evutil_socket_t sock, Span const& data) -{ - auto const* walk = std::data(data); - static_assert(sizeof(*walk) == 1); - size_t len = std::size(data); - - while (len > 0) - { -#if defined(_WIN32) - auto const n = send(sock, reinterpret_cast(walk), len, 0); -#else - auto const n = write(sock, walk, len); -#endif - assert(n >= 0); - len -= n; - walk += n; - } -} - -void sendB64ToClient(evutil_socket_t sock, std::string_view b64) -{ - sendToClient(sock, tr_base64_decode(b64)); -} - -auto constexpr ReservedBytesNoExtensions = std::array{ 0, 0, 0, 0, 0, 0, 0, 0 }; -auto constexpr PlaintextProtocolName = "\023BitTorrent protocol"sv; -auto const DefaultPeerAddr = *tr_address::fromString("127.0.0.1"sv); -auto const DefaultPeerPort = tr_port::fromHost(8080); -auto const TorrentWeAreSeeding = tr_handshake_mediator::torrent_info{ tr_sha1::digest("abcde"sv), - tr_peerIdInit(), - tr_torrent_id_t{ 100 }, - true /*is_done*/ }; -auto const UbuntuTorrent = tr_handshake_mediator::torrent_info{ *tr_sha1_from_string( - "2c6b6858d61da9543d4231a71db4b1c9264b0685"sv), - tr_peerIdInit(), - tr_torrent_id_t{ 101 }, - false /*is_done*/ }; - -auto createIncomingIo(tr_session* session) -{ - auto sockpair = std::array{ -1, -1 }; - EXPECT_EQ(0, evutil_socketpair(LOCAL_SOCKETPAIR_AF, SOCK_STREAM, 0, std::data(sockpair))) << tr_strerror(errno); - auto const now = tr_time(); - auto const peer_socket = tr_peer_socket_tcp_create(sockpair[0]); - auto* const - io = tr_peerIo::newIncoming(session, &session->top_bandwidth_, &DefaultPeerAddr, DefaultPeerPort, now, peer_socket); - return std::make_pair(io, sockpair[1]); -} - -auto createOutgoingIo(tr_session* session, tr_sha1_digest_t const& info_hash) -{ - auto sockpair = std::array{ -1, -1 }; - EXPECT_EQ(0, evutil_socketpair(LOCAL_SOCKETPAIR_AF, SOCK_STREAM, 0, std::data(sockpair))) << tr_strerror(errno); - auto const now = tr_time(); - auto const peer_socket = tr_peer_socket_tcp_create(sockpair[0]); - auto* const io = tr_peerIo::create( - session, - &session->top_bandwidth_, - &DefaultPeerAddr, - DefaultPeerPort, - now, - &info_hash, - false /*is_incoming*/, - false /*is_seed*/, - peer_socket); - return std::make_pair(io, sockpair[1]); -} - -constexpr auto makePeerId(std::string_view sv) -{ - auto peer_id = tr_peer_id_t{}; - for (size_t i = 0, n = std::size(sv); i < n; ++i) - { - peer_id[i] = sv[i]; - } - return peer_id; -} - -auto makeRandomPeerId() -{ - auto peer_id = tr_peer_id_t{}; - tr_rand_buffer(std::data(peer_id), std::size(peer_id)); - auto const peer_id_prefix = "-UW110Q-"sv; - std::copy(std::begin(peer_id_prefix), std::end(peer_id_prefix), std::begin(peer_id)); - return peer_id; -} - -auto runHandshake( - std::shared_ptr mediator, - tr_peerIo* io, - tr_encryption_mode encryption_mode = TR_CLEAR_PREFERRED) -{ - auto result = std::optional{}; - - static auto const DoneCallback = [](auto const& resin) - { - *static_cast*>(resin.userData) = resin; - return true; + tr_session* const session_; + std::map torrents; + tr_message_stream_encryption::DH::private_key_bigend_t private_key_ = {}; }; - tr_handshakeNew(std::move(mediator), io, encryption_mode, DoneCallback, &result); + template + void sendToClient(evutil_socket_t sock, Span const& data) + { + auto const* walk = std::data(data); + static_assert(sizeof(*walk) == 1); + size_t len = std::size(data); - waitFor([&result]() { return result.has_value(); }, MaxWaitMsec); + while (len > 0) + { +#if defined(_WIN32) + auto const n = send(sock, reinterpret_cast(walk), len, 0); +#else + auto const n = write(sock, walk, len); +#endif + assert(n >= 0); + len -= n; + walk += n; + } + } - return result; -} + void sendB64ToClient(evutil_socket_t sock, std::string_view b64) + { + sendToClient(sock, tr_base64_decode(b64)); + } + + static auto constexpr ReservedBytesNoExtensions = std::array{ 0, 0, 0, 0, 0, 0, 0, 0 }; + static auto constexpr PlaintextProtocolName = "\023BitTorrent protocol"sv; + + tr_address const DefaultPeerAddr = *tr_address::fromString("127.0.0.1"sv); + tr_port const DefaultPeerPort = tr_port::fromHost(8080); + tr_handshake_mediator::torrent_info const TorrentWeAreSeeding{ tr_sha1::digest("abcde"sv), + tr_peerIdInit(), + tr_torrent_id_t{ 100 }, + true /*is_done*/ }; + tr_handshake_mediator::torrent_info const UbuntuTorrent{ *tr_sha1_from_string("2c6b6858d61da9543d4231a71db4b1c9264b0685"sv), + tr_peerIdInit(), + tr_torrent_id_t{ 101 }, + false /*is_done*/ }; + + auto createIncomingIo(tr_session* session) + { + auto sockpair = std::array{ -1, -1 }; + EXPECT_EQ(0, evutil_socketpair(LOCAL_SOCKETPAIR_AF, SOCK_STREAM, 0, std::data(sockpair))) << tr_strerror(errno); + auto const now = tr_time(); + auto const peer_socket = tr_peer_socket_tcp_create(sockpair[0]); + auto + io = tr_peerIo::newIncoming(session, &session->top_bandwidth_, &DefaultPeerAddr, DefaultPeerPort, now, peer_socket); + return std::make_pair(io, sockpair[1]); + } + + auto createOutgoingIo(tr_session* session, tr_sha1_digest_t const& info_hash) + { + auto sockpair = std::array{ -1, -1 }; + EXPECT_EQ(0, evutil_socketpair(LOCAL_SOCKETPAIR_AF, SOCK_STREAM, 0, std::data(sockpair))) << tr_strerror(errno); + auto const now = tr_time(); + auto const peer_socket = tr_peer_socket_tcp_create(sockpair[0]); + auto io = tr_peerIo::create( + session, + &session->top_bandwidth_, + &DefaultPeerAddr, + DefaultPeerPort, + now, + &info_hash, + false /*is_incoming*/, + false /*is_seed*/, + peer_socket); + return std::make_pair(io, sockpair[1]); + } + + static constexpr auto makePeerId(std::string_view sv) + { + auto peer_id = tr_peer_id_t{}; + for (size_t i = 0, n = std::size(sv); i < n; ++i) + { + peer_id[i] = sv[i]; + } + return peer_id; + } + + static auto makeRandomPeerId() + { + auto peer_id = tr_peer_id_t{}; + tr_rand_buffer(std::data(peer_id), std::size(peer_id)); + auto const peer_id_prefix = "-UW110Q-"sv; + std::copy(std::begin(peer_id_prefix), std::end(peer_id_prefix), std::begin(peer_id)); + return peer_id; + } + + static auto runHandshake( + std::shared_ptr mediator, + std::shared_ptr io, + tr_encryption_mode encryption_mode = TR_CLEAR_PREFERRED) + { + auto result = std::optional{}; + + static auto const DoneCallback = [](auto const& resin) + { + *static_cast*>(resin.userData) = resin; + return true; + }; + + tr_handshakeNew(std::move(mediator), std::move(io), encryption_mode, DoneCallback, &result); + + waitFor([&result]() { return result.has_value(); }, MaxWaitMsec); + + return result; + } +}; TEST_F(HandshakeTest, incomingPlaintext) { @@ -257,7 +259,6 @@ TEST_F(HandshakeTest, incomingPlaintext) EXPECT_TRUE(io->torrentHash()); EXPECT_EQ(TorrentWeAreSeeding.info_hash, *io->torrentHash()); - tr_peerIoUnref(io); evutil_closesocket(sock); } @@ -284,7 +285,6 @@ TEST_F(HandshakeTest, incomingPlaintextUnknownInfoHash) EXPECT_FALSE(res->peer_id); EXPECT_FALSE(io->torrentHash()); - tr_peerIoUnref(io); evutil_closesocket(sock); } @@ -313,7 +313,6 @@ TEST_F(HandshakeTest, outgoingPlaintext) EXPECT_EQ(UbuntuTorrent.info_hash, *io->torrentHash()); EXPECT_EQ(tr_sha1_to_string(UbuntuTorrent.info_hash), tr_sha1_to_string(*io->torrentHash())); - tr_peerIoUnref(io); evutil_closesocket(sock); } @@ -353,7 +352,6 @@ TEST_F(HandshakeTest, incomingEncrypted) EXPECT_EQ(UbuntuTorrent.info_hash, *io->torrentHash()); EXPECT_EQ(tr_sha1_to_string(UbuntuTorrent.info_hash), tr_sha1_to_string(*io->torrentHash())); - tr_peerIoUnref(io); evutil_closesocket(sock); } @@ -387,7 +385,6 @@ TEST_F(HandshakeTest, incomingEncryptedUnknownInfoHash) EXPECT_TRUE(res->readAnythingFromPeer); EXPECT_FALSE(io->torrentHash()); - tr_peerIoUnref(io); evutil_closesocket(sock); } @@ -432,7 +429,6 @@ TEST_F(HandshakeTest, outgoingEncrypted) EXPECT_EQ(UbuntuTorrent.info_hash, *io->torrentHash()); EXPECT_EQ(tr_sha1_to_string(UbuntuTorrent.info_hash), tr_sha1_to_string(*io->torrentHash())); - tr_peerIoUnref(io); evutil_closesocket(sock); }