diff --git a/libtransmission/peer-mgr.cc b/libtransmission/peer-mgr.cc index 425b63034..3b949e542 100644 --- a/libtransmission/peer-mgr.cc +++ b/libtransmission/peer-mgr.cc @@ -12,7 +12,7 @@ #include // time_t #include // std::tie #include // std::back_inserter -#include +#include #include #include @@ -151,6 +151,59 @@ static char const* tr_atomAddrStr(struct peer_atom const* atom) return atom != nullptr ? tr_address_and_port_to_string(addrstr, sizeof(addrstr), &atom->addr, atom->port) : "[no atom]"; } +// a container for keeping track of tr_handshakes +class Handshakes +{ +public: + void add(tr_address const& address, tr_handshake* handshake) + { + TR_ASSERT(!contains(address)); + + handshakes_.emplace_back(std::make_pair(address, handshake)); + } + + [[nodiscard]] bool contains(tr_address const& address) const noexcept + { + return std::any_of( + std::begin(handshakes_), + std::end(handshakes_), + [&address](auto const& pair) { return pair.first == address; }); + } + + void erase(tr_address const& address) + { + for (auto iter = std::begin(handshakes_), end = std::end(handshakes_); iter != end; ++iter) + { + if (iter->first == address) + { + handshakes_.erase(iter); + return; + } + } + } + + [[nodiscard]] auto empty() const noexcept + { + return std::empty(handshakes_); + } + + void abortAll() + { + // make a tmp copy so that calls to tr_handshakeAbort() won't + // be able to invalidate its loop iteration + auto tmp = handshakes_; + for (auto& [addr, handshake] : tmp) + { + tr_handshakeAbort(handshake); + } + + handshakes_ = {}; + } + +private: + std::vector> handshakes_; +}; + /** @brief Opaque, per-torrent data structure for peer connection information */ class tr_swarm { @@ -164,7 +217,7 @@ public: public: tr_swarm_stats stats = {}; - std::map outgoing_handshakes; + Handshakes outgoing_handshakes; tr_ptrArray pool = {}; /* struct peer_atom */ tr_ptrArray peers = {}; /* tr_peerMsgs */ std::vector> webseeds; @@ -201,7 +254,7 @@ struct tr_peerMgr } tr_session* const session; - std::map incoming_handshakes; + Handshakes incoming_handshakes; event* bandwidthTimer = nullptr; event* rechokeTimer = nullptr; event* refillUpkeepTimer = nullptr; @@ -299,8 +352,8 @@ static bool peerIsInUse(tr_swarm const* cs, struct peer_atom const* atom) auto const* const s = const_cast(cs); auto const lock = s->manager->unique_lock(); - return atom->peer != nullptr || s->outgoing_handshakes.count(atom->addr) != 0 || - s->manager->incoming_handshakes.count(atom->addr) != 0; + return atom->peer != nullptr || s->outgoing_handshakes.contains(atom->addr) || + s->manager->incoming_handshakes.contains(atom->addr); } static void swarmFree(tr_swarm* s) @@ -377,12 +430,7 @@ void tr_peerMgrFree(tr_peerMgr* manager) deleteTimers(manager); - /* free the handshakes. Abort invokes handshakeDoneCB(), which removes - * the item from manager->handshakes, so this is a little roundabout... */ - while (!std::empty(manager->incoming_handshakes)) - { - tr_handshakeAbort(std::begin(manager->incoming_handshakes)->second); - } + manager->incoming_handshakes.abortAll(); delete manager; } @@ -1069,7 +1117,7 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const* addr, tr_port tr_logAddTrace(fmt::format("Banned IP address '{}' tried to connect to us", tr_address_to_string(addr))); tr_netClosePeerSocket(session, socket); } - else if (manager->incoming_handshakes.count(*addr) > 0) + else if (manager->incoming_handshakes.contains(*addr)) { tr_netClosePeerSocket(session, socket); } @@ -1080,7 +1128,7 @@ void tr_peerMgrAddIncoming(tr_peerMgr* manager, tr_address const* addr, tr_port tr_peerIoUnref(io); /* balanced by the implicit ref in tr_peerIoNewIncoming() */ - manager->incoming_handshakes.try_emplace(*addr, handshake); + manager->incoming_handshakes.add(*addr, handshake); } } @@ -1408,12 +1456,7 @@ static void stopSwarm(tr_swarm* swarm) removeAllPeers(swarm); - /* disconnect the handshakes. handshakeAbort calls handshakeDoneCB(), - * which removes the handshake from t->outgoing_handshakes... */ - while (!std::empty(swarm->outgoing_handshakes)) - { - tr_handshakeAbort(std::begin(swarm->outgoing_handshakes)->second); - } + swarm->outgoing_handshakes.abortAll(); } void tr_peerMgrStopTorrent(tr_torrent* tor) @@ -3023,7 +3066,7 @@ static void initiateConnection(tr_peerMgr* mgr, tr_swarm* s, struct peer_atom* a tr_peerIoUnref(io); /* balanced by the initial ref in tr_peerIoNewOutgoing() */ - s->outgoing_handshakes.try_emplace(atom->addr, handshake); + s->outgoing_handshakes.add(atom->addr, handshake); } atom->lastConnectionAttemptAt = now;