refactor: use tr_rand_buf() in tr_bandwidth::phaseOne() (#4404) (#4411)

This commit is contained in:
Charles Kerr
2022-12-19 17:31:24 -06:00
committed by GitHub
parent d290ece0c8
commit 5493ed644e
4 changed files with 41 additions and 9 deletions

View File

@@ -4,7 +4,6 @@
// License text can be found in the licenses/ folder.
#include <algorithm>
#include <random> // for std::mt19937
#include <utility> // for std::swap()
#include <vector>
@@ -13,6 +12,7 @@
#include "transmission.h"
#include "bandwidth.h"
#include "crypto-utils.h"
#include "log.h"
#include "peer-io.h"
#include "tr-assert.h"
@@ -164,19 +164,19 @@ void tr_bandwidth::allocateBandwidth(
}
}
void tr_bandwidth::phaseOne(std::vector<tr_peerIo*>& peer_array, tr_direction dir)
void tr_bandwidth::phaseOne(std::vector<tr_peerIo*>& peers, tr_direction dir)
{
// First phase of IO. Tries to distribute bandwidth fairly to keep faster
// peers from starving the others.
tr_logAddTrace(fmt::format("{} peers to go round-robin for {}", peer_array.size(), dir == TR_UP ? "upload" : "download"));
tr_logAddTrace(fmt::format("{} peers to go round-robin for {}", peers.size(), dir == TR_UP ? "upload" : "download"));
// Shuffle the peers so they all have equal chance to be first in line.
thread_local auto random_engine = std::mt19937{ std::random_device{}() };
std::shuffle(std::begin(peer_array), std::end(peer_array), random_engine);
thread_local auto urbg = tr_urbg<size_t>{};
std::shuffle(std::begin(peers), std::end(peers), urbg);
// Give each peer `Increment` bandwidth bytes to use. Repeat this
// process until we run out of bandwidth and/or peers that can use it.
for (size_t n_unfinished = std::size(peer_array); n_unfinished > 0U;)
for (size_t n_unfinished = std::size(peers); n_unfinished > 0U;)
{
for (size_t i = 0; i < n_unfinished;)
{
@@ -185,13 +185,13 @@ void tr_bandwidth::phaseOne(std::vector<tr_peerIo*>& peer_array, tr_direction di
// out in a timely manner.
static auto constexpr Increment = size_t{ 3000 };
auto const bytes_used = peer_array[i]->flush(dir, Increment);
auto const bytes_used = peers[i]->flush(dir, Increment);
tr_logAddTrace(fmt::format("peer #{} of {} used {} bytes in this pass", i, n_unfinished, bytes_used));
if (bytes_used != Increment)
{
// peer is done writing for now; move it to the end of the list
std::swap(peer_array[i], peer_array[n_unfinished - 1]);
std::swap(peers[i], peers[n_unfinished - 1]);
--n_unfinished;
}
else

View File

@@ -252,7 +252,7 @@ private:
[[nodiscard]] size_t clamp(uint64_t now, tr_direction dir, size_t byte_count) const;
static void phaseOne(std::vector<tr_peerIo*>& peer_array, tr_direction dir);
static void phaseOne(std::vector<tr_peerIo*>& peers, tr_direction dir);
void allocateBandwidth(
tr_priority_t parent_priority,

View File

@@ -9,6 +9,7 @@
#include <array>
#include <cstddef> // size_t
#include <cstdint>
#include <limits>
#include <memory>
#include <optional>
#include <string>
@@ -187,6 +188,34 @@ private:
std::array<T, N> buf;
};
// UniformRandomBitGenerator impl that uses `tr_rand_buffer()`.
// See https://en.cppreference.com/w/cpp/named_req/UniformRandomBitGenerator
template<typename T, size_t N = 1024U>
class tr_urbg
{
public:
using result_type = T;
static_assert(!std::numeric_limits<T>::is_signed);
[[nodiscard]] static constexpr T min() noexcept
{
return std::numeric_limits<T>::min();
}
[[nodiscard]] static constexpr T max() noexcept
{
return std::numeric_limits<T>::max();
}
[[nodiscard]] T operator()() noexcept
{
return buf_();
}
private:
tr_salt_shaker<T, N> buf_;
};
/** @} */
#endif /* TR_CRYPTO_UTILS_H */

View File

@@ -30,6 +30,7 @@
#define tr_ssha1_test tr_ssha1_test_
#define tr_ssl_ctx_t tr_ssl_ctx_t_
#define tr_ssl_get_x509_store tr_ssl_get_x509_store_
#define tr_urbg tr_urbg_
#define tr_x509_cert_free tr_x509_cert_free_
#define tr_x509_cert_new tr_x509_cert_new_
#define tr_x509_cert_t tr_x509_cert_t_
@@ -63,6 +64,7 @@
#undef tr_ssha1_test
#undef tr_ssl_ctx_t
#undef tr_ssl_get_x509_store
#undef tr_urbg
#undef tr_x509_cert_free
#undef tr_x509_cert_new
#undef tr_x509_cert_t
@@ -99,6 +101,7 @@
#define tr_ssha1_test_ tr_ssha1_test
#define tr_ssl_ctx_t_ tr_ssl_ctx_t
#define tr_ssl_get_x509_store_ tr_ssl_get_x509_store
#define tr_urbg_ tr_urbg
#define tr_x509_cert_free_ tr_x509_cert_free
#define tr_x509_cert_new_ tr_x509_cert_new
#define tr_x509_cert_t_ tr_x509_cert_t