refactor: try again to use getaddrinfo in announcer_udp (#4201)

This commit is contained in:
Charles Kerr
2022-11-16 15:13:31 -06:00
committed by GitHub
parent 7a2d2ff0cf
commit a45cc2a79d
10 changed files with 411 additions and 768 deletions

View File

@@ -11,8 +11,6 @@
0A6169A80FE5C9A200C66CE6 /* bitfield.h in Headers */ = {isa = PBXBuildFile; fileRef = 0A6169A60FE5C9A200C66CE6 /* bitfield.h */; }; 0A6169A80FE5C9A200C66CE6 /* bitfield.h in Headers */ = {isa = PBXBuildFile; fileRef = 0A6169A60FE5C9A200C66CE6 /* bitfield.h */; };
0A89346B736DBCF81F3A4850 /* torrent-metainfo.cc in Sources */ = {isa = PBXBuildFile; fileRef = 0A89346B736DBCF81F3A4851 /* torrent-metainfo.cc */; }; 0A89346B736DBCF81F3A4850 /* torrent-metainfo.cc in Sources */ = {isa = PBXBuildFile; fileRef = 0A89346B736DBCF81F3A4851 /* torrent-metainfo.cc */; };
0A89346B736DBCF81F3A4852 /* torrent-metainfo.h in Headers */ = {isa = PBXBuildFile; fileRef = 0A89346B736DBCF81F3A4853 /* torrent-metainfo.h */; }; 0A89346B736DBCF81F3A4852 /* torrent-metainfo.h in Headers */ = {isa = PBXBuildFile; fileRef = 0A89346B736DBCF81F3A4853 /* torrent-metainfo.h */; };
11524394C75E57E52CD9ADF0 /* dns.h in Headers */ = {isa = PBXBuildFile; fileRef = 11524394C75E57E52CD9ADF1 /* dns.h */; settings = {ATTRIBUTES = (Public, ); }; };
11524394C75E57E52CD9ADF2 /* dns-ev.h in Headers */ = {isa = PBXBuildFile; fileRef = 11524394C75E57E52CD9ADF3 /* dns-ev.h */; settings = {ATTRIBUTES = (Private, ); }; };
1BB44E07B1B52E28291B4E32 /* file-piece-map.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1BB44E07B1B52E28291B4E30 /* file-piece-map.cc */; }; 1BB44E07B1B52E28291B4E32 /* file-piece-map.cc in Sources */ = {isa = PBXBuildFile; fileRef = 1BB44E07B1B52E28291B4E30 /* file-piece-map.cc */; };
1BB44E07B1B52E28291B4E33 /* file-piece-map.h in Headers */ = {isa = PBXBuildFile; fileRef = 1BB44E07B1B52E28291B4E31 /* file-piece-map.h */; }; 1BB44E07B1B52E28291B4E33 /* file-piece-map.h in Headers */ = {isa = PBXBuildFile; fileRef = 1BB44E07B1B52E28291B4E31 /* file-piece-map.h */; };
2856E0656A49F2665D69E760 /* benc.h in Headers */ = {isa = PBXBuildFile; fileRef = 2856E0656A49F2665D69E761 /* benc.h */; }; 2856E0656A49F2665D69E760 /* benc.h in Headers */ = {isa = PBXBuildFile; fileRef = 2856E0656A49F2665D69E761 /* benc.h */; };
@@ -604,8 +602,6 @@
0A89346B736DBCF81F3A4851 /* torrent-metainfo.cc */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = "torrent-metainfo.cc"; sourceTree = "<group>"; }; 0A89346B736DBCF81F3A4851 /* torrent-metainfo.cc */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = "torrent-metainfo.cc"; sourceTree = "<group>"; };
0A89346B736DBCF81F3A4853 /* torrent-metainfo.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "torrent-metainfo.h"; sourceTree = "<group>"; }; 0A89346B736DBCF81F3A4853 /* torrent-metainfo.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "torrent-metainfo.h"; sourceTree = "<group>"; };
1058C7A1FEA54F0111CA2CBB /* Cocoa.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Cocoa.framework; path = System/Library/Frameworks/Cocoa.framework; sourceTree = SDKROOT; }; 1058C7A1FEA54F0111CA2CBB /* Cocoa.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Cocoa.framework; path = System/Library/Frameworks/Cocoa.framework; sourceTree = SDKROOT; };
11524394C75E57E52CD9ADF1 /* dns.h */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.h; fileEncoding = 4; path = dns.h; sourceTree = "<group>"; };
11524394C75E57E52CD9ADF3 /* dns-ev.h */ = {isa = PBXFileReference; explicitFileType = sourcecode.cpp.h; fileEncoding = 4; path = "dns-ev.h"; sourceTree = "<group>"; };
13E42FB307B3F0F600E4EEF1 /* CoreData.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreData.framework; path = System/Library/Frameworks/CoreData.framework; sourceTree = SDKROOT; }; 13E42FB307B3F0F600E4EEF1 /* CoreData.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreData.framework; path = System/Library/Frameworks/CoreData.framework; sourceTree = SDKROOT; };
1BB44E07B1B52E28291B4E30 /* file-piece-map.cc */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = "file-piece-map.cc"; sourceTree = "<group>"; }; 1BB44E07B1B52E28291B4E30 /* file-piece-map.cc */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = "file-piece-map.cc"; sourceTree = "<group>"; };
1BB44E07B1B52E28291B4E31 /* file-piece-map.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "file-piece-map.h"; sourceTree = "<group>"; }; 1BB44E07B1B52E28291B4E31 /* file-piece-map.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = "file-piece-map.h"; sourceTree = "<group>"; };
@@ -1711,8 +1707,6 @@
C11DEA151FCD31C0009E22B9 /* subprocess.h */, C11DEA151FCD31C0009E22B9 /* subprocess.h */,
E975121263DD973CAF4AEBA5 /* timer-ev.cc */, E975121263DD973CAF4AEBA5 /* timer-ev.cc */,
E975121263DD973CAF4AEBA3 /* timer-ev.h */, E975121263DD973CAF4AEBA3 /* timer-ev.h */,
11524394C75E57E52CD9ADF1 /* dns.h */,
11524394C75E57E52CD9ADF3 /* dns-ev.h */,
E975121263DD973CAF4AEBA1 /* timer.h */, E975121263DD973CAF4AEBA1 /* timer.h */,
A20152790D1C26EB0081714F /* torrent-ctor.cc */, A20152790D1C26EB0081714F /* torrent-ctor.cc */,
A47A7C87B8B57BE50DF0D411 /* torrent-files.cc */, A47A7C87B8B57BE50DF0D411 /* torrent-files.cc */,
@@ -2229,8 +2223,6 @@
2856E0656A49F2665D69E760 /* benc.h in Headers */, 2856E0656A49F2665D69E760 /* benc.h in Headers */,
E975121263DD973CAF4AEBA0 /* timer.h in Headers */, E975121263DD973CAF4AEBA0 /* timer.h in Headers */,
E975121263DD973CAF4AEBA2 /* timer-ev.h in Headers */, E975121263DD973CAF4AEBA2 /* timer-ev.h in Headers */,
11524394C75E57E52CD9ADF0 /* dns.h in Headers */,
11524394C75E57E52CD9ADF2 /* dns-ev.h in Headers */,
C1077A4F183EB29600634C22 /* error.h in Headers */, C1077A4F183EB29600634C22 /* error.h in Headers */,
A2679295130E00A000CB7464 /* tr-utp.h in Headers */, A2679295130E00A000CB7464 /* tr-utp.h in Headers */,
A263C6B1F6718E2486DB20E0 /* tr-buffer.h in Headers */, A263C6B1F6718E2486DB20E0 /* tr-buffer.h in Headers */,

View File

@@ -136,8 +136,6 @@ endif()
set(${PROJECT_NAME}_PUBLIC_HEADERS set(${PROJECT_NAME}_PUBLIC_HEADERS
${PROJECT_BINARY_DIR}/version.h ${PROJECT_BINARY_DIR}/version.h
dns-ev.h
dns.h
error-types.h error-types.h
error.h error.h
file.h file.h

View File

@@ -7,6 +7,7 @@
#include <cerrno> // for errno, EAFNOSUPPORT #include <cerrno> // for errno, EAFNOSUPPORT
#include <cstring> // for memset() #include <cstring> // for memset()
#include <ctime> #include <ctime>
#include <future>
#include <list> #include <list>
#include <memory> #include <memory>
#include <string_view> #include <string_view>
@@ -37,6 +38,7 @@
using namespace std::literals; using namespace std::literals;
// size defined by bep15
using tau_connection_t = uint64_t; using tau_connection_t = uint64_t;
using tau_transaction_t = uint32_t; using tau_transaction_t = uint32_t;
@@ -49,7 +51,7 @@ static tau_transaction_t tau_transaction_new()
return tmp; return tmp;
} }
/* used in the "action" field of a request */ // used in the "action" field of a request. Values defined in bep 15.
enum tau_action_t enum tau_action_t
{ {
TAU_ACTION_CONNECT = 0, TAU_ACTION_CONNECT = 0,
@@ -58,46 +60,47 @@ enum tau_action_t
TAU_ACTION_ERROR = 3 TAU_ACTION_ERROR = 3
}; };
static bool is_tau_response_message(tau_action_t action, size_t msglen)
{
if (action == TAU_ACTION_CONNECT)
{
return msglen == 16;
}
if (action == TAU_ACTION_ANNOUNCE)
{
return msglen >= 20;
}
if (action == TAU_ACTION_SCRAPE)
{
return msglen >= 20;
}
if (action == TAU_ACTION_ERROR)
{
return msglen >= 8;
}
return false;
}
static auto constexpr TauRequestTtl = int{ 60 };
/**** /****
*****
***** SCRAPE ***** SCRAPE
*****
****/ ****/
struct tau_scrape_request struct tau_scrape_request
{ {
tau_scrape_request(tr_scrape_request const& in, tr_scrape_response_func callback, void* user_data)
: callback_{ callback }
, user_data_{ user_data }
{
this->response.scrape_url = in.scrape_url;
this->response.row_count = in.info_hash_count;
for (int i = 0; i < this->response.row_count; ++i)
{
this->response.rows[i].seeders = -1;
this->response.rows[i].leechers = -1;
this->response.rows[i].downloads = -1;
this->response.rows[i].info_hash = in.info_hash[i];
}
// build the payload
auto buf = libtransmission::Buffer{};
buf.addUint32(TAU_ACTION_SCRAPE);
buf.addUint32(transaction_id);
for (int i = 0; i < in.info_hash_count; ++i)
{
buf.add(in.info_hash[i]);
}
this->payload.insert(std::end(this->payload), std::begin(buf), std::end(buf));
}
[[nodiscard]] constexpr auto hasCallback() const noexcept
{
return callback_ != nullptr;
}
void requestFinished() void requestFinished()
{ {
if (callback != nullptr) if (callback_ != nullptr)
{ {
callback(&response, user_data); callback_(&response, user_data_);
} }
} }
@@ -140,67 +143,63 @@ struct tau_scrape_request
std::vector<std::byte> payload; std::vector<std::byte> payload;
time_t sent_at; time_t sent_at = 0;
time_t created_at; time_t const created_at = tr_time();
tau_transaction_t transaction_id;
tr_scrape_response response;
tr_scrape_response_func callback;
void* user_data;
};
static tau_scrape_request make_tau_scrape_request(
tr_scrape_request const& in,
tr_scrape_response_func callback,
void* user_data)
{
tau_transaction_t const transaction_id = tau_transaction_new(); tau_transaction_t const transaction_id = tau_transaction_new();
/* build the payload */ tr_scrape_response response = {};
auto buf = libtransmission::Buffer{};
buf.addUint32(TAU_ACTION_SCRAPE);
buf.addUint32(transaction_id);
for (int i = 0; i < in.info_hash_count; ++i)
{
buf.add(in.info_hash[i]);
}
// build the tau_scrape_request private:
auto req = tau_scrape_request{}; tr_scrape_response_func const callback_;
req.callback = callback; void* const user_data_;
req.created_at = tr_time(); };
req.transaction_id = transaction_id;
req.callback = callback;
req.user_data = user_data;
req.response.scrape_url = in.scrape_url;
req.response.row_count = in.info_hash_count;
req.payload.insert(std::end(req.payload), std::begin(buf), std::end(buf));
for (int i = 0; i < req.response.row_count; ++i)
{
req.response.rows[i].seeders = -1;
req.response.rows[i].leechers = -1;
req.response.rows[i].downloads = -1;
req.response.rows[i].info_hash = in.info_hash[i];
}
/* cleanup */
return req;
}
/**** /****
*****
***** ANNOUNCE ***** ANNOUNCE
*****
****/ ****/
struct tau_announce_request struct tau_announce_request
{ {
tau_announce_request(
uint32_t announce_ip,
tr_announce_request const& in,
tr_announce_response_func callback,
void* user_data)
: callback_{ callback }
, user_data_{ user_data }
{
response.seeders = -1;
response.leechers = -1;
response.downloads = -1;
response.info_hash = in.info_hash;
// build the payload
auto buf = libtransmission::Buffer{};
buf.addUint32(TAU_ACTION_ANNOUNCE);
buf.addUint32(transaction_id);
buf.add(in.info_hash);
buf.add(in.peer_id);
buf.addUint64(in.down);
buf.addUint64(in.leftUntilComplete);
buf.addUint64(in.up);
buf.addUint32(get_tau_announce_event(in.event));
buf.addUint32(announce_ip);
buf.addUint32(in.key);
buf.addUint32(in.numwant);
buf.addPort(in.port);
payload.insert(std::end(payload), std::begin(buf), std::end(buf));
}
[[nodiscard]] constexpr auto hasCallback() const noexcept
{
return callback_ != nullptr;
}
void requestFinished() void requestFinished()
{ {
if (this->callback != nullptr) if (callback_ != nullptr)
{ {
this->callback(&this->response, this->user_data); callback_(&this->response, user_data_);
} }
} }
@@ -236,29 +235,27 @@ struct tau_announce_request
} }
} }
std::vector<std::byte> payload; enum tau_announce_event
{
time_t created_at = 0; // Used in the "event" field of an announce request.
time_t sent_at = 0; // These values come from BEP 15
tau_transaction_t transaction_id = 0;
tr_announce_response response = {};
tr_announce_response_func callback = nullptr;
void* user_data = nullptr;
};
enum tau_announce_event
{
/* used in the "event" field of an announce request */
TAU_ANNOUNCE_EVENT_NONE = 0, TAU_ANNOUNCE_EVENT_NONE = 0,
TAU_ANNOUNCE_EVENT_COMPLETED = 1, TAU_ANNOUNCE_EVENT_COMPLETED = 1,
TAU_ANNOUNCE_EVENT_STARTED = 2, TAU_ANNOUNCE_EVENT_STARTED = 2,
TAU_ANNOUNCE_EVENT_STOPPED = 3 TAU_ANNOUNCE_EVENT_STOPPED = 3
}; };
static tau_announce_event get_tau_announce_event(tr_announce_event e) std::vector<std::byte> payload;
{
time_t const created_at = tr_time();
time_t sent_at = 0;
tau_transaction_t const transaction_id = tau_transaction_new();
tr_announce_response response = {};
private:
[[nodiscard]] static constexpr tau_announce_event get_tau_announce_event(tr_announce_event e)
{
switch (e) switch (e)
{ {
case TR_ANNOUNCE_EVENT_COMPLETED: case TR_ANNOUNCE_EVENT_COMPLETED:
@@ -273,50 +270,14 @@ static tau_announce_event get_tau_announce_event(tr_announce_event e)
default: default:
return TAU_ANNOUNCE_EVENT_NONE; return TAU_ANNOUNCE_EVENT_NONE;
} }
} }
static tau_announce_request make_tau_announce_request( tr_announce_response_func const callback_;
uint32_t announce_ip, void* const user_data_;
tr_announce_request const& in, };
tr_announce_response_func callback,
void* user_data)
{
tau_transaction_t const transaction_id = tau_transaction_new();
/* build the payload */
auto buf = libtransmission::Buffer{};
buf.addUint32(TAU_ACTION_ANNOUNCE);
buf.addUint32(transaction_id);
buf.add(in.info_hash);
buf.add(in.peer_id);
buf.addUint64(in.down);
buf.addUint64(in.leftUntilComplete);
buf.addUint64(in.up);
buf.addUint32(get_tau_announce_event(in.event));
buf.addUint32(announce_ip);
buf.addUint32(in.key);
buf.addUint32(in.numwant);
buf.addPort(in.port);
/* build the tau_announce_request */
auto req = tau_announce_request();
req.created_at = tr_time();
req.transaction_id = transaction_id;
req.callback = callback;
req.user_data = user_data;
req.payload.insert(std::end(req.payload), std::begin(buf), std::end(buf));
req.response.seeders = -1;
req.response.leechers = -1;
req.response.downloads = -1;
req.response.info_hash = in.info_hash;
return req;
}
/**** /****
***** ***** TRACKER
***** TRACKERS
*****
****/ ****/
struct tau_tracker struct tau_tracker
@@ -324,16 +285,164 @@ struct tau_tracker
using Mediator = tr_announcer_udp::Mediator; using Mediator = tr_announcer_udp::Mediator;
tau_tracker(Mediator& mediator, tr_interned_string key_in, tr_interned_string host_in, tr_port port_in) tau_tracker(Mediator& mediator, tr_interned_string key_in, tr_interned_string host_in, tr_port port_in)
: mediator_{ mediator } : key{ key_in }
, key{ key_in }
, host{ host_in } , host{ host_in }
, port{ port_in } , port{ port_in }
, mediator_{ mediator }
{ {
} }
[[nodiscard]] auto isIdle() const noexcept [[nodiscard]] auto isIdle() const noexcept
{ {
return std::empty(announces) && std::empty(scrapes) && (dns_request_ == 0U); return std::empty(announces) && std::empty(scrapes) && !addr_pending_dns_;
}
void sendto(void const* buf, size_t buflen)
{
TR_ASSERT(addr_);
if (!addr_)
{
return;
}
auto const& [ss, sslen] = *addr_;
mediator_.sendto(buf, buflen, reinterpret_cast<sockaddr const*>(&ss), sslen);
}
void on_connection_response(tau_action_t action, libtransmission::Buffer& buf)
{
this->connecting_at = 0;
this->connection_transaction_id = 0;
if (action == TAU_ACTION_CONNECT)
{
this->connection_id = buf.toUint64();
this->connection_expiration_time = tr_time() + TauConnectionTtlSecs;
logdbg(this->key, fmt::format("Got a new connection ID from tracker: {}", this->connection_id));
}
else if (action == TAU_ACTION_ERROR)
{
std::string const errmsg = !std::empty(buf) ? buf.toString() : _("Connection failed");
logdbg(this->key, errmsg);
this->failAll(true, false, errmsg);
}
this->upkeep();
}
void upkeep(bool timeout_reqs = true)
{
time_t const now = tr_time();
bool const closing = this->close_at != 0;
// do we have a DNS request that's ready?
if (addr_pending_dns_ && addr_pending_dns_->wait_for(0ms) == std::future_status::ready)
{
addr_ = addr_pending_dns_->get();
addr_pending_dns_.reset();
addr_expires_at_ = now + DnsRetryIntervalSecs;
}
// if the address info is too old, expire it
if (addr_ && (closing || addr_expires_at_ <= now))
{
logtrace(this->host, "Expiring old DNS result");
addr_.reset();
addr_expires_at_ = 0;
}
// are there any requests pending?
if (this->isIdle())
{
return;
}
// if DNS lookup *recently* failed for this host, do nothing
if (!addr_ && now < addr_expires_at_)
{
return;
}
// if we don't have an address yet, try & get one now.
if (!closing && !addr_ && !addr_pending_dns_)
{
addr_pending_dns_ = std::async(std::launch::async, lookup, this->host, this->port, this->key);
return;
}
logtrace(
this->key,
fmt::format(
"connected {} ({} {}) -- connecting_at {}",
this->connection_expiration_time > now,
this->connection_expiration_time,
now,
this->connecting_at));
/* also need a valid connection ID... */
if (addr_ && this->connection_expiration_time <= now && this->connecting_at == 0)
{
this->connecting_at = now;
this->connection_transaction_id = tau_transaction_new();
logtrace(this->key, fmt::format("Trying to connect. Transaction ID is {}", this->connection_transaction_id));
auto buf = libtransmission::Buffer{};
buf.addUint64(0x41727101980LL);
buf.addUint32(TAU_ACTION_CONNECT);
buf.addUint32(this->connection_transaction_id);
auto const contiguous = std::vector<std::byte>(std::begin(buf), std::end(buf));
this->sendto(std::data(contiguous), std::size(contiguous));
return;
}
if (timeout_reqs)
{
timeout_requests();
}
if (addr_ && this->connection_expiration_time > now)
{
send_requests();
}
}
private:
using Sockaddr = std::pair<sockaddr_storage, socklen_t>;
using MaybeSockaddr = std::optional<Sockaddr>;
[[nodiscard]] static MaybeSockaddr lookup(tr_interned_string host, tr_port port, tr_interned_string logname)
{
auto szport = std::array<char, 16>{};
*fmt::format_to(std::data(szport), FMT_STRING("{:d}"), port.host()) = '\0';
auto hints = addrinfo{};
hints.ai_family = AF_UNSPEC;
hints.ai_protocol = IPPROTO_UDP;
hints.ai_socktype = SOCK_DGRAM;
addrinfo* info = nullptr;
if (int const rc = getaddrinfo(host.c_str(), std::data(szport), &hints, &info); rc != 0)
{
logwarn(
logname,
fmt::format(
_("Couldn't look up '{address}:{port}': {error} ({error_code})"),
fmt::arg("address", host.sv()),
fmt::arg("port", port.host()),
fmt::arg("error", gai_strerror(rc)),
fmt::arg("error_code", rc)));
return {};
}
auto ss = sockaddr_storage{};
auto const len = info->ai_addrlen;
memcpy(&ss, info->ai_addr, len);
freeaddrinfo(info);
logdbg(logname, "DNS lookup succeeded");
return std::make_pair(ss, len);
} }
void failAll(bool did_connect, bool did_timeout, std::string_view errmsg) void failAll(bool did_connect, bool did_timeout, std::string_view errmsg)
@@ -352,90 +461,58 @@ struct tau_tracker
this->announces.clear(); this->announces.clear();
} }
void sendto(void const* buf, size_t buflen) ///
void timeout_requests()
{ {
TR_ASSERT(addr_); time_t const now = time(nullptr);
if (!addr_) bool const cancel_all = this->close_at != 0 && (this->close_at <= now);
if (this->connecting_at != 0 && this->connecting_at + TauRequestTtl < now)
{ {
return; auto empty_buf = libtransmission::Buffer{};
on_connection_response(TAU_ACTION_ERROR, empty_buf);
} }
auto [ss, sslen] = *addr_; timeout_requests(this->announces, now, cancel_all, "announce");
timeout_requests(this->scrapes, now, cancel_all, "scrape");
if (ss.ss_family == AF_INET)
{
reinterpret_cast<sockaddr_in*>(&ss)->sin_port = port.network();
}
else if (ss.ss_family == AF_INET6)
{
reinterpret_cast<sockaddr_in6*>(&ss)->sin6_port = port.network();
} }
mediator_.sendto(buf, buflen, reinterpret_cast<sockaddr*>(&ss), sslen); template<typename T>
} void timeout_requests(std::list<T>& requests, time_t now, bool cancel_all, std::string_view name)
Mediator& mediator_;
tr_interned_string const key;
tr_interned_string const host;
tr_port const port;
libtransmission::Dns::Tag dns_request_ = {};
std::optional<std::pair<sockaddr_storage, socklen_t>> addr_;
time_t addr_expires_at_ = 0;
time_t connecting_at = 0;
time_t connection_expiration_time = 0;
tau_connection_t connection_id = 0;
tau_transaction_t connection_transaction_id = 0;
time_t close_at = 0;
static time_t constexpr DnsRetryIntervalSecs = 60 * 60;
std::list<tau_announce_request> announces;
std::list<tau_scrape_request> scrapes;
};
static void tau_tracker_upkeep(struct tau_tracker* /*tracker*/);
static void tau_tracker_on_dns(tau_tracker* const tracker, sockaddr const* sa, socklen_t salen, time_t expires_at)
{
tracker->dns_request_ = {};
if (sa == nullptr)
{ {
auto const errmsg = fmt::format(_("Couldn't find address of tracker '{host}'"), fmt::arg("host", tracker->host)); for (auto it = std::begin(requests); it != std::end(requests);)
logwarn(tracker->key, errmsg); {
tracker->failAll(false, false, errmsg.c_str()); auto& req = *it;
tracker->addr_expires_at_ = tr_time() + tau_tracker::DnsRetryIntervalSecs; if (cancel_all || req.created_at + TauRequestTtl < now)
{
logtrace(this->key, fmt::format("timeout {} req {}", name, fmt::ptr(&req)));
req.fail(false, true, "");
it = requests.erase(it);
} }
else else
{ {
logdbg(tracker->key, "DNS lookup succeeded"); ++it;
auto ss = sockaddr_storage{}; }
memcpy(&ss, sa, salen); }
tracker->addr_.emplace(ss, salen);
tracker->addr_expires_at_ = expires_at;
tau_tracker_upkeep(tracker);
} }
}
static void tau_tracker_send_request(struct tau_tracker* tracker, void const* payload, size_t payload_len) ///
{
logdbg(tracker->key, fmt::format("sending request w/connection id {}", tracker->connection_id));
auto buf = libtransmission::Buffer{}; void send_requests()
buf.addUint64(tracker->connection_id); {
buf.add(payload, payload_len); TR_ASSERT(!addr_pending_dns_);
TR_ASSERT(addr_);
TR_ASSERT(this->connecting_at == 0);
TR_ASSERT(this->connection_expiration_time > tr_time());
auto const contiguous = std::vector<std::byte>(std::begin(buf), std::end(buf)); send_requests(this->announces);
tracker->sendto(std::data(contiguous), std::size(contiguous)); send_requests(this->scrapes);
} }
template<typename T> template<typename T>
static void tau_tracker_send_requests(tau_tracker* tracker, std::list<T>& reqs) void send_requests(std::list<T>& reqs)
{ {
auto const now = tr_time(); auto const now = tr_time();
for (auto it = std::begin(reqs); it != std::end(reqs);) for (auto it = std::begin(reqs); it != std::end(reqs);)
@@ -448,11 +525,11 @@ static void tau_tracker_send_requests(tau_tracker* tracker, std::list<T>& reqs)
continue; continue;
} }
logdbg(tracker->key, fmt::format("sending req {}", fmt::ptr(&req))); logdbg(this->key, fmt::format("sending req {}", fmt::ptr(&req)));
req.sent_at = now; req.sent_at = now;
tau_tracker_send_request(tracker, std::data(req.payload), std::size(req.payload)); send_request(std::data(req.payload), std::size(req.payload));
if (req.callback != nullptr) if (req.hasCallback())
{ {
++it; ++it;
continue; continue;
@@ -461,176 +538,49 @@ static void tau_tracker_send_requests(tau_tracker* tracker, std::list<T>& reqs)
// no response needed, so we can remove it now // no response needed, so we can remove it now
it = reqs.erase(it); it = reqs.erase(it);
} }
}
static void tau_tracker_send_reqs(tau_tracker* tracker)
{
TR_ASSERT(!tracker->dns_request_);
TR_ASSERT(tracker->addr_);
TR_ASSERT(tracker->connecting_at == 0);
TR_ASSERT(tracker->connection_expiration_time > tr_time());
tau_tracker_send_requests(tracker, tracker->announces);
tau_tracker_send_requests(tracker, tracker->scrapes);
}
static void on_tracker_connection_response(struct tau_tracker& tracker, tau_action_t action, libtransmission::Buffer& buf)
{
tracker.connecting_at = 0;
tracker.connection_transaction_id = 0;
if (action == TAU_ACTION_CONNECT)
{
tracker.connection_id = buf.toUint64();
tracker.connection_expiration_time = tr_time() + TauConnectionTtlSecs;
logdbg(tracker.key, fmt::format("Got a new connection ID from tracker: {}", tracker.connection_id));
}
else if (action == TAU_ACTION_ERROR)
{
std::string const errmsg = !std::empty(buf) ? buf.toString() : _("Connection failed");
logdbg(tracker.key, errmsg);
tracker.failAll(true, false, errmsg);
} }
tau_tracker_upkeep(&tracker); void send_request(void const* payload, size_t payload_len)
}
static void tau_tracker_timeout_reqs(struct tau_tracker* tracker)
{
time_t const now = time(nullptr);
bool const cancel_all = tracker->close_at != 0 && (tracker->close_at <= now);
if (tracker->connecting_at != 0 && tracker->connecting_at + TauRequestTtl < now)
{ {
auto empty_buf = libtransmission::Buffer{}; logdbg(this->key, fmt::format("sending request w/connection id {}", this->connection_id));
on_tracker_connection_response(*tracker, TAU_ACTION_ERROR, empty_buf);
}
if (auto& reqs = tracker->announces; !std::empty(reqs))
{
for (auto it = std::begin(reqs); it != std::end(reqs);)
{
auto& req = *it;
if (cancel_all || req.created_at + TauRequestTtl < now)
{
logtrace(tracker->key, fmt::format("timeout announce req {}", fmt::ptr(&req)));
req.fail(false, true, "");
it = reqs.erase(it);
}
else
{
++it;
}
}
}
if (auto& reqs = tracker->scrapes; !std::empty(reqs))
{
for (auto it = std::begin(reqs); it != std::end(reqs);)
{
auto& req = *it;
if (cancel_all || req.created_at + TauRequestTtl < now)
{
logtrace(tracker->key, fmt::format("timeout scrape req {}", fmt::ptr(&req)));
req.fail(false, true, "");
it = reqs.erase(it);
}
else
{
++it;
}
}
}
}
static void tau_tracker_upkeep_ex(struct tau_tracker* tracker, bool timeout_reqs)
{
time_t const now = tr_time();
bool const closing = tracker->close_at != 0;
/* if the address info is too old, expire it */
if (tracker->addr_ && (closing || tracker->addr_expires_at_ <= now))
{
logtrace(tracker->host, "Expiring old DNS result");
tracker->addr_.reset();
tracker->addr_expires_at_ = 0;
}
/* are there any requests pending? */
if (tracker->isIdle())
{
return;
}
// if DNS lookup *recently* failed for this host, do nothing
if (!tracker->addr_ && now < tracker->addr_expires_at_)
{
return;
}
/* if we don't have an address yet, try & get one now. */
if (!closing && !tracker->addr_ && (tracker->dns_request_ == 0U))
{
auto hints = libtransmission::Dns::Hints{};
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_DGRAM;
hints.ai_protocol = IPPROTO_UDP;
logtrace(tracker->host, "Trying a new DNS lookup");
tracker->dns_request_ = tracker->mediator_.dns().lookup(
tracker->host.sv(),
[tracker](sockaddr const* sa, socklen_t len, time_t expires_at)
{ tau_tracker_on_dns(tracker, sa, len, expires_at); },
hints);
return;
}
logtrace(
tracker->key,
fmt::format(
"connected {} ({} {}) -- connecting_at {}",
tracker->connection_expiration_time > now,
tracker->connection_expiration_time,
now,
tracker->connecting_at));
/* also need a valid connection ID... */
if (tracker->addr_ && tracker->connection_expiration_time <= now && tracker->connecting_at == 0)
{
tracker->connecting_at = now;
tracker->connection_transaction_id = tau_transaction_new();
logtrace(tracker->key, fmt::format("Trying to connect. Transaction ID is {}", tracker->connection_transaction_id));
auto buf = libtransmission::Buffer{}; auto buf = libtransmission::Buffer{};
buf.addUint64(0x41727101980LL); buf.addUint64(this->connection_id);
buf.addUint32(TAU_ACTION_CONNECT); buf.add(payload, payload_len);
buf.addUint32(tracker->connection_transaction_id);
auto const contiguous = std::vector<std::byte>(std::begin(buf), std::end(buf)); auto const contiguous = std::vector<std::byte>(std::begin(buf), std::end(buf));
tracker->sendto(std::data(contiguous), std::size(contiguous)); this->sendto(std::data(contiguous), std::size(contiguous));
return;
} }
if (timeout_reqs) public:
{ tr_interned_string const key;
tau_tracker_timeout_reqs(tracker); tr_interned_string const host;
} tr_port const port;
if (tracker->addr_ && tracker->connection_expiration_time > now) time_t connecting_at = 0;
{ time_t connection_expiration_time = 0;
tau_tracker_send_reqs(tracker); tau_connection_t connection_id = {};
} tau_transaction_t connection_transaction_id = {};
}
static void tau_tracker_upkeep(struct tau_tracker* tracker) time_t close_at = 0;
{
tau_tracker_upkeep_ex(tracker, true); std::list<tau_announce_request> announces;
} std::list<tau_scrape_request> scrapes;
private:
Mediator& mediator_;
std::optional<std::future<MaybeSockaddr>> addr_pending_dns_ = {};
MaybeSockaddr addr_ = {};
time_t addr_expires_at_ = 0;
static time_t constexpr DnsRetryIntervalSecs = 60 * 60;
static auto constexpr TauRequestTtl = int{ 60 };
};
/**** /****
*****
***** SESSION ***** SESSION
*****
****/ ****/
class tr_announcer_udp_impl final : public tr_announcer_udp class tr_announcer_udp_impl final : public tr_announcer_udp
@@ -652,8 +602,8 @@ public:
// Since size of IP field is only 4 bytes long, we can only announce IPv4 addresses // Since size of IP field is only 4 bytes long, we can only announce IPv4 addresses
auto const addr = mediator_.announceIP(); auto const addr = mediator_.announceIP();
uint32_t const announce_ip = addr && addr->isIPv4() ? addr->addr.addr4.s_addr : 0; uint32_t const announce_ip = addr && addr->isIPv4() ? addr->addr.addr4.s_addr : 0;
tracker->announces.push_back(make_tau_announce_request(announce_ip, request, response_func, user_data)); tracker->announces.emplace_back(announce_ip, request, response_func, user_data);
tau_tracker_upkeep_ex(tracker, false); tracker->upkeep(false);
} }
void scrape(tr_scrape_request const& request, tr_scrape_response_func response_func, void* user_data) override void scrape(tr_scrape_request const& request, tr_scrape_response_func response_func, void* user_data) override
@@ -664,15 +614,15 @@ public:
return; return;
} }
tracker->scrapes.push_back(make_tau_scrape_request(request, response_func, user_data)); tracker->scrapes.emplace_back(request, response_func, user_data);
tau_tracker_upkeep_ex(tracker, false); tracker->upkeep(false);
} }
void upkeep() override void upkeep() override
{ {
for (auto& tracker : trackers_) for (auto& tracker : trackers_)
{ {
tau_tracker_upkeep(&tracker); tracker.upkeep();
} }
} }
@@ -690,14 +640,8 @@ public:
for (auto& tracker : trackers_) for (auto& tracker : trackers_)
{ {
// if there's a pending DNS request, cancel it
if (tracker.dns_request_ != 0U)
{
mediator_.dns().cancel(tracker.dns_request_);
}
tracker.close_at = now + 3; tracker.close_at = now + 3;
tau_tracker_upkeep(&tracker); tracker.upkeep();
} }
} }
@@ -715,7 +659,7 @@ public:
buf.add(msg, msglen); buf.add(msg, msglen);
auto const action_id = static_cast<tau_action_t>(buf.toUint32()); auto const action_id = static_cast<tau_action_t>(buf.toUint32());
if (!is_tau_response_message(action_id, msglen)) if (!isResponseMessage(action_id, msglen))
{ {
return false; return false;
} }
@@ -729,7 +673,7 @@ public:
if (tracker.connecting_at != 0 && transaction_id == tracker.connection_transaction_id) if (tracker.connecting_at != 0 && transaction_id == tracker.connection_transaction_id)
{ {
logtrace(tracker.key, fmt::format("{} is my connection request!", transaction_id)); logtrace(tracker.key, fmt::format("{} is my connection request!", transaction_id));
on_tracker_connection_response(tracker, action_id, buf); tracker.on_connection_response(action_id, buf);
return true; return true;
} }
@@ -802,6 +746,31 @@ private:
return tracker; return tracker;
} }
[[nodiscard]] static constexpr bool isResponseMessage(tau_action_t action, size_t msglen) noexcept
{
if (action == TAU_ACTION_CONNECT)
{
return msglen == 16;
}
if (action == TAU_ACTION_ANNOUNCE)
{
return msglen >= 20;
}
if (action == TAU_ACTION_SCRAPE)
{
return msglen >= 20;
}
if (action == TAU_ACTION_ERROR)
{
return msglen >= 8;
}
return false;
}
std::list<tau_tracker> trackers_; std::list<tau_tracker> trackers_;
Mediator& mediator_; Mediator& mediator_;

View File

@@ -23,11 +23,6 @@
struct tr_announcer; struct tr_announcer;
struct tr_torrent_announcer; struct tr_torrent_announcer;
namespace libtransmission
{
class Dns;
} // namespace libtransmission
/** /**
* *** Tracker Publish / Subscribe * *** Tracker Publish / Subscribe
* **/ * **/
@@ -301,7 +296,6 @@ public:
public: public:
virtual ~Mediator() noexcept = default; virtual ~Mediator() noexcept = default;
virtual void sendto(void const* buf, size_t buflen, sockaddr const* addr, socklen_t addrlen) = 0; virtual void sendto(void const* buf, size_t buflen, sockaddr const* addr, socklen_t addrlen) = 0;
[[nodiscard]] virtual libtransmission::Dns& dns() = 0;
[[nodiscard]] virtual std::optional<tr_address> announceIP() const = 0; [[nodiscard]] virtual std::optional<tr_address> announceIP() const = 0;
}; };

View File

@@ -1,210 +0,0 @@
// This file Copyright 2022 Mnemosyne LLC.
// It may be used under GPLv2 (SPDX: GPL-2.0-only), GPLv3 (SPDX: GPL-3.0-only),
// or any future license endorsed by Mnemosyne LLC.
// License text can be found in the licenses/ folder.
#pragma once
#ifndef __TRANSMISSION__
#error only libtransmission should #include this header.
#endif
#include <cstring> // for std::memcpy()
#include <ctime>
#include <list>
#include <map>
#include <memory>
#include <utility>
#include <event2/dns.h>
#include <event2/event.h>
#include "dns.h"
#include "utils.h" // for tr_strlower()
namespace libtransmission
{
class EvDns final : public Dns
{
private:
using Key = std::pair<std::string, Hints>;
struct CacheEntry
{
sockaddr_storage ss_ = {};
socklen_t sslen_ = {};
time_t expires_at_ = {};
};
struct CallbackArg
{
Key key;
EvDns* self;
};
struct Request
{
evdns_getaddrinfo_request* request;
struct CallbackInfo
{
CallbackInfo(Tag tag, Callback callback)
: tag_{ tag }
, callback_{ std::move(callback) }
{
}
Tag tag_;
Callback callback_;
};
std::list<CallbackInfo> callbacks;
};
public:
using TimeFunc = time_t (*)();
EvDns(struct event_base* event_base, TimeFunc time_func)
: time_func_{ time_func }
, evdns_base_{ evdns_base_new(event_base, EVDNS_BASE_INITIALIZE_NAMESERVERS),
[](evdns_base* dns)
{
// if zero, active requests will be aborted
evdns_base_free(dns, 0);
} }
{
}
~EvDns() override
{
for (auto& [key, request] : requests_)
{
evdns_getaddrinfo_cancel(request.request);
}
}
std::optional<std::pair<sockaddr const*, socklen_t>> cached(std::string_view address, Hints hints = {}) const override
{
if (auto const* entry = cached(makeKey(address, hints)); entry != nullptr)
{
return std::make_pair(reinterpret_cast<sockaddr const*>(&entry->ss_), entry->sslen_);
}
return {};
}
Tag lookup(std::string_view address, Callback&& callback, Hints hints = {}) override
{
auto const key = makeKey(address, hints);
if (auto const* entry = cached(key); entry)
{
callback(reinterpret_cast<sockaddr const*>(&entry->ss_), entry->sslen_, entry->expires_at_);
return {};
}
auto& request = requests_[key];
auto const tag = next_tag_;
++next_tag_;
request.callbacks.emplace_back(tag, std::move(callback));
if (request.request == nullptr)
{
auto evhints = evutil_addrinfo{};
evhints.ai_family = hints.ai_family;
evhints.ai_socktype = hints.ai_socktype;
evhints.ai_protocol = hints.ai_protocol;
void* const arg = new CallbackArg{ key, this };
request.request = evdns_getaddrinfo(evdns_base_.get(), key.first.c_str(), nullptr, &evhints, evcallback, arg);
}
return tag;
}
void cancel(Tag tag) override
{
for (auto& [key, request] : requests_)
{
for (auto iter = std::begin(request.callbacks), end = std::end(request.callbacks); iter != end; ++iter)
{
if (iter->tag_ != tag)
{
continue;
}
iter->callback_(nullptr, 0, 0);
request.callbacks.erase(iter);
// if this was the last pending request for `key`, cancel the evdns request
if (std::empty(request.callbacks))
{
evdns_getaddrinfo_cancel(request.request);
requests_.erase(key);
}
return;
}
}
}
private:
[[nodiscard]] static Key makeKey(std::string_view address, Hints hints)
{
return Key{ tr_strlower(address), hints };
}
[[nodiscard]] CacheEntry const* cached(Key const& key) const
{
if (auto iter = cache_.find(key); iter != std::end(cache_))
{
auto const& entry = iter->second;
if (auto const now = time_func_(); entry.expires_at_ > now)
{
return &entry;
}
cache_.erase(iter); // expired
}
return nullptr;
}
static void evcallback(int /*result*/, struct evutil_addrinfo* res, void* varg)
{
auto* const arg = static_cast<CallbackArg*>(varg);
auto [key, self] = *arg;
delete arg;
auto& cache_entry = self->cache_[key];
if (res != nullptr)
{
cache_entry.expires_at_ = self->time_func_() + CacheTtlSecs;
cache_entry.sslen_ = res->ai_addrlen;
std::memcpy(&cache_entry.ss_, res->ai_addr, res->ai_addrlen);
evutil_freeaddrinfo(res);
}
if (auto request_entry = self->requests_.extract(key); request_entry)
{
for (auto& callback : request_entry.mapped().callbacks)
{
callback.callback_(
reinterpret_cast<sockaddr const*>(&cache_entry.ss_),
cache_entry.sslen_,
cache_entry.expires_at_);
}
}
}
TimeFunc const time_func_;
static time_t constexpr CacheTtlSecs = 3600U;
std::unique_ptr<evdns_base, void (*)(evdns_base*)> const evdns_base_;
mutable std::map<Key, CacheEntry> cache_;
std::map<Key, Request> requests_;
unsigned int next_tag_ = 1;
};
} // namespace libtransmission

View File

@@ -1,72 +0,0 @@
// This file Copyright 2022 Mnemosyne LLC.
// It may be used under GPLv2 (SPDX: GPL-2.0-only), GPLv3 (SPDX: GPL-3.0-only),
// or any future license endorsed by Mnemosyne LLC.
// License text can be found in the licenses/ folder.
#pragma once
#include <functional>
#include <string_view>
#include "transmission.h"
#include "net.h"
namespace libtransmission
{
class Dns
{
public:
virtual ~Dns() = default;
using Callback = std::function<void(struct sockaddr const*, socklen_t salen, time_t expires_at)>;
using Tag = unsigned int;
class Hints
{
public:
Hints()
{
}
int ai_family = AF_UNSPEC;
int ai_socktype = SOCK_DGRAM;
int ai_protocol = IPPROTO_UDP;
[[nodiscard]] constexpr int compare(Hints const& that) const noexcept // <=>
{
if (ai_family != that.ai_family)
{
return ai_family < that.ai_family ? -1 : 1;
}
if (ai_socktype != that.ai_socktype)
{
return ai_socktype < that.ai_socktype ? -1 : 1;
}
if (ai_protocol != that.ai_protocol)
{
return ai_protocol < that.ai_protocol ? -1 : 1;
}
return 0;
}
[[nodiscard]] constexpr bool operator<(Hints const& that) const noexcept
{
return compare(that) < 0;
}
};
[[nodiscard]] virtual std::optional<std::pair<struct sockaddr const*, socklen_t>> cached(
std::string_view address,
Hints hints = {}) const = 0;
virtual Tag lookup(std::string_view address, Callback&& callback, Hints hints = {}) = 0;
virtual void cancel(Tag) = 0;
};
} // namespace libtransmission

View File

@@ -26,7 +26,6 @@
#include <sys/stat.h> /* umask() */ #include <sys/stat.h> /* umask() */
#endif #endif
#include <event2/dns.h>
#include <event2/event.h> #include <event2/event.h>
#include <fmt/chrono.h> #include <fmt/chrono.h>
@@ -40,7 +39,6 @@
#include "blocklist.h" #include "blocklist.h"
#include "cache.h" #include "cache.h"
#include "crypto-utils.h" #include "crypto-utils.h"
#include "dns-ev.h"
#include "error-types.h" #include "error-types.h"
#include "error.h" #include "error.h"
#include "file.h" #include "file.h"
@@ -2175,7 +2173,6 @@ tr_session::tr_session(std::string_view config_dir, tr_variant* settings_dict)
, blocklist_dir_{ makeBlocklistDir(config_dir) } , blocklist_dir_{ makeBlocklistDir(config_dir) }
, session_thread_{ tr_session_thread::create() } , session_thread_{ tr_session_thread::create() }
, timer_maker_{ std::make_unique<libtransmission::EvTimerMaker>(eventBase()) } , timer_maker_{ std::make_unique<libtransmission::EvTimerMaker>(eventBase()) }
, dns_{ std::make_unique<libtransmission::EvDns>(eventBase(), tr_time) }
, settings_{ settings_dict } , settings_{ settings_dict }
, session_id_{ tr_time } , session_id_{ tr_time }
, peer_mgr_{ tr_peerMgrNew(this), tr_peerMgrFree } , peer_mgr_{ tr_peerMgrNew(this), tr_peerMgrFree }

View File

@@ -31,7 +31,6 @@
#include "bandwidth.h" #include "bandwidth.h"
#include "bitfield.h" #include "bitfield.h"
#include "cache.h" #include "cache.h"
#include "dns.h"
#include "interned-string.h" #include "interned-string.h"
#include "net.h" // tr_socket_t #include "net.h" // tr_socket_t
#include "open-files.h" #include "open-files.h"
@@ -148,11 +147,6 @@ private:
return tr_address::fromString(session_.announceIP()); return tr_address::fromString(session_.announceIP());
} }
[[nodiscard]] libtransmission::Dns& dns() override
{
return *session_.dns_.get();
}
private: private:
tr_session& session_; tr_session& session_;
}; };
@@ -1040,9 +1034,6 @@ private:
// depends-on: session_thread_ // depends-on: session_thread_
std::unique_ptr<libtransmission::TimerMaker> const timer_maker_; std::unique_ptr<libtransmission::TimerMaker> const timer_maker_;
// depends-on: event_base_
std::unique_ptr<libtransmission::Dns> const dns_;
/// trivial type fields /// trivial type fields
tr_session_settings settings_; tr_session_settings settings_;
@@ -1149,7 +1140,7 @@ private:
// depends-on: lpd_mediator_ // depends-on: lpd_mediator_
std::unique_ptr<tr_lpd> lpd_; std::unique_ptr<tr_lpd> lpd_;
// depends-on: dns_, udp_core_ // depends-on: udp_core_
AnnouncerUdpMediator announcer_udp_mediator_{ *this }; AnnouncerUdpMediator announcer_udp_mediator_{ *this };
// depends-on: timer_maker_, torrents_, peer_mgr_ // depends-on: timer_maker_, torrents_, peer_mgr_

View File

@@ -64,14 +64,14 @@ public:
virtual ~TimerMaker() = default; virtual ~TimerMaker() = default;
[[nodiscard]] virtual std::unique_ptr<Timer> create() = 0; [[nodiscard]] virtual std::unique_ptr<Timer> create() = 0;
[[nodiscard]] virtual std::unique_ptr<Timer> create(std::function<void()> callback) [[nodiscard]] std::unique_ptr<Timer> create(std::function<void()> callback)
{ {
auto timer = create(); auto timer = create();
timer->setCallback(std::move(callback)); timer->setCallback(std::move(callback));
return timer; return timer;
} }
[[nodiscard]] virtual std::unique_ptr<Timer> create(Timer::CStyleCallback callback, void* user_data) [[nodiscard]] std::unique_ptr<Timer> create(Timer::CStyleCallback callback, void* user_data)
{ {
auto timer = create(); auto timer = create();
timer->setCallback(callback, user_data); timer->setCallback(callback, user_data);

View File

@@ -14,8 +14,8 @@
#include "announcer.h" #include "announcer.h"
#include "crypto-utils.h" #include "crypto-utils.h"
#include "dns.h"
#include "peer-mgr.h" // for tr_pex #include "peer-mgr.h" // for tr_pex
#include "timer-ev.h"
#include "tr-buffer.h" #include "tr-buffer.h"
#include "test-fixtures.h" #include "test-fixtures.h"
@@ -32,31 +32,6 @@ private:
} }
protected: protected:
class MockDns final : public libtransmission::Dns
{
public:
[[nodiscard]] std::optional<std::pair<struct sockaddr const*, socklen_t>> cached(
std::string_view /*address*/,
Hints /*hints*/ = {}) const override
{
return {};
}
Tag lookup(std::string_view address, Callback&& callback, Hints /*hints*/) override
{
auto const addr = tr_address::fromString(address); // mock has no actual DNS, just parsing e.g. inet_pton
auto [ss, sslen] = addr->toSockaddr(Port);
callback(reinterpret_cast<sockaddr const*>(&ss), sslen, tr_time() + 3600); // 1hr ttl
return {};
}
void cancel(Tag /*tag*/) override
{
}
static auto constexpr Port = tr_port::fromHost(443);
};
class MockMediator final : public tr_announcer_udp::Mediator class MockMediator final : public tr_announcer_udp::Mediator
{ {
public: public:
@@ -77,11 +52,6 @@ protected:
return event_base_.get(); return event_base_.get();
} }
[[nodiscard]] libtransmission::Dns& dns() override
{
return dns_;
}
[[nodiscard]] std::optional<tr_address> announceIP() const override [[nodiscard]] std::optional<tr_address> announceIP() const override
{ {
return {}; return {};
@@ -106,8 +76,6 @@ protected:
std::deque<Sent> sent_; std::deque<Sent> sent_;
std::unique_ptr<event_base, void (*)(event_base*)> const event_base_; std::unique_ptr<event_base, void (*)(event_base*)> const event_base_;
MockDns dns_;
}; };
static void expectEqual(tr_scrape_response const& expected, tr_scrape_response const& actual) static void expectEqual(tr_scrape_response const& expected, tr_scrape_response const& actual)
@@ -199,7 +167,6 @@ protected:
[[nodiscard]] static auto waitForAnnouncerToSendMessage(MockMediator& mediator) [[nodiscard]] static auto waitForAnnouncerToSendMessage(MockMediator& mediator)
{ {
EXPECT_FALSE(std::empty(mediator.sent_));
libtransmission::test::waitFor(mediator.eventBase(), [&mediator]() { return !std::empty(mediator.sent_); }); libtransmission::test::waitFor(mediator.eventBase(), [&mediator]() { return !std::empty(mediator.sent_); });
auto buf = libtransmission::Buffer(mediator.sent_.back().buf_); auto buf = libtransmission::Buffer(mediator.sent_.back().buf_);
mediator.sent_.pop_back(); mediator.sent_.pop_back();
@@ -309,6 +276,16 @@ protected:
return req; return req;
} }
// emulate the upkeep timer that tr_announcer runs in production
static auto createUpkeepTimer(MockMediator& mediator, std::unique_ptr<tr_announcer_udp>& announcer)
{
auto timer_maker = libtransmission::EvTimerMaker{ mediator.eventBase() };
auto timer = timer_maker.create();
timer->setCallback([&announcer]() { announcer->upkeep(); });
timer->startRepeating(200ms);
return timer;
}
// https://www.bittorrent.org/beps/bep_0015.html // https://www.bittorrent.org/beps/bep_0015.html
static auto constexpr ProtocolId = uint64_t{ 0x41727101980ULL }; static auto constexpr ProtocolId = uint64_t{ 0x41727101980ULL };
static auto constexpr ConnectAction = uint32_t{ 0 }; static auto constexpr ConnectAction = uint32_t{ 0 };
@@ -330,6 +307,7 @@ TEST_F(AnnouncerUdpTest, canScrape)
{ {
auto mediator = MockMediator{}; auto mediator = MockMediator{};
auto announcer = tr_announcer_udp::create(mediator); auto announcer = tr_announcer_udp::create(mediator);
auto upkeep_timer = createUpkeepTimer(mediator, announcer);
// tell announcer to scrape // tell announcer to scrape
auto [request, expected_response] = buildSimpleScrapeRequestAndResponse(); auto [request, expected_response] = buildSimpleScrapeRequestAndResponse();
@@ -396,6 +374,7 @@ TEST_F(AnnouncerUdpTest, canDestructCleanlyEvenWhenBusy)
{ {
auto mediator = MockMediator{}; auto mediator = MockMediator{};
auto announcer = tr_announcer_udp::create(mediator); auto announcer = tr_announcer_udp::create(mediator);
auto upkeep_timer = createUpkeepTimer(mediator, announcer);
// tell announcer to scrape // tell announcer to scrape
auto [request, expected_response] = buildSimpleScrapeRequestAndResponse(); auto [request, expected_response] = buildSimpleScrapeRequestAndResponse();
@@ -420,6 +399,7 @@ TEST_F(AnnouncerUdpTest, canMultiScrape)
{ {
auto mediator = MockMediator{}; auto mediator = MockMediator{};
auto announcer = tr_announcer_udp::create(mediator); auto announcer = tr_announcer_udp::create(mediator);
auto upkeep_timer = createUpkeepTimer(mediator, announcer);
auto expected_response = tr_scrape_response{}; auto expected_response = tr_scrape_response{};
expected_response.did_connect = true; expected_response.did_connect = true;
@@ -491,6 +471,7 @@ TEST_F(AnnouncerUdpTest, canHandleScrapeError)
// build the announcer // build the announcer
auto mediator = MockMediator{}; auto mediator = MockMediator{};
auto announcer = tr_announcer_udp::create(mediator); auto announcer = tr_announcer_udp::create(mediator);
auto upkeep_timer = createUpkeepTimer(mediator, announcer);
// tell announcer to scrape // tell announcer to scrape
auto response = std::optional<tr_scrape_response>{}; auto response = std::optional<tr_scrape_response>{};
@@ -540,6 +521,7 @@ TEST_F(AnnouncerUdpTest, canHandleConnectError)
// build the announcer // build the announcer
auto mediator = MockMediator{}; auto mediator = MockMediator{};
auto announcer = tr_announcer_udp::create(mediator); auto announcer = tr_announcer_udp::create(mediator);
auto upkeep_timer = createUpkeepTimer(mediator, announcer);
// tell the announcer to scrape // tell the announcer to scrape
auto response = std::optional<tr_scrape_response>{}; auto response = std::optional<tr_scrape_response>{};
@@ -573,6 +555,7 @@ TEST_F(AnnouncerUdpTest, handleMessageReturnsFalseOnInvalidMessage)
// build the announcer // build the announcer
auto mediator = MockMediator{}; auto mediator = MockMediator{};
auto announcer = tr_announcer_udp::create(mediator); auto announcer = tr_announcer_udp::create(mediator);
auto upkeep_timer = createUpkeepTimer(mediator, announcer);
// tell the announcer to scrape // tell the announcer to scrape
auto response = std::optional<tr_scrape_response>{}; auto response = std::optional<tr_scrape_response>{};
@@ -658,6 +641,7 @@ TEST_F(AnnouncerUdpTest, canAnnounce)
// build the announcer // build the announcer
auto mediator = MockMediator{}; auto mediator = MockMediator{};
auto announcer = tr_announcer_udp::create(mediator); auto announcer = tr_announcer_udp::create(mediator);
auto upkeep_timer = createUpkeepTimer(mediator, announcer);
auto response = std::optional<tr_announce_response>{}; auto response = std::optional<tr_announce_response>{};
announcer->announce( announcer->announce(