From 9e06cf8f2e23f9aa708d7b6c16f8fd13bda39399 Mon Sep 17 00:00:00 2001 From: Charles Kerr Date: Fri, 11 Nov 2022 10:09:24 -0600 Subject: [PATCH] refactor: make DHT unblocking (#4122) --- libtransmission/session.cc | 66 +- libtransmission/session.h | 71 +- libtransmission/torrent.cc | 2 - libtransmission/torrent.h | 3 - libtransmission/tr-dht.cc | 1252 ++++++++++--------------- libtransmission/tr-dht.h | 98 +- libtransmission/tr-udp.cc | 59 +- tests/libtransmission/CMakeLists.txt | 2 + tests/libtransmission/dht-test.cc | 643 +++++++++++++ tests/libtransmission/test-fixtures.h | 8 + tests/libtransmission/timer-test.cc | 60 +- 11 files changed, 1413 insertions(+), 851 deletions(-) create mode 100644 tests/libtransmission/dht-test.cc diff --git a/libtransmission/session.cc b/libtransmission/session.cc index fbbe93600..33bb27939 100644 --- a/libtransmission/session.cc +++ b/libtransmission/session.cc @@ -108,9 +108,44 @@ tr_peer_id_t tr_peerIdInit() return peer_id; } -/*** -**** -***/ +/// + +std::vector tr_session::DhtMediator::torrentsAllowingDHT() const +{ + auto ids = std::vector{}; + auto const& torrents = session_.torrents(); + + ids.reserve(std::size(torrents)); + for (auto const* const tor : torrents) + { + if (tor->isRunning && tor->allowsDht()) + { + ids.push_back(tor->id()); + } + } + + return ids; +} + +tr_sha1_digest_t tr_session::DhtMediator::torrentInfoHash(tr_torrent_id_t id) const +{ + if (auto const* const tor = session_.torrents().get(id); tor != nullptr) + { + return tor->infoHash(); + } + + return {}; +} + +void tr_session::DhtMediator::addPex(tr_sha1_digest_t const& info_hash, tr_pex const* pex, size_t n_pex) +{ + if (auto* const tor = session_.torrents().get(info_hash); tor != nullptr) + { + tr_peerMgrAddPex(tor, TR_PEER_FROM_DHT, pex, n_pex); + } +} + +/// bool tr_session::LpdMediator::onPeerFound(std::string_view info_hash_str, tr_address address, tr_port port) { @@ -468,7 +503,6 @@ void tr_session::onNowTimer() // tr_session upkeep tasks to perform once per second tr_timeUpdate(time(nullptr)); - udp_core_->dhtUpkeep(); alt_speeds_.checkScheduler(); // TODO: this seems a little silly. Why do we increment this @@ -607,24 +641,28 @@ void tr_session::setSettings(tr_session_settings settings_in, bool force) port_changed = true; } + bool addr_changed = false; if (new_settings.tcp_enabled) { if (auto const& val = new_settings.bind_address_ipv4; force || port_changed || val != old_settings.bind_address_ipv4) { auto const [addr, is_default] = publicAddress(TR_AF_INET); bound_ipv4_.emplace(eventBase(), addr, local_peer_port_, &tr_session::onIncomingPeerConnection, this); + addr_changed = true; } if (auto const& val = new_settings.bind_address_ipv6; force || port_changed || val != old_settings.bind_address_ipv6) { auto const [addr, is_default] = publicAddress(TR_AF_INET6); bound_ipv6_.emplace(eventBase(), addr, local_peer_port_, &tr_session::onIncomingPeerConnection, this); + addr_changed = true; } } else { bound_ipv4_.reset(); bound_ipv6_.reset(); + addr_changed = true; } if (port_changed) @@ -653,6 +691,15 @@ void tr_session::setSettings(tr_session_settings settings_in, bool force) } } + if (!allowsDHT()) + { + dht_.reset(); + } + else if (force || !dht_ || port_changed || addr_changed || dht_changed) + { + dht_ = tr_dht::create(dht_mediator_, localPeerPort(), udp_core_->socket4(), udp_core_->socket6()); + } + // We need to update bandwidth if speed settings changed. // It's a harmless call, so just call it instead of checking for settings changes updateBandwidth(this, TR_UP); @@ -801,6 +848,14 @@ tr_port_forwarding_state tr_sessionGetPortForwarding(tr_session const* session) return session->port_forwarding_->state(); } +void tr_session::onAdvertisedPeerPortChanged() +{ + for (auto* const tor : torrents()) + { + tr_torrentChangeMyPort(tor); + } +} + /*** **** ***/ @@ -1161,13 +1216,14 @@ void tr_session::closeImplPart1(std::promise* closed_promise) save_timer_.reset(); now_timer_.reset(); rpc_server_.reset(); + dht_.reset(); lpd_.reset(); + port_forwarding_.reset(); bound_ipv6_.reset(); bound_ipv4_.reset(); // tell other items to start shutting down - udp_core_->startShutdown(); announcer_udp_->startShutdown(); // Close the torrents in order of most active to least active diff --git a/libtransmission/session.h b/libtransmission/session.h index 3b76c3cd0..3094a1f48 100644 --- a/libtransmission/session.h +++ b/libtransmission/session.h @@ -42,6 +42,7 @@ #include "session-thread.h" #include "stats.h" #include "torrents.h" +#include "tr-dht.h" #include "tr-lpd.h" #include "utils-ev.h" #include "verify.h" @@ -150,6 +151,36 @@ private: tr_session& session_; }; + class DhtMediator : public tr_dht::Mediator + { + public: + DhtMediator(tr_session& session) noexcept + : session_{ session } + { + } + + ~DhtMediator() noexcept override = default; + + [[nodiscard]] std::vector torrentsAllowingDHT() const override; + + [[nodiscard]] tr_sha1_digest_t torrentInfoHash(tr_torrent_id_t id) const override; + + [[nodiscard]] std::string_view configDir() const override + { + return session_.config_dir_; + } + + [[nodiscard]] libtransmission::TimerMaker& timerMaker() override + { + return session_.timerMaker(); + } + + void addPex(tr_sha1_digest_t const&, tr_pex const* pex, size_t n_pex) override; + + private: + tr_session& session_; + }; + class PortForwardingMediator final : public tr_port_forwarding::Mediator { public: @@ -175,7 +206,11 @@ private: void onPortForwarded(tr_port public_port) override { - session_.advertised_peer_port_ = public_port; + if (session_.advertised_peer_port_ != public_port) + { + session_.advertised_peer_port_ = public_port; + session_.onAdvertisedPeerPortChanged(); + } } private: @@ -241,12 +276,21 @@ private: { public: tr_udp_core(tr_session& session, tr_port udp_port); - ~tr_udp_core(); - static void startShutdown(); - static void dhtUpkeep(); + void sendto(void const* buf, size_t buflen, struct sockaddr const* to, socklen_t const tolen) const; + [[nodiscard]] constexpr auto socket4() const noexcept + { + return udp_socket_; + } + + [[nodiscard]] constexpr auto socket6() const noexcept + { + return udp6_socket_; + } + + private: void set_socket_buffers(); void set_socket_tos() @@ -255,11 +299,6 @@ private: session_.setSocketTOS(udp6_socket_, TR_AF_INET6); } - void sendto(void const* buf, size_t buflen, struct sockaddr const* to, socklen_t const tolen) const; - - void addDhtNode(tr_address const& addr, tr_port port); - - private: tr_port const udp_port_; tr_session& session_; tr_socket_t udp_socket_ = TR_BAD_SOCKET; @@ -847,9 +886,9 @@ public: void addDhtNode(tr_address const& addr, tr_port port) { - if (udp_core_) + if (dht_) { - udp_core_->addDhtNode(addr, port); + dht_->addNode(addr, port); } } @@ -886,7 +925,7 @@ private: [[nodiscard]] tr_port randomPort() const; - void setPeerPort(tr_port port); + void onAdvertisedPeerPortChanged(); struct init_data; void initImpl(init_data&); @@ -1061,7 +1100,7 @@ private: std::optional bound_ipv6_; public: - // depends-on: announcer_udp_ + // depends-on: settings_, announcer_udp_ // FIXME(ckerr): circular dependency udp_core -> announcer_udp -> announcer_udp_mediator -> udp_core std::unique_ptr udp_core_; @@ -1104,6 +1143,9 @@ private: // depends-on: udp_core_ AnnouncerUdpMediator announcer_udp_mediator_{ *this }; + // depends-on: timer_maker_, torrents_, peer_mgr_ + DhtMediator dht_mediator_{ *this }; + public: // depends-on: announcer_udp_mediator_ std::unique_ptr announcer_udp_ = tr_announcer_udp::create(announcer_udp_mediator_); @@ -1111,6 +1153,9 @@ public: // depends-on: settings_, torrents_, web_, announcer_udp_ struct tr_announcer* announcer = nullptr; + // depends-on: public_peer_port_, udp_core_, dht_mediator_ + std::unique_ptr dht_; + private: // depends-on: session_thread_, timer_maker_, settings_, torrents_, web_ std::unique_ptr rpc_server_; diff --git a/libtransmission/torrent.cc b/libtransmission/torrent.cc index 4b815348b..1a37f37f3 100644 --- a/libtransmission/torrent.cc +++ b/libtransmission/torrent.cc @@ -1312,8 +1312,6 @@ static void torrentStartImpl(tr_torrent* const tor) tr_torrentResetTransferStats(tor); tr_announcerTorrentStarted(tor); - tor->dhtAnnounceAt = now + tr_rand_int_weak(20); - tor->dhtAnnounce6At = now + tr_rand_int_weak(20); tor->lpdAnnounceAt = now; tr_peerMgrStartTorrent(tor); } diff --git a/libtransmission/torrent.h b/libtransmission/torrent.h index 62f88de0f..6e7eb05dd 100644 --- a/libtransmission/torrent.h +++ b/libtransmission/torrent.h @@ -748,9 +748,6 @@ public: time_t peer_id_creation_time_ = 0; - time_t dhtAnnounceAt = 0; - time_t dhtAnnounce6At = 0; - time_t lpdAnnounceAt = 0; time_t activityDate = 0; diff --git a/libtransmission/tr-dht.cc b/libtransmission/tr-dht.cc index 1eeb7c1d5..d41757330 100644 --- a/libtransmission/tr-dht.cc +++ b/libtransmission/tr-dht.cc @@ -9,15 +9,14 @@ #include // for abort() #include // for memcpy() #include +#include #include +#include #include -#include #include #include #include -#include #include // for std::tie() -#include #ifdef _WIN32 #include @@ -31,8 +30,6 @@ #include /* sockaddr_in */ #endif -#include - #include #include "transmission.h" @@ -41,756 +38,20 @@ #include "file.h" #include "log.h" #include "net.h" -#include "peer-mgr.h" -#include "session.h" +#include "peer-mgr.h" // for tr_peerMgrCompactToPex() #include "timer.h" -#include "torrent.h" #include "tr-assert.h" #include "tr-dht.h" #include "tr-strbuf.h" #include "variant.h" -#include "utils.h" // tr_time(), _() +#include "utils.h" // for tr_time(), _() using namespace std::literals; -namespace -{ -struct Impl -{ - std::unique_ptr timer; - std::array id = {}; - tr_socket_t udp4_socket = TR_BAD_SOCKET; - tr_socket_t udp6_socket = TR_BAD_SOCKET; - tr_session* session = nullptr; -}; - -Impl impl = {}; -} // namespace - -// mutex-locked wrapper around libdht's API -namespace locked_dht -{ -namespace -{ - -[[nodiscard]] auto unique_lock() -{ - static std::recursive_mutex dht_mutex; - return std::unique_lock(dht_mutex); -} - -} // namespace - -auto getNodes(struct sockaddr_in* sin, int* num, struct sockaddr_in6* sin6, int* num6) -{ - auto lock = unique_lock(); - return dht_get_nodes(sin, num, sin6, num6); -} - -auto init(int s, int s6, unsigned char const* id, unsigned char const* v) -{ - auto lock = unique_lock(); - return dht_init(s, s6, id, v); -} - -auto nodes(int af, int* good_return, int* dubious_return, int* cached_return, int* incoming_return) -{ - auto lock = unique_lock(); - return dht_nodes(af, good_return, dubious_return, cached_return, incoming_return); -} - -auto periodic( - void const* buf, - size_t buflen, - struct sockaddr const* from, - int fromlen, - time_t* tosleep, - dht_callback_t* callback, - void* closure) -{ - auto lock = unique_lock(); - return dht_periodic(buf, buflen, from, fromlen, tosleep, callback, closure); -} - -auto ping_node(struct sockaddr const* sa, int salen) -{ - auto lock = unique_lock(); - return dht_ping_node(sa, salen); -} - -auto search(unsigned char const* id, int port, int af, dht_callback_t* callback, void* closure) -{ - auto lock = unique_lock(); - return dht_search(id, port, af, callback, closure); -} - -auto uninit() -{ - auto lock = unique_lock(); - return dht_uninit(); -} - -} // namespace locked_dht - -enum class Status -{ - Stopped, - Broken, - Poor, - Firewalled, - Good -}; - -static constexpr std::string_view printableStatus(Status status) -{ - switch (status) - { - case Status::Stopped: - return "stopped"sv; - - case Status::Broken: - return "broken"sv; - - case Status::Poor: - return "poor"sv; - - case Status::Firewalled: - return "firewalled"sv; - - case Status::Good: - return "good"sv; - - default: - return "???"sv; - } -} - -bool tr_dhtEnabled() -{ - return impl.session != nullptr; -} - -static constexpr auto getUdpSocket(int af) -{ - switch (af) - { - case AF_INET: - return impl.udp4_socket; - - case AF_INET6: - return impl.udp6_socket; - - default: - return TR_BAD_SOCKET; - } -} - -static auto getStatus(int af, int* const setme_node_count = nullptr) -{ - if (getUdpSocket(af) == TR_BAD_SOCKET) - { - if (setme_node_count != nullptr) - { - *setme_node_count = 0; - } - - return Status::Stopped; - } - - int good = 0; - int dubious = 0; - int incoming = 0; - locked_dht::nodes(af, &good, &dubious, nullptr, &incoming); - - if (setme_node_count != nullptr) - { - *setme_node_count = good + dubious; - } - - if (good < 4 || good + dubious <= 8) - { - return Status::Broken; - } - - if (good < 40) - { - return Status::Poor; - } - - if (incoming < 8) - { - return Status::Firewalled; - } - - return Status::Good; -} - -static constexpr auto isReady(Status const status) -{ - return status >= Status::Firewalled; -} - -static auto isReady(int af) -{ - return isReady(getStatus(af)); -} - -static bool isBootstrapDone(int af = 0) -{ - if (af == 0) - { - return isBootstrapDone(AF_INET) && isBootstrapDone(AF_INET6); - } - - auto const status = getStatus(af, nullptr); - return status == Status::Stopped || isReady(status); -} - -static void nap(int roughly_sec) -{ - int const roughly_msec = roughly_sec * 1000; - int const msec = roughly_msec / 2 + tr_rand_int_weak(roughly_msec); - tr_wait_msec(msec); -} - -static int getBootstrappedAF() -{ - if (isBootstrapDone(AF_INET6)) - { - return AF_INET; - } - - if (isBootstrapDone(AF_INET)) - { - return AF_INET6; - } - - return 0; -} - -static void bootstrapFromName(char const* name, tr_port port, int af) -{ - auto hints = addrinfo{}; - hints.ai_socktype = SOCK_DGRAM; - hints.ai_family = af; - - auto port_str = std::array{}; - *fmt::format_to(std::data(port_str), FMT_STRING("{:d}"), port.host()) = '\0'; - - addrinfo* info = nullptr; - if (int const rc = getaddrinfo(name, std::data(port_str), &hints, &info); rc != 0) - { - tr_logAddWarn(fmt::format( - _("Couldn't look up '{address}:{port}': {error} ({error_code})"), - fmt::arg("address", name), - fmt::arg("port", port.host()), - fmt::arg("error", gai_strerror(rc)), - fmt::arg("error_code", rc))); - return; - } - - addrinfo* infop = info; - while (infop != nullptr) - { - locked_dht::ping_node(infop->ai_addr, infop->ai_addrlen); - - nap(15); - - if (isBootstrapDone(af)) - { - break; - } - - infop = infop->ai_next; - } - - freeaddrinfo(info); -} - -static void bootstrapFromFile(std::string_view config_dir) -{ - if (isBootstrapDone()) - { - return; - } - - // check for a manual bootstrap file. - auto in = std::ifstream{ tr_pathbuf{ config_dir, "/dht.bootstrap"sv } }; - if (!in.is_open()) - { - return; - } - - // format is each line has address, a space char, and port number - tr_logAddTrace("Attempting manual bootstrap"); - auto line = std::string{}; - while (!isBootstrapDone() && std::getline(in, line)) - { - auto line_stream = std::istringstream{ line }; - auto addrstr = std::string{}; - auto hport = uint16_t{}; - line_stream >> addrstr >> hport; - - if (line_stream.bad() || std::empty(addrstr)) - { - tr_logAddWarn(fmt::format(_("Couldn't parse line: '{line}'"), fmt::arg("line", line))); - } - else - { - bootstrapFromName(addrstr.c_str(), tr_port::fromHost(hport), getBootstrappedAF()); - } - } -} - -static void bootstrapStart(std::string_view config_dir, std::vector nodes4, std::vector nodes6) -{ - if (!tr_dhtEnabled()) - { - return; - } - - auto const num4 = std::size(nodes4) / 6; - if (num4 > 0) - { - tr_logAddDebug(fmt::format("Bootstrapping from {} IPv4 nodes", num4)); - } - - auto const num6 = std::size(nodes6) / 18; - if (num6 > 0) - { - tr_logAddDebug(fmt::format("Bootstrapping from {} IPv6 nodes", num6)); - } - - auto const* walk4 = std::data(nodes4); - auto const* walk6 = std::data(nodes6); - for (size_t i = 0; i < std::max(num4, num6); ++i) - { - if (i < num4 && !isBootstrapDone(AF_INET)) - { - auto addr = tr_address{}; - auto port = tr_port{}; - std::tie(addr, walk4) = tr_address::fromCompact4(walk4); - std::tie(port, walk4) = tr_port::fromCompact(walk4); - tr_dhtAddNode(addr, port, true); - } - - if (i < num6 && !isBootstrapDone(AF_INET6)) - { - auto addr = tr_address{}; - auto port = tr_port{}; - std::tie(addr, walk6) = tr_address::fromCompact6(walk6); - std::tie(port, walk6) = tr_port::fromCompact(walk6); - tr_dhtAddNode(addr, port, true); - } - - /* Our DHT code is able to take up to 9 nodes in a row without - dropping any. After that, it takes some time to split buckets. - So ping the first 8 nodes quickly, then slow down. */ - if (i < 8U) - { - nap(2); - } - else - { - nap(15); - } - - if (isBootstrapDone()) - { - break; - } - } - - if (!isBootstrapDone()) - { - bootstrapFromFile(config_dir); - } - - if (!isBootstrapDone()) - { - for (int i = 0; i < 6; ++i) - { - /* We don't want to abuse our bootstrap nodes, so be very - slow. The initial wait is to give other nodes a chance - to contact us before we attempt to contact a bootstrap - node, for example because we've just been restarted. */ - nap(40); - - if (isBootstrapDone()) - { - break; - } - - if (i == 0) - { - tr_logAddDebug("Attempting bootstrap from dht.transmissionbt.com"); - } - - bootstrapFromName("dht.transmissionbt.com", tr_port::fromHost(6881), getBootstrappedAF()); - } - } - - tr_logAddTrace("Finished bootstrapping"); -} - -int tr_dhtInit(tr_session* session, tr_socket_t udp4_socket, tr_socket_t udp6_socket) -{ - if (impl.session != nullptr) /* already initialized */ - { - return -1; - } - - tr_logAddInfo(_("Initializing DHT")); - - if (tr_env_key_exists("TR_DHT_VERBOSE")) - { - dht_debug = stderr; - } - - auto benc = tr_variant{}; - auto const dat_file = tr_pathbuf{ session->configDir(), "/dht.dat"sv }; - auto const ok = tr_variantFromFile(&benc, TR_VARIANT_PARSE_BENC, dat_file.sv()); - - bool have_id = false; - auto nodes = std::vector{}; - auto nodes6 = std::vector{}; - - if (ok) - { - auto sv = std::string_view{}; - have_id = tr_variantDictFindStrView(&benc, TR_KEY_id, &sv); - if (have_id && std::size(sv) == std::size(impl.id)) - { - std::copy(std::begin(sv), std::end(sv), std::data(impl.id)); - } - - size_t raw_len = 0U; - std::byte const* raw = nullptr; - - if (tr_variantDictFindRaw(&benc, TR_KEY_nodes, &raw, &raw_len) && raw_len % 6 == 0) - { - nodes.assign(raw, raw + raw_len); - } - - if (tr_variantDictFindRaw(&benc, TR_KEY_nodes6, &raw, &raw_len) && raw_len % 18 == 0) - { - nodes6.assign(raw, raw + raw_len); - } - - tr_variantClear(&benc); - } - - if (have_id) - { - tr_logAddTrace("Reusing old id"); - } - else - { - /* Note that DHT ids need to be distributed uniformly, - * so it should be something truly random. */ - tr_logAddTrace("Generating new id"); - tr_rand_buffer(std::data(impl.id), std::size(impl.id)); - } - - if (locked_dht::init(udp4_socket, udp6_socket, std::data(impl.id), nullptr) < 0) - { - auto const errcode = errno; - tr_logAddDebug(fmt::format("DHT initialization failed: {} ({})", tr_strerror(errcode), errcode)); - impl = {}; - return -1; - } - - impl.session = session; - impl.udp4_socket = udp4_socket; - impl.udp6_socket = udp4_socket; - - std::thread(bootstrapStart, std::string{ session->configDir() }, nodes, nodes6).detach(); - - static auto constexpr MinInterval = 10ms; - static auto constexpr MaxInterval = 1s; - auto const random_percent = tr_rand_int_weak(1000) / 1000.0; - auto interval = MinInterval + random_percent * (MaxInterval - MinInterval); - impl.timer = session->timerMaker().create([]() { tr_dhtCallback(nullptr, 0, nullptr, 0); }); - impl.timer->startSingleShot(std::chrono::duration_cast(interval)); - - tr_logAddDebug("DHT initialized"); - - return 1; -} - -void tr_dhtUninit() -{ - TR_ASSERT(tr_dhtEnabled()); - - tr_logAddTrace("Uninitializing DHT"); - - impl.timer.reset(); - - /* Since we only save known good nodes, - * avoid erasing older data if we don't know enough nodes. */ - if (!isReady(AF_INET) && !isReady(AF_INET6)) - { - tr_logAddTrace("Not saving nodes, DHT not ready"); - } - else - { - auto constexpr MaxNodes = int{ 300 }; - auto constexpr PortLen = size_t{ 2 }; - auto constexpr CompactAddrLen = size_t{ 4 }; - auto constexpr CompactLen = size_t{ CompactAddrLen + PortLen }; - auto constexpr Compact6AddrLen = size_t{ 16 }; - auto constexpr Compact6Len = size_t{ Compact6AddrLen + PortLen }; - - auto sins = std::array{}; - auto sins6 = std::array{}; - int num = MaxNodes; - int num6 = MaxNodes; - int const n = locked_dht::getNodes(std::data(sins), &num, std::data(sins6), &num6); - tr_logAddTrace(fmt::format("Saving {} ({} + {}) nodes", n, num, num6)); - - tr_variant benc; - tr_variantInitDict(&benc, 3); - tr_variantDictAddRaw(&benc, TR_KEY_id, std::data(impl.id), std::size(impl.id)); - - if (num > 0) - { - auto compact = std::array{}; - char* out = std::data(compact); - for (auto const* in = std::data(sins), *end = in + num; in != end; ++in) - { - memcpy(out, &in->sin_addr, CompactAddrLen); - out += CompactAddrLen; - memcpy(out, &in->sin_port, PortLen); - out += PortLen; - } - - tr_variantDictAddRaw(&benc, TR_KEY_nodes, std::data(compact), out - std::data(compact)); - } - - if (num6 > 0) - { - auto compact6 = std::array{}; - char* out6 = std::data(compact6); - for (auto const* in = std::data(sins6), *end = in + num6; in != end; ++in) - { - memcpy(out6, &in->sin6_addr, Compact6AddrLen); - out6 += Compact6AddrLen; - memcpy(out6, &in->sin6_port, PortLen); - out6 += PortLen; - } - - tr_variantDictAddRaw(&benc, TR_KEY_nodes6, std::data(compact6), out6 - std::data(compact6)); - } - - tr_variantToFile(&benc, TR_VARIANT_FMT_BENC, tr_pathbuf{ impl.session->configDir(), "/dht.dat"sv }); - tr_variantClear(&benc); - } - - locked_dht::uninit(); - - tr_logAddTrace("Done uninitializing DHT"); - - impl = {}; -} - -bool tr_dhtAddNode(tr_address addr, tr_port port, bool bootstrap) -{ - if (!tr_dhtEnabled()) - { - return false; - } - - /* Since we don't want to abuse our bootstrap nodes, - * we don't ping them if the DHT is in a good state. */ - - if (bootstrap && isReady(addr.isIPv4() ? AF_INET : AF_INET6)) - { - return false; - } - - if (addr.isIPv4()) - { - auto sin = sockaddr_in{}; - sin.sin_family = AF_INET; - sin.sin_addr = addr.addr.addr4; - sin.sin_port = port.network(); - locked_dht::ping_node((struct sockaddr*)&sin, sizeof(sin)); - return true; - } - - if (addr.isIPv6()) - { - auto sin6 = sockaddr_in6{}; - sin6.sin6_family = AF_INET6; - sin6.sin6_addr = addr.addr.addr6; - sin6.sin6_port = port.network(); - locked_dht::ping_node((struct sockaddr*)&sin6, sizeof(sin6)); - return true; - } - - return false; -} - -static void callback(void* vsession, int event, unsigned char const* info_hash, void const* data, size_t data_len) -{ - auto* const session = static_cast(vsession); - auto hash = tr_sha1_digest_t{}; - std::copy_n(reinterpret_cast(info_hash), std::size(hash), std::data(hash)); - auto const lock = session->unique_lock(); - auto* const tor = session->torrents().get(hash); - - if (event == DHT_EVENT_VALUES || event == DHT_EVENT_VALUES6) - { - if (tor != nullptr && tor->allowsDht()) - { - auto const pex = event == DHT_EVENT_VALUES ? tr_pex::fromCompact4(data, data_len, nullptr, 0) : - tr_pex::fromCompact6(data, data_len, nullptr, 0); - tr_peerMgrAddPex(tor, TR_PEER_FROM_DHT, std::data(pex), std::size(pex)); - tr_logAddDebugTor( - tor, - fmt::format("Learned {} {} peers from DHT", std::size(pex), event == DHT_EVENT_VALUES6 ? "IPv6" : "IPv4")); - } - } - else if (event == DHT_EVENT_SEARCH_DONE || event == DHT_EVENT_SEARCH_DONE6) - { - if (tor != nullptr) - { - if (event == DHT_EVENT_SEARCH_DONE) - { - tr_logAddTraceTor(tor, "IPv4 DHT announce done"); - } - else - { - tr_logAddTraceTor(tor, "IPv6 DHT announce done"); - } - } - } -} - -static bool announceTorrent(tr_torrent const* const tor, int af, bool announce, tr_port incoming_peer_port) -{ - TR_ASSERT(tor->allowsDht()); - - int numnodes = 0; - auto const status = getStatus(af, &numnodes); - if (status == Status::Stopped) - { - // let the caller believe everything is all right. - return true; - } - - if (status < Status::Poor) - { - tr_logAddTraceTor( - tor, - fmt::format( - "{} DHT not ready ({}, {} nodes)", - af == AF_INET6 ? "IPv6" : "IPv4", - printableStatus(status), - numnodes)); - return false; - } - - auto const* dht_hash = reinterpret_cast(std::data(tor->infoHash())); - auto const hport = announce ? incoming_peer_port.host() : 0; - int const rc = locked_dht::search(dht_hash, hport, af, callback, impl.session); - if (rc < 0) - { - auto const error_code = errno; - tr_logAddWarnTor( - tor, - fmt::format( - _("Unable to announce torrent in DHT with {type}: {error} ({error_code}); state is {state}"), - fmt::arg("type", af == AF_INET6 ? "IPv6" : "IPv4"), - fmt::arg("state", printableStatus(status)), - fmt::arg("error_code", error_code), - fmt::arg("error", tr_strerror(error_code)))); - return false; - } - - tr_logAddTraceTor( - tor, - fmt::format( - "Starting {} DHT announce ({}, {} nodes)", - af == AF_INET6 ? "IPv6" : "IPv4", - printableStatus(status), - numnodes)); - - return true; -} - -void tr_dhtUpkeep() -{ - TR_ASSERT(impl.session != nullptr); - - auto lock = impl.session->unique_lock(); - auto const now = tr_time(); - auto const incoming_peer_port = impl.session->advertisedPeerPort(); - - for (auto* const tor : impl.session->torrents()) - { - if (!tor->isRunning || !tor->allowsDht()) - { - continue; - } - - if (tor->dhtAnnounceAt <= now) - { - auto const ok = announceTorrent(tor, AF_INET, true, incoming_peer_port); - auto const interval = ok ? 25 * 60 + tr_rand_int_weak(3 * 60) : 5 + tr_rand_int_weak(5); - tor->dhtAnnounceAt = now + interval; - } - - if (tor->dhtAnnounce6At <= now) - { - auto const ok = announceTorrent(tor, AF_INET6, true, incoming_peer_port); - auto const interval = ok ? 25 * 60 + tr_rand_int_weak(3 * 60) : 5 + tr_rand_int_weak(5); - tor->dhtAnnounce6At = now + interval; - } - } -} - -void tr_dhtCallback(unsigned char* buf, int buflen, struct sockaddr* from, socklen_t fromlen) -{ - if (!tr_dhtEnabled()) - { - return; - } - - time_t tosleep = 0; - int const rc = locked_dht::periodic(buf, buflen, from, fromlen, &tosleep, callback, impl.session); - - if (rc < 0) - { - if (errno == EINTR) - { - tosleep = 0; - } - else - { - auto const errcode = errno; - tr_logAddDebug(fmt::format("dht_periodic failed: {} ({})", tr_strerror(errcode), errcode)); - if (errcode == EINVAL || errcode == EFAULT) - { - // TODO: maybe just turn it off instead of crashing? - abort(); - } - - tosleep = 1; - } - } - - // Being slightly late is fine, - // and has the added benefit of adding some jitter. - auto const random_percent = tr_rand_int_weak(1000) / 1000.0; - auto const min_interval = std::chrono::seconds{ tosleep }; - auto const max_interval = std::chrono::seconds{ tosleep + 1 }; - auto const interval = min_interval + random_percent * (max_interval - min_interval); - impl.timer->startSingleShot(std::chrono::duration_cast(interval)); -} - +// the dht library needs us to implement these: extern "C" { + // This function should return true when a node is blacklisted. // We don't support using a blacklist with the DHT in Transmission, // since massive (ab)use of this feature could harm the DHT. However, @@ -827,12 +88,12 @@ extern "C" { return -1; } - return size; + return static_cast(size); } int dht_sendto(int sockfd, void const* buf, int len, int flags, struct sockaddr const* to, int tolen) { - return sendto(sockfd, static_cast(buf), len, flags, to, tolen); + return static_cast(sendto(sockfd, static_cast(buf), len, flags, to, tolen)); } #if defined(_WIN32) && !defined(__MINGW32__) @@ -850,3 +111,500 @@ extern "C" #endif } // extern "C" + +class tr_dht_impl final : public tr_dht +{ +private: + using Node = std::pair; + using Nodes = std::deque; + using Id = std::array; + + enum class SwarmStatus + { + Stopped, + Broken, + Poor, + Firewalled, + Good + }; + +public: + tr_dht_impl(Mediator& mediator, tr_port peer_port, tr_socket_t udp4_socket, tr_socket_t udp6_socket) + : peer_port_{ peer_port } + , udp4_socket_{ udp4_socket } + , udp6_socket_{ udp6_socket } + , mediator_{ mediator } + , state_filename_{ tr_pathbuf{ mediator_.configDir(), "/dht.dat" } } + , announce_timer_{ mediator_.timerMaker().create([this]() { onAnnounceTimer(); }) } + , bootstrap_timer_{ mediator_.timerMaker().create([this]() { onBootstrapTimer(); }) } + , periodic_timer_{ mediator_.timerMaker().create([this]() { onPeriodicTimer(); }) } + { + tr_logAddDebug(fmt::format("Starting DHT on port {port}", fmt::arg("port", peer_port.host()))); + + // load up the bootstrap nodes + if (tr_sys_path_exists(state_filename_.c_str())) + { + std::tie(id_, bootstrap_queue_) = loadState(state_filename_); + } + getNodesFromBootstrapFile(tr_pathbuf{ mediator_.configDir(), "/dht.bootstrap"sv }, bootstrap_queue_); + getNodesFromName("dht.transmissionbt.com", tr_port::fromHost(6881), bootstrap_queue_); + bootstrap_timer_->startSingleShot(100ms); + + mediator_.api().init(udp4_socket_, udp6_socket_, std::data(id_), nullptr); + + onAnnounceTimer(); + announce_timer_->startRepeating(1s); + + onPeriodicTimer(); + } + + tr_dht_impl(tr_dht_impl&&) = delete; + tr_dht_impl(tr_dht_impl const&) = delete; + tr_dht_impl& operator=(tr_dht_impl&&) = delete; + tr_dht_impl& operator=(tr_dht_impl const&) = delete; + + ~tr_dht_impl() override + { + tr_logAddTrace("Uninitializing DHT"); + + // Since we only save known good nodes, + // only overwrite older data if we know enough nodes. + if (isReady(AF_INET) || isReady(AF_INET6)) + { + saveState(); + } + + mediator_.api().uninit(); + tr_logAddTrace("Done uninitializing DHT"); + } + + void addNode(tr_address const& addr, tr_port port) override + { + if (addr.isIPv4()) + { + auto sin = sockaddr_in{}; + sin.sin_family = AF_INET; + sin.sin_addr = addr.addr.addr4; + sin.sin_port = port.network(); + mediator_.api().ping_node((struct sockaddr*)&sin, sizeof(sin)); + } + else if (addr.isIPv6()) + { + auto sin6 = sockaddr_in6{}; + sin6.sin6_family = AF_INET6; + sin6.sin6_addr = addr.addr.addr6; + sin6.sin6_port = port.network(); + mediator_.api().ping_node((struct sockaddr*)&sin6, sizeof(sin6)); + } + } + + void handleMessage(unsigned char const* msg, size_t msglen, struct sockaddr* from, socklen_t fromlen) override + { + auto const call_again_in_n_secs = periodic(msg, msglen, from, fromlen); + + // Being slightly late is fine, + // and has the added benefit of adding some jitter. + auto const interval = call_again_in_n_secs + std::chrono::milliseconds{ tr_rand_int_weak(1000) }; + periodic_timer_->startSingleShot(interval); + } + +private: + [[nodiscard]] constexpr tr_socket_t udpSocket(int af) const noexcept + { + switch (af) + { + case AF_INET: + return udp4_socket_; + + case AF_INET6: + return udp6_socket_; + + default: + return TR_BAD_SOCKET; + } + } + + [[nodiscard]] SwarmStatus swarmStatus(int family, int* const setme_node_count = nullptr) const + { + if (udpSocket(family) == TR_BAD_SOCKET) + { + if (setme_node_count != nullptr) + { + *setme_node_count = 0; + } + + return SwarmStatus::Stopped; + } + + int good = 0; + int dubious = 0; + int incoming = 0; + mediator_.api().nodes(family, &good, &dubious, nullptr, &incoming); + + if (setme_node_count != nullptr) + { + *setme_node_count = good + dubious; + } + + if (good < 4 || good + dubious <= 8) + { + return SwarmStatus::Broken; + } + + if (good < 40) + { + return SwarmStatus::Poor; + } + + if (incoming < 8) + { + return SwarmStatus::Firewalled; + } + + return SwarmStatus::Good; + } + + [[nodiscard]] static constexpr auto isReady(SwarmStatus const status) + { + return status >= SwarmStatus::Firewalled; + } + + [[nodiscard]] bool isReady(int af) const noexcept + { + return isReady(swarmStatus(af)); + } + + [[nodiscard]] bool isReady() const noexcept + { + return isReady(AF_INET) && isReady(AF_INET6); + } + + /// + + // how long to wait between adding nodes during bootstrap + [[nodiscard]] static constexpr auto bootstrapInterval(size_t n_added) + { + // Our DHT code is able to take up to 9 nodes in a row without + // dropping any. After that, it takes some time to split buckets. + // So ping the first 8 nodes quickly, then slow down. + if (n_added < 8U) + { + return 2s; + } + + if (n_added < 16U) + { + return 15s; + } + + return 40s; + } + + void onBootstrapTimer() + { + // Since we don't want to abuse our bootstrap nodes, + // we don't ping them if the DHT is in a good state. + if (isReady() || std::empty(bootstrap_queue_)) + { + return; + } + + auto [address, port] = bootstrap_queue_.front(); + bootstrap_queue_.pop_front(); + addNode(address, port); + ++n_bootstrapped_; + + bootstrap_timer_->startSingleShot(bootstrapInterval(n_bootstrapped_)); + } + + /// + + [[nodiscard]] auto announceTorrent(tr_sha1_digest_t const& info_hash, int af, tr_port port) + { + auto const* dht_hash = reinterpret_cast(std::data(info_hash)); + auto const rc = mediator_.api().search(dht_hash, port.host(), af, callback, this); + auto const announce_again_in_n_secs = rc < 0 ? 5s + std::chrono::seconds{ tr_rand_int_weak(5) } : + 25min + std::chrono::seconds{ tr_rand_int_weak(3 * 60) }; + return announce_again_in_n_secs; + } + + void onAnnounceTimer() + { + // don't announce if the swarm isn't ready + if (swarmStatus(AF_INET) < SwarmStatus::Poor && swarmStatus(AF_INET6) < SwarmStatus::Poor) + { + return; + } + + auto const now = tr_time(); + for (auto const id : mediator_.torrentsAllowingDHT()) + { + auto& times = announce_times_[id]; + + if (auto& announce_after = times.ipv4_announce_after; announce_after < now) + { + auto const announce_again_in_n_secs = announceTorrent(mediator_.torrentInfoHash(id), AF_INET, peer_port_); + announce_after = now + std::chrono::seconds{ announce_again_in_n_secs }.count(); + } + + if (auto& announce_after = times.ipv6_announce_after; announce_after < now) + { + auto const announce_again_in_n_secs = announceTorrent(mediator_.torrentInfoHash(id), AF_INET6, peer_port_); + announce_after = now + std::chrono::seconds{ announce_again_in_n_secs }.count(); + } + } + } + + /// + + void onPeriodicTimer() + { + auto const call_again_in_n_secs = periodic(nullptr, 0, nullptr, 0); + + // Being slightly late is fine, + // and has the added benefit of adding some jitter. + auto const interval = call_again_in_n_secs + std::chrono::milliseconds{ tr_rand_int_weak(1000) }; + periodic_timer_->startSingleShot(interval); + } + + [[nodiscard]] std::chrono::seconds periodic( + unsigned char const* msg, + size_t msglen, + struct sockaddr const* from, + socklen_t fromlen) + { + TR_ASSERT_MSG(msglen == 0 || msg[msglen] == '\0', "libdht requires zero-terminated msg"); + + auto call_again_in_n_secs = time_t{}; + mediator_.api().periodic(msg, msglen, from, static_cast(fromlen), &call_again_in_n_secs, callback, this); + return std::chrono::seconds{ call_again_in_n_secs }; + } + + static void callback(void* vself, int event, unsigned char const* info_hash, void const* data, size_t data_len) + { + auto* const self = static_cast(vself); + auto hash = tr_sha1_digest_t{}; + std::copy_n(reinterpret_cast(info_hash), std::size(hash), std::data(hash)); + + if (event == DHT_EVENT_VALUES) + { + auto const pex = tr_pex::fromCompact4(data, data_len, nullptr, 0); + self->mediator_.addPex(hash, std::data(pex), std::size(pex)); + } + else if (event == DHT_EVENT_VALUES6) + { + auto const pex = tr_pex::fromCompact6(data, data_len, nullptr, 0); + self->mediator_.addPex(hash, std::data(pex), std::size(pex)); + } + } + + /// + + void saveState() const + { + auto constexpr MaxNodes = int{ 300 }; + auto constexpr PortLen = size_t{ 2 }; + auto constexpr CompactAddrLen = size_t{ 4 }; + auto constexpr CompactLen = size_t{ CompactAddrLen + PortLen }; + auto constexpr Compact6AddrLen = size_t{ 16 }; + auto constexpr Compact6Len = size_t{ Compact6AddrLen + PortLen }; + + auto sins4 = std::array{}; + auto sins6 = std::array{}; + auto num4 = int{ MaxNodes }; + auto num6 = int{ MaxNodes }; + auto const n = mediator_.api().get_nodes(std::data(sins4), &num4, std::data(sins6), &num6); + tr_logAddTrace(fmt::format("Saving {} ({} + {}) nodes", n, num4, num6)); + + tr_variant benc; + tr_variantInitDict(&benc, 3); + tr_variantDictAddRaw(&benc, TR_KEY_id, std::data(id_), std::size(id_)); + + if (num4 > 0) + { + auto compact = std::array{}; + char* out = std::data(compact); + for (auto const* in = std::data(sins4), *end = in + num4; in != end; ++in) + { + memcpy(out, &in->sin_addr, CompactAddrLen); + out += CompactAddrLen; + memcpy(out, &in->sin_port, PortLen); // saved in network byte order + out += PortLen; + } + + tr_variantDictAddRaw(&benc, TR_KEY_nodes, std::data(compact), out - std::data(compact)); + } + + if (num6 > 0) + { + auto compact6 = std::array{}; + char* out6 = std::data(compact6); + for (auto const* in = std::data(sins6), *end = in + num6; in != end; ++in) + { + memcpy(out6, &in->sin6_addr, Compact6AddrLen); + out6 += Compact6AddrLen; + memcpy(out6, &in->sin6_port, PortLen); // saved in network byte order + out6 += PortLen; + } + + tr_variantDictAddRaw(&benc, TR_KEY_nodes6, std::data(compact6), out6 - std::data(compact6)); + } + + tr_variantToFile(&benc, TR_VARIANT_FMT_BENC, state_filename_); + tr_variantClear(&benc); + } + + [[nodiscard]] static std::pair loadState(std::string_view filename) + { + // Note that DHT ids need to be distributed uniformly, + // so it should be something truly random + auto id = Id{}; + tr_rand_buffer(std::data(id), std::size(id)); + + auto nodes = Nodes{}; + + if (auto dict = tr_variant{}; tr_variantFromFile(&dict, TR_VARIANT_PARSE_BENC, filename)) + { + if (auto sv = std::string_view{}; + tr_variantDictFindStrView(&dict, TR_KEY_id, &sv) && std::size(sv) == std::size(id)) + { + std::copy(std::begin(sv), std::end(sv), std::begin(id)); + } + + size_t raw_len = 0U; + std::byte const* raw = nullptr; + if (tr_variantDictFindRaw(&dict, TR_KEY_nodes, &raw, &raw_len) && raw_len % 6 == 0) + { + auto* walk = raw; + auto const* const end = raw + raw_len; + while (walk < end) + { + auto addr = tr_address{}; + auto port = tr_port{}; + std::tie(addr, walk) = tr_address::fromCompact4(walk); + std::tie(port, walk) = tr_port::fromCompact(walk); + nodes.emplace_back(addr, port); + } + } + + if (tr_variantDictFindRaw(&dict, TR_KEY_nodes6, &raw, &raw_len) && raw_len % 18 == 0) + { + auto* walk = raw; + auto const* const end = raw + raw_len; + while (walk < end) + { + auto addr = tr_address{}; + auto port = tr_port{}; + std::tie(addr, walk) = tr_address::fromCompact6(walk); + std::tie(port, walk) = tr_port::fromCompact(walk); + nodes.emplace_back(addr, port); + } + } + + tr_variantClear(&dict); + } + + return std::make_pair(id, nodes); + } + + /// + + static void getNodesFromBootstrapFile(std::string_view filename, Nodes& nodes) + { + auto in = std::ifstream{ std::string{ filename } }; + if (!in.is_open()) + { + return; + } + + // format is each line has host, a space char, and port number + auto line = std::string{}; + while (std::getline(in, line)) + { + auto line_stream = std::istringstream{ line }; + auto addrstr = std::string{}; + auto hport = uint16_t{}; + line_stream >> addrstr >> hport; + + if (line_stream.bad() || std::empty(addrstr)) + { + tr_logAddWarn(fmt::format( + _("Couldn't parse '{filename}' line: '{line}'"), + fmt::arg("filename", filename), + fmt::arg("line", line))); + } + else + { + getNodesFromName(addrstr.c_str(), tr_port::fromHost(hport), nodes); + } + } + } + + static void getNodesFromName(char const* name, tr_port port_in, Nodes& nodes) + { + auto hints = addrinfo{}; + hints.ai_socktype = SOCK_DGRAM; + hints.ai_family = AF_UNSPEC; + hints.ai_protocol = 0; + hints.ai_flags = 0; + + auto port_str = std::array{}; + *fmt::format_to(std::data(port_str), FMT_STRING("{:d}"), port_in.host()) = '\0'; + + addrinfo* info = nullptr; + if (int const rc = getaddrinfo(name, std::data(port_str), &hints, &info); rc != 0) + { + tr_logAddWarn(fmt::format( + _("Couldn't look up '{address}:{port}': {error} ({error_code})"), + fmt::arg("address", name), + fmt::arg("port", port_in.host()), + fmt::arg("error", gai_strerror(rc)), + fmt::arg("error_code", rc))); + return; + } + + for (auto* infop = info; infop != nullptr; infop = infop->ai_next) + { + if (auto addrport = tr_address::fromSockaddr(infop->ai_addr); addrport) + { + nodes.emplace_back(addrport->first, addrport->second); + } + } + + freeaddrinfo(info); + } + + /// + + tr_port const peer_port_; + tr_socket_t const udp4_socket_; + tr_socket_t const udp6_socket_; + + Mediator& mediator_; + std::string const state_filename_; + std::unique_ptr const announce_timer_; + std::unique_ptr const bootstrap_timer_; + std::unique_ptr const periodic_timer_; + + Id id_ = {}; + + Nodes bootstrap_queue_; + size_t n_bootstrapped_ = 0; + + struct AnnounceInfo + { + time_t ipv4_announce_after = 0; + time_t ipv6_announce_after = 0; + }; + + std::map announce_times_; +}; + +[[nodiscard]] std::unique_ptr tr_dht::create( + Mediator& mediator, + tr_port peer_port, + tr_socket_t udp4_socket, + tr_socket_t udp6_socket) +{ + return std::make_unique(mediator, peer_port, udp4_socket, udp6_socket); +} diff --git a/libtransmission/tr-dht.h b/libtransmission/tr-dht.h index afcab6288..03118185b 100644 --- a/libtransmission/tr-dht.h +++ b/libtransmission/tr-dht.h @@ -8,17 +8,103 @@ #error only libtransmission should #include this header. #endif +#include #include +#include + +#include #include "transmission.h" #include "net.h" // tr_port -int tr_dhtInit(tr_session*, tr_socket_t udp4_socket, tr_socket_t udp6_socket); -void tr_dhtUninit(); +struct tr_pex; -bool tr_dhtEnabled(); +namespace libtransmission +{ +class TimerMaker; +} // namespace libtransmission -bool tr_dhtAddNode(tr_address, tr_port, bool bootstrap); -void tr_dhtUpkeep(); -void tr_dhtCallback(unsigned char* buf, int buflen, struct sockaddr* from, socklen_t fromlen); +class tr_dht +{ +public: + // Wrapper around DHT library. + // This calls `jech/dht` in production, but makes it possible for tests to inject a mock. + struct API + { + virtual ~API() = default; + + virtual int get_nodes(struct sockaddr_in* sin, int* num, struct sockaddr_in6* sin6, int* num6) + { + return ::dht_get_nodes(sin, num, sin6, num6); + } + + virtual int nodes(int af, int* good_return, int* dubious_return, int* cached_return, int* incoming_return) + { + return ::dht_nodes(af, good_return, dubious_return, cached_return, incoming_return); + } + + virtual int periodic( + void const* buf, + size_t buflen, + struct sockaddr const* from, + int fromlen, + time_t* tosleep, + dht_callback_t callback, + void* closure) + { + return ::dht_periodic(buf, buflen, from, fromlen, tosleep, callback, closure); + } + + virtual int ping_node(struct sockaddr const* sa, int salen) + { + return ::dht_ping_node(sa, salen); + } + + virtual int search(unsigned char const* id, int port, int af, dht_callback_t callback, void* closure) + { + return ::dht_search(id, port, af, callback, closure); + } + + virtual int init(int s, int s6, unsigned const char* id, unsigned const char* v) + { + return ::dht_init(s, s6, id, v); + } + + virtual int uninit() + { + return ::dht_uninit(); + } + }; + + class Mediator + { + public: + virtual ~Mediator() = default; + + [[nodiscard]] virtual std::vector torrentsAllowingDHT() const = 0; + [[nodiscard]] virtual tr_sha1_digest_t torrentInfoHash(tr_torrent_id_t) const = 0; + + [[nodiscard]] virtual std::string_view configDir() const = 0; + [[nodiscard]] virtual libtransmission::TimerMaker& timerMaker() = 0; + [[nodiscard]] virtual API& api() + { + return api_; + } + + virtual void addPex(tr_sha1_digest_t const&, tr_pex const* pex, size_t n_pex) = 0; + + private: + API api_; + }; + + [[nodiscard]] static std::unique_ptr create( + Mediator& mediator, + tr_port peer_port, + tr_socket_t udp4_socket, + tr_socket_t udp6_socket); + virtual ~tr_dht() = default; + + virtual void addNode(tr_address const& address, tr_port port) = 0; + virtual void handleMessage(unsigned char const* msg, size_t msglen, struct sockaddr* from, socklen_t fromlen) = 0; +}; diff --git a/libtransmission/tr-udp.cc b/libtransmission/tr-udp.cc index bb57b4203..1f4a6ace8 100644 --- a/libtransmission/tr-udp.cc +++ b/libtransmission/tr-udp.cc @@ -7,12 +7,6 @@ #include #include /* memcmp(), memset() */ -#ifdef _WIN32 -#include /* dup2() */ -#else -#include /* dup2() */ -#endif - #include #include @@ -22,7 +16,6 @@ #include "net.h" #include "session.h" #include "tr-assert.h" -#include "tr-dht.h" #include "tr-utp.h" #include "utils.h" @@ -90,11 +83,6 @@ static void set_socket_buffers(tr_socket_t fd, bool large) } } -void tr_session::tr_udp_core::addDhtNode(tr_address const& addr, tr_port port) -{ - tr_dhtAddNode(addr, port, false); -} - void tr_session::tr_udp_core::set_socket_buffers() { bool const utp = session_.allowsUTP(); @@ -188,13 +176,16 @@ static void event_callback(evutil_socket_t s, [[maybe_unused]] short type, void* TR_ASSERT(vsession != nullptr); TR_ASSERT(type == EV_READ); - auto buf = std::array{}; + auto buf = std::array{}; auto from = sockaddr_storage{}; - auto* session = static_cast(vsession); - - socklen_t fromlen = sizeof(from); - auto const - rc = recvfrom(s, reinterpret_cast(std::data(buf)), std::size(buf) - 1, 0, (struct sockaddr*)&from, &fromlen); + auto fromlen = socklen_t{ sizeof(from) }; + auto const rc = recvfrom( + s, + reinterpret_cast(std::data(buf)), + std::size(buf) - 1, + 0, + reinterpret_cast(&from), + &fromlen); /* Since most packets we receive here are µTP, make quick inline checks for the other protocols. The logic is as follows: @@ -203,14 +194,15 @@ static void event_callback(evutil_socket_t s, [[maybe_unused]] short type, void* is between 0 and 3 - the above cannot be µTP packets, since these start with a 4-bit version number (1). */ + auto* session = static_cast(vsession); if (rc > 0) { if (buf[0] == 'd') { - if (session->allowsDHT()) + if (session->dht_) { - buf[rc] = '\0'; /* required by the DHT code */ - tr_dhtCallback(std::data(buf), rc, (struct sockaddr*)&from, fromlen); + buf[rc] = '\0'; // libdht requires zero-terminated messages + session->dht_->handleMessage(std::data(buf), rc, reinterpret_cast(&from), fromlen); } } else if (rc >= 8 && buf[0] == 0 && buf[1] == 0 && buf[2] == 0 && buf[3] <= 3) @@ -294,12 +286,7 @@ tr_session::tr_udp_core::tr_udp_core(tr_session& session, tr_port udp_port) set_socket_buffers(); set_socket_tos(); - if (session_.allowsDHT()) - { - tr_dhtInit(&session_, udp_socket_, udp6_socket_); - } - - if (udp4_event_) + if (udp4_event_ != nullptr) { event_add(udp4_event_.get(), nullptr); } @@ -309,26 +296,8 @@ tr_session::tr_udp_core::tr_udp_core(tr_session& session, tr_port udp_port) } } -void tr_session::tr_udp_core::dhtUpkeep() -{ - if (tr_dhtEnabled()) - { - tr_dhtUpkeep(); - } -} - -void tr_session::tr_udp_core::startShutdown() -{ - if (tr_dhtEnabled()) - { - tr_dhtUninit(); - } -} - tr_session::tr_udp_core::~tr_udp_core() { - startShutdown(); - udp6_event_.reset(); if (udp6_socket_ != TR_BAD_SOCKET) diff --git a/tests/libtransmission/CMakeLists.txt b/tests/libtransmission/CMakeLists.txt index f87eace09..064e88b4d 100644 --- a/tests/libtransmission/CMakeLists.txt +++ b/tests/libtransmission/CMakeLists.txt @@ -13,6 +13,7 @@ add_executable(libtransmission-test crypto-test-ref.h crypto-test.cc error-test.cc + dht-test.cc file-piece-map-test.cc file-test.cc getopt-test.cc @@ -65,6 +66,7 @@ target_include_directories(libtransmission-test SYSTEM ${WIDE_INTEGER_INCLUDE_DIRS} ${B64_INCLUDE_DIRS} ${CURL_INCLUDE_DIRS} + ${DHT_INCLUDE_DIRS} ${EVENT2_INCLUDE_DIRS}) target_compile_options(libtransmission-test diff --git a/tests/libtransmission/dht-test.cc b/tests/libtransmission/dht-test.cc new file mode 100644 index 000000000..68d10f561 --- /dev/null +++ b/tests/libtransmission/dht-test.cc @@ -0,0 +1,643 @@ +// This file Copyright (C) 2022 Mnemosyne LLC. +// It may be used under GPLv2 (SPDX: GPL-2.0-only), GPLv3 (SPDX: GPL-3.0-only), +// or any future license endorsed by Mnemosyne LLC. +// License text can be found in the licenses/ folder. + +#include +#include +#include +#include +#include + +#include + +#include "transmission.h" + +#include "file.h" +#include "timer-ev.h" +#include "session-thread.h" // for tr_evthread_init(); + +#include "gtest/gtest.h" +#include "test-fixtures.h" + +#ifdef _WIN32 +#undef gai_strerror +#define gai_strerror gai_strerrorA +#endif + +using namespace std::literals; + +namespace libtransmission::test +{ + +bool waitFor(struct event_base* event_base, std::chrono::milliseconds msec) +{ + return waitFor( + event_base, + []() { return false; }, + msec); +} + +namespace +{ +auto constexpr IdLength = size_t{ 20U }; +auto constexpr MockTimerInterval = 40ms; + +} // namespace + +class DhtTest : public SandboxedTest +{ +protected: + // Helper for creating a mock dht.dat state file + struct MockStateFile + { + // Fake data to be written to the test state file + + std::array id_ = tr_randObj>(); + + std::vector> ipv4_nodes_ = { + std::make_pair(*tr_address::fromString("10.10.10.1"), tr_port::fromHost(128)), + std::make_pair(*tr_address::fromString("10.10.10.2"), tr_port::fromHost(129)), + std::make_pair(*tr_address::fromString("10.10.10.3"), tr_port::fromHost(130)), + std::make_pair(*tr_address::fromString("10.10.10.4"), tr_port::fromHost(131)), + std::make_pair(*tr_address::fromString("10.10.10.5"), tr_port::fromHost(132)) + }; + + std::vector> ipv6_nodes_ = { + std::make_pair(*tr_address::fromString("1002:1035:4527:3546:7854:1237:3247:3217"), tr_port::fromHost(6881)), + std::make_pair(*tr_address::fromString("1002:1035:4527:3546:7854:1237:3247:3218"), tr_port::fromHost(6882)), + std::make_pair(*tr_address::fromString("1002:1035:4527:3546:7854:1237:3247:3219"), tr_port::fromHost(6883)), + std::make_pair(*tr_address::fromString("1002:1035:4527:3546:7854:1237:3247:3220"), tr_port::fromHost(6884)), + std::make_pair(*tr_address::fromString("1002:1035:4527:3546:7854:1237:3247:3221"), tr_port::fromHost(6885)) + }; + + [[nodiscard]] auto nodesString() const + { + auto str = std::string{}; + for (auto const& [addr, port] : ipv4_nodes_) + { + str += addr.readable(port); + str += ','; + } + for (auto const& [addr, port] : ipv6_nodes_) + { + str += addr.readable(port); + str += ','; + } + return str; + } + + [[nodiscard]] static auto filename(std::string_view dirname) + { + return std::string{ dirname } + "/dht.dat"; + } + + void save(std::string_view path) const + { + auto const dat_file = MockStateFile::filename(path); + + auto dict = tr_variant{}; + tr_variantInitDict(&dict, 3U); + tr_variantDictAddRaw(&dict, TR_KEY_id, std::data(id_), std::size(id_)); + auto compact = std::vector{}; + for (auto const& [addr, port] : ipv4_nodes_) + { + addr.toCompact4(std::back_inserter(compact), port); + } + tr_variantDictAddRaw(&dict, TR_KEY_nodes, std::data(compact), std::size(compact)); + compact.clear(); + for (auto const& [addr, port] : ipv6_nodes_) + { + addr.toCompact6(std::back_inserter(compact), port); + } + tr_variantDictAddRaw(&dict, TR_KEY_nodes6, std::data(compact), std::size(compact)); + tr_variantToFile(&dict, TR_VARIANT_FMT_BENC, dat_file); + tr_variantClear(&dict); + } + }; + + // A fake libdht for the tests to call + class MockDht final : public tr_dht::API + { + public: + int get_nodes(struct sockaddr_in* /*sin*/, int* /*max*/, struct sockaddr_in6* /*sin6*/, int* /*max6*/) override + { + return 0; + } + + int nodes(int /*af*/, int* good, int* dubious, int* cached, int* incoming) override + { + if (good != nullptr) + { + *good = good_; + } + + if (dubious != nullptr) + { + *dubious = dubious_; + } + + if (cached != nullptr) + { + *cached = cached_; + } + + if (incoming != nullptr) + { + *incoming = incoming_; + } + + return 0; + } + + int periodic( + void const* /*buf*/, + size_t /*buflen*/, + sockaddr const /*from*/*, + int /*fromlen*/, + time_t* /*tosleep*/, + dht_callback_t /*callback*/, + void* /*closure*/) override + { + ++n_periodic_calls_; + return 0; + } + + int ping_node(struct sockaddr const* sa, int /*salen*/) override + { + auto addrport = tr_address::fromSockaddr(sa); + auto const [addr, port] = *addrport; + pinged_.push_back(Pinged{ addr, port, tr_time() }); + return 0; + } + + int search(unsigned char const* id, int port, int af, dht_callback_t /*callback*/, void* /*closure*/) override + { + auto info_hash = tr_sha1_digest_t{}; + std::copy_n(reinterpret_cast(id), std::size(info_hash), std::data(info_hash)); + searched_.push_back(Searched{ info_hash, tr_port::fromHost(port), af }); + return 0; + } + + int init(int dht_socket, int dht_socket6, unsigned const char* id, unsigned const char* /*v*/) override + { + inited_ = true; + dht_socket_ = dht_socket; + dht_socket6_ = dht_socket6; + std::copy_n(id, std::size(id_), std::begin(id_)); + return 0; + } + + int uninit() override + { + inited_ = false; + return 0; + } + + constexpr void setHealthySwarm() + { + good_ = 50; + incoming_ = 10; + } + + constexpr void setFirewalledSwarm() + { + good_ = 50; + incoming_ = 0; + } + + constexpr void setPoorSwarm() + { + good_ = 10; + incoming_ = 1; + } + + struct Searched + { + tr_sha1_digest_t info_hash; + tr_port port; + int af; + }; + + struct Pinged + { + tr_address address; + tr_port port; + time_t timestamp; + }; + + int good_ = 0; + int dubious_ = 0; + int cached_ = 0; + int incoming_ = 0; + size_t n_periodic_calls_ = 0; + bool inited_ = false; + std::vector pinged_; + std::vector searched_; + std::array id_ = {}; + int dht_socket_ = TR_BAD_SOCKET; + int dht_socket6_ = TR_BAD_SOCKET; + }; + + // Creates real timers, but with shortened intervals so that tests can run faster + class MockTimer final : public libtransmission::Timer + { + public: + explicit MockTimer(std::unique_ptr real_timer) + : real_timer_{ std::move(real_timer) } + { + } + + void stop() override + { + real_timer_->stop(); + } + + void setCallback(std::function callback) override + { + real_timer_->setCallback(std::move(callback)); + } + + void setRepeating(bool repeating = true) override + { + real_timer_->setRepeating(repeating); + } + + void setInterval(std::chrono::milliseconds /*interval*/) override + { + real_timer_->setInterval(MockTimerInterval); + } + + void start() override + { + real_timer_->start(); + } + + [[nodiscard]] std::chrono::milliseconds interval() const noexcept override + { + return real_timer_->interval(); + } + + [[nodiscard]] bool isRepeating() const noexcept override + { + return real_timer_->isRepeating(); + } + + private: + std::unique_ptr const real_timer_; + }; + + // Creates MockTimers + class MockTimerMaker final : public libtransmission::TimerMaker + { + public: + explicit MockTimerMaker(struct event_base* evb) + : real_timer_maker_{ evb } + { + } + + [[nodiscard]] std::unique_ptr create() override + { + return std::make_unique(real_timer_maker_.create()); + } + + EvTimerMaker real_timer_maker_; + }; + + class MockMediator final : public tr_dht::Mediator + { + public: + explicit MockMediator(struct event_base* event_base) + : mock_timer_maker_{ event_base } + { + } + + [[nodiscard]] std::vector torrentsAllowingDHT() const override + { + return torrents_allowing_dht_; + } + + [[nodiscard]] tr_sha1_digest_t torrentInfoHash(tr_torrent_id_t id) const override + { + if (auto const iter = info_hashes_.find(id); iter != std::end(info_hashes_)) + { + return iter->second; + } + + return {}; + } + + [[nodiscard]] std::string_view configDir() const override + { + return config_dir_; + } + + [[nodiscard]] libtransmission::TimerMaker& timerMaker() override + { + return mock_timer_maker_; + } + + [[nodiscard]] tr_dht::API& api() override + { + return mock_dht_; + } + + void addPex(tr_sha1_digest_t const& /*info_hash*/, tr_pex const* /*pex*/, size_t /*n_pex*/) override + { + } + + std::string config_dir_; + std::vector torrents_allowing_dht_; + std::map info_hashes_; + MockDht mock_dht_; + MockTimerMaker mock_timer_maker_; + }; + + [[nodiscard]] static std::pair getSockaddr(std::string_view name, tr_port port) + { + auto hints = addrinfo{}; + hints.ai_socktype = SOCK_DGRAM; + hints.ai_family = AF_UNSPEC; + + auto const szname = tr_urlbuf{ name }; + auto const port_str = std::to_string(port.host()); + addrinfo* info = nullptr; + if (int const rc = getaddrinfo(szname.c_str(), std::data(port_str), &hints, &info); rc != 0) + { + tr_logAddWarn(fmt::format( + _("Couldn't look up '{address}:{port}': {error} ({error_code})"), + fmt::arg("address", name), + fmt::arg("port", port.host()), + fmt::arg("error", gai_strerror(rc)), + fmt::arg("error_code", rc))); + return {}; + } + + auto opt = tr_address::fromSockaddr(info->ai_addr); + freeaddrinfo(info); + if (opt) + { + return *opt; + } + + return {}; + } + + void SetUp() override + { + SandboxedTest::SetUp(); + + tr_session_thread::tr_evthread_init(); + event_base_ = event_base_new(); + } + + void TearDown() override + { + event_base_free(event_base_); + event_base_ = nullptr; + + SandboxedTest::TearDown(); + } + + struct event_base* event_base_ = nullptr; + + // Arbitrary values. Several tests requires socket/port values + // to be provided but they aren't central to the tests, so they're + // declared here with "Arbitrary" in the name to make that clear. + static auto constexpr ArbitrarySock4 = tr_socket_t{ 404 }; + static auto constexpr ArbitrarySock6 = tr_socket_t{ 418 }; + static auto constexpr ArbitraryPeerPort = tr_port::fromHost(909); +}; + +TEST_F(DhtTest, initsWithCorrectSockets) +{ + static auto constexpr Sock4 = tr_socket_t{ 1000 }; + static auto constexpr Sock6 = tr_socket_t{ 2000 }; + + // Make the DHT + auto mediator = MockMediator{ event_base_ }; + mediator.config_dir_ = sandboxDir(); + auto dht = tr_dht::create(mediator, ArbitraryPeerPort, Sock4, Sock6); + + // Confirm that dht_init() was called with the right sockets + EXPECT_EQ(Sock4, mediator.mock_dht_.dht_socket_); + EXPECT_EQ(Sock6, mediator.mock_dht_.dht_socket6_); +} + +TEST_F(DhtTest, callsUninitOnDestruct) +{ + auto mediator = MockMediator{ event_base_ }; + mediator.config_dir_ = sandboxDir(); + EXPECT_FALSE(mediator.mock_dht_.inited_); + + { + auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6); + EXPECT_TRUE(mediator.mock_dht_.inited_); + + // dht goes out-of-scope here + } + + EXPECT_FALSE(mediator.mock_dht_.inited_); +} + +TEST_F(DhtTest, loadsStateFromStateFile) +{ + auto const state_file = MockStateFile{}; + state_file.save(sandboxDir()); + + // Make the DHT + auto mediator = MockMediator{ event_base_ }; + mediator.config_dir_ = sandboxDir(); + auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6); + + // Wait for all the state nodes to be pinged + auto& pinged = mediator.mock_dht_.pinged_; + auto const n_expected_nodes = std::size(state_file.ipv4_nodes_) + std::size(state_file.ipv6_nodes_); + waitFor(event_base_, [&pinged, n_expected_nodes]() { return std::size(pinged) >= n_expected_nodes; }); + auto actual_nodes_str = std::string{}; + for (auto const& [addr, port, timestamp] : pinged) + { + actual_nodes_str += addr.readable(port); + actual_nodes_str += ','; + } + + /// Confirm that the state was loaded + + // dht_init() should have been called with the state file's id + EXPECT_EQ(state_file.id_, mediator.mock_dht_.id_); + + // dht_ping_nodedht_init() should have been called with state file's nodes + EXPECT_EQ(state_file.nodesString(), actual_nodes_str); +} + +TEST_F(DhtTest, stopsBootstrappingWhenSwarmHealthIsGoodEnough) +{ + auto const state_file = MockStateFile{}; + state_file.save(sandboxDir()); + + // Make the DHT + auto mediator = MockMediator{ event_base_ }; + mediator.config_dir_ = sandboxDir(); + auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6); + + // Wait for N pings to occur... + auto& mock_dht = mediator.mock_dht_; + static auto constexpr TurnGoodAfterNthPing = size_t{ 3 }; + waitFor(event_base_, [&mock_dht]() { return std::size(mock_dht.pinged_) == TurnGoodAfterNthPing; }); + EXPECT_EQ(TurnGoodAfterNthPing, std::size(mock_dht.pinged_)); + + // Now fake that libdht says the swarm is healthy. + // This should cause bootstrapping to end. + mock_dht.setHealthySwarm(); + + // Now test to see if bootstrapping is done. + // There's not public API for `isBootstrapping()`, + // so to test this we just a moment to confirm that no more bootstrap nodes are pinged. + waitFor(event_base_, MockTimerInterval * 10); + + // Confirm that the number of nodes pinged is unchanged, + // indicating that boostrapping is done + EXPECT_EQ(TurnGoodAfterNthPing, std::size(mock_dht.pinged_)); +} + +TEST_F(DhtTest, savesStateIfSwarmIsGood) +{ + auto const state_file = MockStateFile{}; + auto const dat_file = MockStateFile::filename(sandboxDir()); + EXPECT_FALSE(tr_sys_path_exists(dat_file.c_str())); + + { + auto mediator = MockMediator{ event_base_ }; + mediator.config_dir_ = sandboxDir(); + mediator.mock_dht_.setHealthySwarm(); + + auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6); + + // as dht goes out of scope, + // it should save its state if the swarm is healthy + EXPECT_FALSE(tr_sys_path_exists(dat_file.c_str())); + } + + EXPECT_TRUE(tr_sys_path_exists(dat_file.c_str())); +} + +TEST_F(DhtTest, doesNotSaveStateIfSwarmIsBad) +{ + auto const state_file = MockStateFile{}; + auto const dat_file = MockStateFile::filename(sandboxDir()); + EXPECT_FALSE(tr_sys_path_exists(dat_file.c_str())); + + { + auto mediator = MockMediator{ event_base_ }; + mediator.config_dir_ = sandboxDir(); + mediator.mock_dht_.setPoorSwarm(); + + auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6); + + // as dht goes out of scope, + // it should save its state if the swarm is healthy + EXPECT_FALSE(tr_sys_path_exists(dat_file.c_str())); + } + + EXPECT_FALSE(tr_sys_path_exists(dat_file.c_str())); +} + +TEST_F(DhtTest, usesBootstrapFile) +{ + // Make the 'dht.bootstrap' file. + // This a file with each line holding `${host} ${port}` + // which tr-dht will try to ping as nodes + static auto constexpr BootstrapNodeName = "example.com"sv; + static auto constexpr BootstrapNodePort = tr_port::fromHost(8080); + if (auto ofs = std::ofstream{ tr_pathbuf{ sandboxDir(), "/dht.bootstrap" } }; ofs) + { + ofs << BootstrapNodeName << ' ' << BootstrapNodePort.host() << std::endl; + ofs.close(); + } + + // Make the DHT + auto mediator = MockMediator{ event_base_ }; + mediator.config_dir_ = sandboxDir(); + auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6); + + // We didn't create a 'dht.dat' file to load state from, + // so 'dht.bootstrap' should be the first nodes in the bootstrap list. + // Confirm that BootstrapNodeName gets pinged first. + auto const expected = getSockaddr(BootstrapNodeName, BootstrapNodePort); + auto& pinged = mediator.mock_dht_.pinged_; + waitFor( + event_base_, + [&pinged]() { return !std::empty(pinged); }, + 5s); + ASSERT_EQ(1U, std::size(pinged)); + auto const actual = pinged.front(); + EXPECT_EQ(expected.first, actual.address); + EXPECT_EQ(expected.second, actual.port); + EXPECT_EQ(expected.first.readable(expected.second), actual.address.readable(actual.port)); +} + +TEST_F(DhtTest, pingsAddedNodes) +{ + auto mediator = MockMediator{ event_base_ }; + mediator.config_dir_ = sandboxDir(); + auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6); + + EXPECT_EQ(0U, std::size(mediator.mock_dht_.pinged_)); + + auto const addr = *tr_address::fromString("10.10.10.1"); + auto constexpr Port = tr_port::fromHost(128); + dht->addNode(addr, Port); + + ASSERT_EQ(1U, std::size(mediator.mock_dht_.pinged_)); + EXPECT_EQ(addr, mediator.mock_dht_.pinged_.front().address); + EXPECT_EQ(Port, mediator.mock_dht_.pinged_.front().port); +} + +TEST_F(DhtTest, announcesTorrents) +{ + auto constexpr Id = tr_torrent_id_t{ 1 }; + auto constexpr PeerPort = tr_port::fromHost(999); + auto const info_hash = tr_randObj(); + + tr_timeUpdate(time(nullptr)); + + auto mediator = MockMediator{ event_base_ }; + mediator.info_hashes_[Id] = info_hash; + mediator.torrents_allowing_dht_ = { Id }; + mediator.config_dir_ = sandboxDir(); + + // Since we're mocking a swarm that's magically healthy out-of-the-box, + // the DHT object we create can skip bootstrapping and proceed straight + // to announces + auto& mock_dht = mediator.mock_dht_; + mock_dht.setHealthySwarm(); + + auto dht = tr_dht::create(mediator, PeerPort, ArbitrarySock4, ArbitrarySock6); + + waitFor(event_base_, MockTimerInterval * 10); + + ASSERT_EQ(2U, std::size(mock_dht.searched_)); + + EXPECT_EQ(info_hash, mock_dht.searched_[0].info_hash); + EXPECT_EQ(PeerPort, mock_dht.searched_[0].port); + EXPECT_EQ(AF_INET, mock_dht.searched_[0].af); + + EXPECT_EQ(info_hash, mock_dht.searched_[1].info_hash); + EXPECT_EQ(PeerPort, mock_dht.searched_[1].port); + EXPECT_EQ(AF_INET6, mock_dht.searched_[1].af); +} + +TEST_F(DhtTest, callsPeriodicPeriodically) +{ + auto mediator = MockMediator{ event_base_ }; + mediator.config_dir_ = sandboxDir(); + auto dht = tr_dht::create(mediator, ArbitraryPeerPort, ArbitrarySock4, ArbitrarySock6); + + auto& mock_dht = mediator.mock_dht_; + auto const baseline = mock_dht.n_periodic_calls_; + static auto constexpr Periods = 10; + waitFor(event_base_, std::chrono::duration_cast(MockTimerInterval * Periods)); + EXPECT_NEAR(mock_dht.n_periodic_calls_, baseline + Periods, Periods / 2); +} + +} // namespace libtransmission::test diff --git a/tests/libtransmission/test-fixtures.h b/tests/libtransmission/test-fixtures.h index b6224aa34..f066ae297 100644 --- a/tests/libtransmission/test-fixtures.h +++ b/tests/libtransmission/test-fixtures.h @@ -43,6 +43,14 @@ namespace libtransmission namespace test { +template +[[nodiscard]] static auto tr_randObj() +{ + auto ret = T{}; + tr_rand_buffer(&ret, sizeof(ret)); + return ret; +} + using file_func_t = std::function; static void depthFirstWalk(char const* path, file_func_t func) diff --git a/tests/libtransmission/timer-test.cc b/tests/libtransmission/timer-test.cc index d079e2153..5ed856cc7 100644 --- a/tests/libtransmission/timer-test.cc +++ b/tests/libtransmission/timer-test.cc @@ -37,7 +37,7 @@ protected: return std::chrono::duration_cast(val); }; - void sleep_msec(std::chrono::milliseconds msec) + void sleepMsec(std::chrono::milliseconds msec) { EXPECT_FALSE(waitFor( evbase_.get(), @@ -45,7 +45,7 @@ protected: msec)); } - static void EXPECT_TIME( + static void expectTime( std::chrono::milliseconds expected, std::chrono::milliseconds actual, std::chrono::milliseconds allowed_deviation) @@ -59,12 +59,12 @@ protected: // This checks that `actual` is in the bounds of [expected/2 ... expected*1.5] // to confirm that the timer didn't kick too close to the previous or next interval. - static void EXPECT_INTERVAL(std::chrono::milliseconds expected, std::chrono::milliseconds actual) + static void expectInterval(std::chrono::milliseconds expected, std::chrono::milliseconds actual) { - EXPECT_TIME(expected, actual, expected / 2); + expectTime(expected, actual, expected / 2); } - [[nodiscard]] static auto current_time() + [[nodiscard]] static auto currentTime() { return std::chrono::steady_clock::now(); } @@ -133,17 +133,17 @@ TEST_F(TimerTest, singleShotHonorsInterval) timer->setCallback(callback); // run a single-shot timer - auto const begin_time = current_time(); + auto const begin_time = currentTime(); static auto constexpr Interval = 100ms; timer->startSingleShot(Interval); EXPECT_FALSE(timer->isRepeating()); EXPECT_EQ(Interval, timer->interval()); waitFor(evbase_.get(), [&called] { return called; }); - auto const end_time = current_time(); + auto const end_time = currentTime(); // confirm that it kicked at the right interval EXPECT_TRUE(called); - EXPECT_INTERVAL(Interval, AsMSec(end_time - begin_time)); + expectInterval(Interval, AsMSec(end_time - begin_time)); } TEST_F(TimerTest, repeatingHonorsInterval) @@ -160,17 +160,17 @@ TEST_F(TimerTest, repeatingHonorsInterval) timer->setCallback(callback); // start a repeating timer - auto const begin_time = current_time(); + auto const begin_time = currentTime(); static auto constexpr Interval = 100ms; static auto constexpr DesiredLoops = 3; timer->startRepeating(Interval); EXPECT_TRUE(timer->isRepeating()); EXPECT_EQ(Interval, timer->interval()); waitFor(evbase_.get(), [&n_calls] { return n_calls >= DesiredLoops; }); - auto const end_time = current_time(); + auto const end_time = currentTime(); // confirm that it kicked the right number of times - EXPECT_INTERVAL(Interval * DesiredLoops, AsMSec(end_time - begin_time)); + expectInterval(Interval * DesiredLoops, AsMSec(end_time - begin_time)); EXPECT_EQ(DesiredLoops, n_calls); } @@ -190,12 +190,12 @@ TEST_F(TimerTest, restartWithDifferentInterval) auto const test = [this, &n_calls, &timer](auto interval) { auto const next = n_calls + 1; - auto const begin_time = current_time(); + auto const begin_time = currentTime(); timer->startSingleShot(interval); waitFor(evbase_.get(), [&n_calls, next]() { return n_calls >= next; }); - auto const end_time = current_time(); + auto const end_time = currentTime(); - EXPECT_INTERVAL(interval, AsMSec(end_time - begin_time)); + expectInterval(interval, AsMSec(end_time - begin_time)); }; test(100ms); @@ -219,12 +219,12 @@ TEST_F(TimerTest, restartWithSameInterval) auto const test = [this, &n_calls, &timer](auto interval) { auto const next = n_calls + 1; - auto const begin_time = current_time(); + auto const begin_time = currentTime(); timer->startSingleShot(interval); waitFor(evbase_.get(), [&n_calls, next]() { return n_calls >= next; }); - auto const end_time = current_time(); + auto const end_time = currentTime(); - EXPECT_INTERVAL(interval, AsMSec(end_time - begin_time)); + expectInterval(interval, AsMSec(end_time - begin_time)); }; test(timer->interval()); @@ -246,31 +246,31 @@ TEST_F(TimerTest, repeatingThenSingleShot) timer->setCallback(callback); // start a repeating timer and confirm that it's running - auto begin_time = current_time(); + auto begin_time = currentTime(); static auto constexpr RepeatingInterval = 100ms; static auto constexpr DesiredLoops = 2; timer->startRepeating(RepeatingInterval); EXPECT_EQ(RepeatingInterval, timer->interval()); EXPECT_TRUE(timer->isRepeating()); waitFor(evbase_.get(), [&n_calls]() { return n_calls >= DesiredLoops; }); - auto end_time = current_time(); - EXPECT_TIME(RepeatingInterval * DesiredLoops, AsMSec(end_time - begin_time), RepeatingInterval / 2); + auto end_time = currentTime(); + expectTime(RepeatingInterval * DesiredLoops, AsMSec(end_time - begin_time), RepeatingInterval / 2); // now restart it as a single shot auto const baseline = n_calls; - begin_time = current_time(); + begin_time = currentTime(); static auto constexpr SingleShotInterval = 25ms; timer->startSingleShot(SingleShotInterval); EXPECT_EQ(SingleShotInterval, timer->interval()); EXPECT_FALSE(timer->isRepeating()); waitFor(evbase_.get(), [&n_calls]() { return n_calls >= DesiredLoops + 1; }); - end_time = current_time(); + end_time = currentTime(); // confirm that the single shot interval was honored - EXPECT_INTERVAL(SingleShotInterval, AsMSec(end_time - begin_time)); + expectInterval(SingleShotInterval, AsMSec(end_time - begin_time)); // confirm that the timer only kicks once, since it was converted into single-shot - sleep_msec(SingleShotInterval * 3); + sleepMsec(SingleShotInterval * 3); EXPECT_EQ(baseline + 1, n_calls); } @@ -294,13 +294,13 @@ TEST_F(TimerTest, singleShotStop) EXPECT_FALSE(timer->isRepeating()); // wait half the interval, then stop the timer - sleep_msec(Interval / 2); + sleepMsec(Interval / 2); EXPECT_EQ(0U, n_calls); timer->stop(); // wait until the timer has gone past. // since we stopped it, callback should not have been called. - sleep_msec(Interval); + sleepMsec(Interval); EXPECT_EQ(0U, n_calls); } @@ -324,13 +324,13 @@ TEST_F(TimerTest, repeatingStop) EXPECT_TRUE(timer->isRepeating()); // wait half the interval, then stop the timer - sleep_msec(Interval / 2); + sleepMsec(Interval / 2); EXPECT_EQ(0U, n_calls); timer->stop(); // wait until the timer has gone past. // since we stopped it, callback should not have been called. - sleep_msec(Interval); + sleepMsec(Interval); EXPECT_EQ(0U, n_calls); } @@ -354,13 +354,13 @@ TEST_F(TimerTest, destroyedTimersStop) EXPECT_TRUE(timer->isRepeating()); // wait half the interval, then destroy the timer - sleep_msec(Interval / 2); + sleepMsec(Interval / 2); EXPECT_EQ(0U, n_calls); timer.reset(); // wait until the timer has gone past. // since we destroyed it, callback should not have been called. - sleep_msec(Interval); + sleepMsec(Interval); EXPECT_EQ(0U, n_calls); }