refactor: make sha1, sha256 RAII safe (#3556)

This commit is contained in:
Charles Kerr
2022-07-31 15:58:14 -05:00
committed by GitHub
parent 154d5cf497
commit 18e2a04f88
22 changed files with 463 additions and 498 deletions

View File

@@ -47,7 +47,7 @@ tr_auto_option(USE_SYSTEM_UTP "Use system utp library" AUTO)
tr_auto_option(USE_SYSTEM_B64 "Use system b64 library" AUTO)
tr_auto_option(USE_SYSTEM_PSL "Use system psl library" AUTO)
tr_list_option(USE_QT_VERSION "Use specific Qt version" AUTO 5 6)
tr_list_option(WITH_CRYPTO "Use specified crypto library" AUTO openssl cyassl polarssl ccrypto)
tr_list_option(WITH_CRYPTO "Use specified crypto library" AUTO ccrypto cyassl mbedtls openssl polarssl wolfssl)
tr_auto_option(WITH_INOTIFY "Enable inotify support (on systems that support it)" AUTO)
tr_auto_option(WITH_KQUEUE "Enable kqueue support (on systems that support it)" AUTO)
tr_auto_option(WITH_LIBAPPINDICATOR "Use libappindicator in GTK+ client" AUTO)
@@ -189,7 +189,7 @@ if(WITH_CRYPTO STREQUAL "AUTO" OR WITH_CRYPTO STREQUAL "openssl")
set(CRYPTO_LIBRARIES ${OPENSSL_LIBRARIES})
endif()
endif()
if(WITH_CRYPTO STREQUAL "AUTO" OR WITH_CRYPTO STREQUAL "cyassl")
if(WITH_CRYPTO STREQUAL "AUTO" OR WITH_CRYPTO STREQUAL "cyassl" OR WITH_CRYPTO STREQUAL "wolfssl")
tr_get_required_flag(WITH_CRYPTO CYASSL_IS_REQUIRED)
find_package(CyaSSL ${CYASSL_MINIMUM} ${CYASSL_IS_REQUIRED})
tr_fixup_list_option(WITH_CRYPTO "cyassl" CYASSL_FOUND "AUTO" CYASSL_IS_REQUIRED)
@@ -199,7 +199,7 @@ if(WITH_CRYPTO STREQUAL "AUTO" OR WITH_CRYPTO STREQUAL "cyassl")
set(CRYPTO_LIBRARIES ${CYASSL_LIBRARIES})
endif()
endif()
if(WITH_CRYPTO STREQUAL "AUTO" OR WITH_CRYPTO STREQUAL "polarssl")
if(WITH_CRYPTO STREQUAL "AUTO" OR WITH_CRYPTO STREQUAL "polarssl" OR WITH_CRYPTO STREQUAL "mbedtls")
tr_get_required_flag(WITH_CRYPTO POLARSSL_IS_REQUIRED)
find_package(PolarSSL ${POLARSSL_MINIMUM} ${POLARSSL_IS_REQUIRED})
tr_fixup_list_option(WITH_CRYPTO "polarssl" POLARSSL_FOUND "AUTO" POLARSSL_IS_REQUIRED)

View File

@@ -127,7 +127,7 @@ enum tr_announce_event
TR_ANNOUNCE_EVENT_STOPPED,
};
char const* tr_announce_event_get_string(tr_announce_event);
std::string_view tr_announce_event_get_string(tr_announce_event);
struct tr_announce_request
{

View File

@@ -44,9 +44,10 @@ using namespace std::literals;
*****
****/
static char const* get_event_string(tr_announce_request const* req)
static std::string_view get_event_string(tr_announce_request const* req)
{
return req->partial_seed && (req->event != TR_ANNOUNCE_EVENT_STOPPED) ? "paused" : tr_announce_event_get_string(req->event);
return req->partial_seed && (req->event != TR_ANNOUNCE_EVENT_STOPPED) ? "paused"sv :
tr_announce_event_get_string(req->event);
}
static tr_urlbuf announce_url_new(tr_session const* session, tr_announce_request const* req)
@@ -91,7 +92,7 @@ static tr_urlbuf announce_url_new(tr_session const* session, tr_announce_request
fmt::format_to(out, "&corrupt={}", req->corrupt);
}
if (char const* str = get_event_string(req); !tr_str_is_empty(str))
if (auto const str = get_event_string(req); !std::empty(str))
{
fmt::format_to(out, "&event={}", str);
}

View File

@@ -70,21 +70,21 @@ static auto constexpr TrMultiscrapeStep = int{ 5 };
****
***/
char const* tr_announce_event_get_string(tr_announce_event e)
std::string_view tr_announce_event_get_string(tr_announce_event e)
{
switch (e)
{
case TR_ANNOUNCE_EVENT_COMPLETED:
return "completed";
return "completed"sv;
case TR_ANNOUNCE_EVENT_STARTED:
return "started";
return "started"sv;
case TR_ANNOUNCE_EVENT_STOPPED:
return "stopped";
return "stopped"sv;
default:
return "";
return ""sv;
}
}

View File

@@ -99,78 +99,89 @@ bool check_ccrypto_result(CCCryptorStatus result, char const* file, int line)
****
***/
tr_sha1_ctx_t tr_sha1_init(void)
namespace
{
auto* handle = new CC_SHA1_CTX();
CC_SHA1_Init(handle);
return handle;
}
bool tr_sha1_update(tr_sha1_ctx_t handle, void const* data, size_t data_length)
class Sha1Impl final : public tr_sha1
{
TR_ASSERT(handle != nullptr);
if (data_length == 0)
public:
Sha1Impl()
{
return true;
clear();
}
TR_ASSERT(data != nullptr);
~Sha1Impl() override = default;
CC_SHA1_Update(static_cast<CC_SHA1_CTX*>(handle), data, data_length);
return true;
}
void clear() override
{
CC_SHA1_Init(&handle_);
}
std::optional<tr_sha1_digest_t> tr_sha1_final(tr_sha1_ctx_t raw_handle)
{
TR_ASSERT(raw_handle != nullptr);
auto* handle = static_cast<CC_SHA1_CTX*>(raw_handle);
void add(void const* data, size_t data_length) override
{
if (data_length > 0U)
{
CC_SHA1_Update(&handle_, data, data_length);
}
}
[[nodiscard]] tr_sha1_digest_t final() override
{
auto digest = tr_sha1_digest_t{};
auto* const digest_as_uchar = reinterpret_cast<unsigned char*>(std::data(digest));
CC_SHA1_Final(digest_as_uchar, handle);
delete handle;
CC_SHA1_Final(reinterpret_cast<unsigned char*>(std::data(digest)), &handle_);
clear();
return digest;
}
/***
****
***/
tr_sha256_ctx_t tr_sha256_init(void)
{
auto* handle = new CC_SHA256_CTX();
CC_SHA256_Init(handle);
return handle;
}
bool tr_sha256_update(tr_sha256_ctx_t handle, void const* data, size_t data_length)
{
TR_ASSERT(handle != nullptr);
if (data_length == 0)
{
return true;
}
TR_ASSERT(data != nullptr);
private:
CC_SHA1_CTX handle_ = {};
};
CC_SHA256_Update(static_cast<CC_SHA256_CTX*>(handle), data, data_length);
return true;
class Sha256Impl final : public tr_sha256
{
public:
Sha256Impl()
{
clear();
}
~Sha256Impl() override = default;
void clear() override
{
CC_SHA256_Init(&handle_);
}
void add(void const* data, size_t data_length) override
{
if (data_length > 0U)
{
CC_SHA256_Update(&handle_, data, data_length);
}
}
[[nodiscard]] tr_sha256_digest_t final() override
{
auto digest = tr_sha256_digest_t{};
CC_SHA256_Final(reinterpret_cast<unsigned char*>(std::data(digest)), &handle_);
clear();
return digest;
}
private:
CC_SHA256_CTX handle_;
};
} // namespace
std::unique_ptr<tr_sha1> tr_sha1::create()
{
return std::make_unique<Sha1Impl>();
}
std::optional<tr_sha256_digest_t> tr_sha256_final(tr_sha256_ctx_t raw_handle)
std::unique_ptr<tr_sha256> tr_sha256::create()
{
TR_ASSERT(raw_handle != nullptr);
auto* handle = static_cast<CC_SHA256_CTX*>(raw_handle);
auto digest = tr_sha256_digest_t{};
auto* const digest_as_uchar = reinterpret_cast<unsigned char*>(std::data(digest));
CC_SHA256_Final(digest_as_uchar, handle);
delete handle;
return digest;
return std::make_unique<Sha256Impl>();
}
/***

View File

@@ -6,8 +6,10 @@
#include <mutex>
#if defined(CYASSL_IS_WOLFSSL)
// NOLINTBEGIN bugprone-macro-parentheses
#define API_HEADER(x) <wolfssl/x>
#define API_HEADER_CRYPT(x) API_HEADER(wolfcrypt/x)
// NOLINTEND
#define API(x) wc_##x
#define API_VERSION_HEX LIBWOLFSSL_VERSION_HEX
#else
@@ -32,6 +34,12 @@
#include "tr-assert.h"
#include "utils.h"
#if LIBWOLFSSL_VERSION_HEX >= 0x04000000 // 4.0.0
using TR_WC_RNG = WC_RNG;
#else
using TR_WC_RNG = RNG;
#endif
#define TR_CRYPTO_X509_FALLBACK
#include "crypto-utils-fallback.cc" // NOLINT(bugprone-suspicious-include)
@@ -82,9 +90,9 @@ static bool check_cyassl_result(int result, char const* file, int line)
****
***/
static RNG* get_rng(void)
static TR_WC_RNG* get_rng()
{
static RNG rng;
static TR_WC_RNG rng;
static bool rng_initialized = false;
if (!rng_initialized)
@@ -106,90 +114,89 @@ static std::mutex rng_mutex_;
****
***/
tr_sha1_ctx_t tr_sha1_init(void)
namespace
{
Sha* handle = tr_new(Sha, 1);
if (check_result(API(InitSha)(handle)))
class Sha1Impl final : public tr_sha1
{
public:
Sha1Impl()
{
return handle;
clear();
}
tr_free(handle);
return nullptr;
}
~Sha1Impl() override = default;
bool tr_sha1_update(tr_sha1_ctx_t raw_handle, void const* data, size_t data_length)
{
auto* handle = static_cast<Sha*>(raw_handle);
TR_ASSERT(handle != nullptr);
if (data_length == 0)
void clear() override
{
return true;
API(InitSha)(&handle_);
}
TR_ASSERT(data != nullptr);
return check_result(API(ShaUpdate)(handle, static_cast<byte const*>(data), data_length));
}
std::optional<tr_sha1_digest_t> tr_sha1_final(tr_sha1_ctx_t raw_handle)
{
auto* handle = static_cast<Sha*>(raw_handle);
TR_ASSERT(handle != nullptr);
void add(void const* data, size_t data_length) override
{
if (data_length > 0U)
{
API(ShaUpdate)(&handle_, static_cast<byte const*>(data), data_length);
}
}
[[nodiscard]] tr_sha1_digest_t final() override
{
auto digest = tr_sha1_digest_t{};
auto* const digest_as_uchar = reinterpret_cast<unsigned char*>(std::data(digest));
auto const ok = check_result(API(ShaFinal)(handle, digest_as_uchar));
tr_free(handle);
return ok ? std::make_optional(digest) : std::nullopt;
}
/***
****
***/
tr_sha256_ctx_t tr_sha256_init(void)
{
Sha256* handle = tr_new(Sha256, 1);
if (check_result(API(InitSha256)(handle)))
{
return handle;
API(ShaFinal)(&handle_, reinterpret_cast<byte*>(std::data(digest)));
clear();
return digest;
}
tr_free(handle);
return nullptr;
}
private:
API(Sha) handle_ = {};
};
bool tr_sha256_update(tr_sha256_ctx_t raw_handle, void const* data, size_t data_length)
class Sha256Impl final : public tr_sha256
{
auto* handle = static_cast<Sha256*>(raw_handle);
TR_ASSERT(handle != nullptr);
if (data_length == 0)
public:
Sha256Impl()
{
return true;
clear();
}
TR_ASSERT(data != nullptr);
~Sha256Impl() override = default;
return check_result(API(Sha256Update)(handle, static_cast<byte const*>(data), data_length));
}
void clear() override
{
API(InitSha256)(&handle_);
}
std::optional<tr_sha256_digest_t> tr_sha256_final(tr_sha256_ctx_t raw_handle)
{
auto* handle = static_cast<Sha256*>(raw_handle);
TR_ASSERT(handle != nullptr);
void add(void const* data, size_t data_length) override
{
if (data_length > 0U)
{
API(Sha256Update)(&handle_, static_cast<byte const*>(data), data_length);
}
}
[[nodiscard]] tr_sha256_digest_t final() override
{
auto digest = tr_sha256_digest_t{};
auto* const digest_as_uchar = reinterpret_cast<unsigned char*>(std::data(digest));
auto const ok = check_result(API(Sha256Final)(handle, digest_as_uchar));
tr_free(handle);
API(Sha256Final)(&handle_, reinterpret_cast<byte*>(std::data(digest)));
clear();
return digest;
}
return ok ? std::make_optional(digest) : std::nullopt;
private:
API(Sha256) handle_ = {};
};
} // namespace
std::unique_ptr<tr_sha1> tr_sha1::create()
{
return std::make_unique<Sha1Impl>();
}
std::unique_ptr<tr_sha256> tr_sha256::create()
{
return std::make_unique<Sha256Impl>();
}
/***

View File

@@ -88,96 +88,119 @@ static bool check_openssl_result(int result, int expected_result, bool expected_
****
***/
tr_sha1_ctx_t tr_sha1_init()
namespace
{
EVP_MD_CTX* handle = EVP_MD_CTX_create();
if (check_result(EVP_DigestInit_ex(handle, EVP_sha1(), nullptr)))
class ShaHelper
{
public:
using EvpFunc = decltype((EVP_sha1));
ShaHelper(EvpFunc evp_func)
: evp_func_{ evp_func }
{
return handle;
clear();
}
EVP_MD_CTX_destroy(handle);
return nullptr;
}
bool tr_sha1_update(tr_sha1_ctx_t raw_handle, void const* data, size_t data_length)
{
auto* const handle = static_cast<EVP_MD_CTX*>(raw_handle);
TR_ASSERT(handle != nullptr);
if (data_length == 0)
void clear()
{
return true;
EVP_DigestInit_ex(handle_.get(), evp_func_(), nullptr);
}
TR_ASSERT(data != nullptr);
void update(void const* data, size_t data_length)
{
if (data_length != 0U)
{
EVP_DigestUpdate(handle_.get(), data, data_length);
}
}
return check_result(EVP_DigestUpdate(handle, data, data_length));
}
std::optional<tr_sha1_digest_t> tr_sha1_final(tr_sha1_ctx_t raw_handle)
{
auto* handle = static_cast<EVP_MD_CTX*>(raw_handle);
TR_ASSERT(handle != nullptr);
template<typename DigestType>
[[nodiscard]] DigestType digest()
{
TR_ASSERT(handle_ != nullptr);
unsigned int hash_length = 0;
auto digest = tr_sha1_digest_t{};
auto digest = DigestType{};
auto* const digest_as_uchar = reinterpret_cast<unsigned char*>(std::data(digest));
bool const ok = check_result(EVP_DigestFinal_ex(handle, digest_as_uchar, &hash_length));
bool const ok = check_result(EVP_DigestFinal_ex(handle_.get(), digest_as_uchar, &hash_length));
TR_ASSERT(!ok || hash_length == std::size(digest));
EVP_MD_CTX_destroy(handle);
return ok ? std::make_optional(digest) : std::nullopt;
}
/***
****
***/
tr_sha256_ctx_t tr_sha256_init()
{
EVP_MD_CTX* handle = EVP_MD_CTX_create();
if (check_result(EVP_DigestInit_ex(handle, EVP_sha256(), nullptr)))
{
return handle;
clear();
return digest;
}
EVP_MD_CTX_destroy(handle);
return nullptr;
}
bool tr_sha256_update(tr_sha256_ctx_t raw_handle, void const* data, size_t data_length)
{
auto* const handle = static_cast<EVP_MD_CTX*>(raw_handle);
TR_ASSERT(handle != nullptr);
if (data_length == 0)
private:
struct MessageDigestDeleter
{
return true;
void operator()(EVP_MD_CTX* ctx) const noexcept
{
EVP_MD_CTX_destroy(ctx);
}
};
EvpFunc evp_func_;
std::unique_ptr<EVP_MD_CTX, MessageDigestDeleter> const handle_{ EVP_MD_CTX_create() };
};
class Sha1Impl final : public tr_sha1
{
public:
~Sha1Impl() override = default;
void clear() override
{
helper_.clear();
}
TR_ASSERT(data != nullptr);
void add(void const* data, size_t data_length) override
{
helper_.update(data, data_length);
}
return check_result(EVP_DigestUpdate(handle, data, data_length));
[[nodiscard]] tr_sha1_digest_t final() override
{
return helper_.digest<tr_sha1_digest_t>();
}
private:
ShaHelper helper_{ EVP_sha1 };
};
class Sha256Impl final : public tr_sha256
{
public:
~Sha256Impl() override = default;
void clear() override
{
helper_.clear();
}
void add(void const* data, size_t data_length) override
{
helper_.update(data, data_length);
}
[[nodiscard]] tr_sha256_digest_t final() override
{
return helper_.digest<tr_sha256_digest_t>();
}
private:
ShaHelper helper_{ EVP_sha256 };
};
} // namespace
std::unique_ptr<tr_sha1> tr_sha1::create()
{
return std::make_unique<Sha1Impl>();
}
std::optional<tr_sha256_digest_t> tr_sha256_final(tr_sha1_ctx_t raw_handle)
std::unique_ptr<tr_sha256> tr_sha256::create()
{
auto* handle = static_cast<EVP_MD_CTX*>(raw_handle);
TR_ASSERT(handle != nullptr);
unsigned int hash_length = 0;
auto digest = tr_sha256_digest_t{};
auto* const digest_as_uchar = reinterpret_cast<unsigned char*>(std::data(digest));
bool const ok = check_result(EVP_DigestFinal_ex(handle, digest_as_uchar, &hash_length));
TR_ASSERT(!ok || hash_length == std::size(digest));
EVP_MD_CTX_destroy(handle);
return ok ? std::make_optional(digest) : std::nullopt;
return std::make_unique<Sha256Impl>();
}
/***
@@ -186,9 +209,9 @@ std::optional<tr_sha256_digest_t> tr_sha256_final(tr_sha1_ctx_t raw_handle)
#if OPENSSL_VERSION_NUMBER < 0x0090802fL
static EVP_CIPHER_CTX* openssl_evp_cipher_context_new(void)
static EVP_CIPHER_CTX* openssl_evp_cipher_context_new()
{
EVP_CIPHER_CTX* handle = tr_new(EVP_CIPHER_CTX, 1);
auto* const handle = new EVP_CIPHER_CTX{};
if (handle != nullptr)
{
@@ -206,7 +229,7 @@ static void openssl_evp_cipher_context_free(EVP_CIPHER_CTX* handle)
}
EVP_CIPHER_CTX_cleanup(handle);
tr_free(handle);
delete handle;
}
#define EVP_CIPHER_CTX_new() openssl_evp_cipher_context_new()

View File

@@ -6,8 +6,10 @@
#include <mutex>
#if defined(POLARSSL_IS_MBEDTLS)
// NOLINTBEGIN bugprone-macro-parentheses
#define API_HEADER(x) <mbedtls/x>
#define API(x) mbedtls_##x
// NOLINTEND
#define API_VERSION_NUMBER MBEDTLS_VERSION_NUMBER
#else
#define API_HEADER(x) <polarssl/x>
@@ -98,7 +100,7 @@ static int my_rand(void* /*context*/, unsigned char* buffer, size_t buffer_size)
return 0;
}
static api_ctr_drbg_context* get_rng(void)
static api_ctr_drbg_context* get_rng()
{
static api_ctr_drbg_context rng;
static bool rng_initialized = false;
@@ -129,96 +131,131 @@ static std::recursive_mutex rng_mutex_;
****
***/
tr_sha1_ctx_t tr_sha1_init(void)
namespace
{
api_sha1_context* handle = tr_new0(api_sha1_context, 1);
#if API_VERSION_NUMBER >= 0x01030800
API(sha1_init)(handle);
#endif
API(sha1_starts)(handle);
return handle;
}
bool tr_sha1_update(tr_sha1_ctx_t raw_handle, void const* data, size_t data_length)
class Sha1Impl final : public tr_sha1
{
auto* handle = static_cast<api_sha1_context*>(raw_handle);
TR_ASSERT(handle != nullptr);
if (data_length == 0)
public:
Sha1Impl()
{
return true;
clear();
}
TR_ASSERT(data != nullptr);
~Sha1Impl() override = default;
API(sha1_update)(handle, static_cast<unsigned char const*>(data), data_length);
return true;
}
void clear() override
{
#if API_VERSION_NUMBER >= 0x01030800
API(sha1_init)(&handle_);
#endif
std::optional<tr_sha1_digest_t> tr_sha1_final(tr_sha1_ctx_t raw_handle)
{
auto* handle = static_cast<api_sha1_context*>(raw_handle);
TR_ASSERT(handle != nullptr);
#if API_VERSION_NUMBER >= 0x02070000
mbedtls_sha1_starts_ret(&handle_);
#else
API(sha1_starts)(&handle_);
#endif
}
void add(void const* data, size_t data_length) override
{
if (data_length > 0U)
{
#if API_VERSION_NUMBER >= 0x02070000
mbedtls_sha1_update_ret(&handle_, static_cast<unsigned char const*>(data), data_length);
#else
API(sha1_update)(&handle_, static_cast<unsigned char const*>(data), data_length);
#endif
}
}
[[nodiscard]] tr_sha1_digest_t final() override
{
auto digest = tr_sha1_digest_t{};
auto* const digest_as_uchar = reinterpret_cast<unsigned char*>(std::data(digest));
API(sha1_finish)(handle, digest_as_uchar);
#if API_VERSION_NUMBER >= 0x01030800
API(sha1_free)(handle);
#if API_VERSION_NUMBER >= 0x02070000
mbedtls_sha1_finish_ret(&handle_, digest_as_uchar);
#else
API(sha1_finish)(&handle_, digest_as_uchar);
#endif
#if API_VERSION_NUMBER >= 0x01030800
API(sha1_free)(&handle_);
#endif
tr_free(handle);
return digest;
}
/***
****
***/
tr_sha256_ctx_t tr_sha256_init(void)
{
api_sha256_context* handle = tr_new0(api_sha256_context, 1);
#if API_VERSION_NUMBER >= 0x01030800
API(sha256_init)(handle);
#endif
API(sha256_starts)(handle, 0);
return handle;
}
bool tr_sha256_update(tr_sha256_ctx_t raw_handle, void const* data, size_t data_length)
{
auto* handle = static_cast<api_sha256_context*>(raw_handle);
TR_ASSERT(handle != nullptr);
if (data_length == 0)
{
return true;
}
TR_ASSERT(data != nullptr);
private:
mbedtls_sha1_context handle_ = {};
};
API(sha256_update)(handle, static_cast<unsigned char const*>(data), data_length);
return true;
}
std::optional<tr_sha256_digest_t> tr_sha256_final(tr_sha256_ctx_t raw_handle)
class Sha256Impl final : public tr_sha256
{
auto* handle = static_cast<api_sha256_context*>(raw_handle);
TR_ASSERT(handle != nullptr);
public:
Sha256Impl()
{
clear();
}
auto digest = tr_sha256_digest_t{};
auto* const digest_as_uchar = reinterpret_cast<unsigned char*>(std::data(digest));
API(sha256_finish)(handle, digest_as_uchar);
~Sha256Impl() override = default;
void clear() override
{
#if API_VERSION_NUMBER >= 0x01030800
API(sha256_free)(handle);
API(sha256_init)(&handle_);
#endif
#if API_VERSION_NUMBER >= 0x02070000
mbedtls_sha256_starts_ret(&handle_, 0);
#else
API(sha256_starts)(&handle_);
#endif
}
void add(void const* data, size_t data_length) override
{
if (data_length > 0U)
{
#if API_VERSION_NUMBER >= 0x02070000
mbedtls_sha256_update_ret(&handle_, static_cast<unsigned char const*>(data), data_length);
#else
API(sha256_update)(&handle_, static_cast<unsigned char const*>(data), data_length);
#endif
}
}
[[nodiscard]] tr_sha256_digest_t final() override
{
auto digest = tr_sha256_digest_t{};
auto* const digest_as_uchar = reinterpret_cast<unsigned char*>(std::data(digest));
#if API_VERSION_NUMBER >= 0x02070000
mbedtls_sha256_finish_ret(&handle_, digest_as_uchar);
#else
API(sha256_finish)(&handle_, digest_as_uchar);
#endif
#if API_VERSION_NUMBER >= 0x01030800
API(sha256_free)(&handle_);
#endif
tr_free(handle);
return digest;
}
private:
mbedtls_sha256_context handle_ = {};
};
} // namespace
std::unique_ptr<tr_sha1> tr_sha1::create()
{
return std::make_unique<Sha1Impl>();
}
std::unique_ptr<tr_sha256> tr_sha256::create()
{
return std::make_unique<Sha256Impl>();
}
/***

View File

@@ -74,11 +74,11 @@ std::string tr_salt(std::string_view plaintext, std::string_view salt)
static_assert(DigestStringSize == 40);
// build a sha1 digest of the original content and the salt
auto const digest = tr_sha1(plaintext, salt);
auto const digest = tr_sha1::digest(plaintext, salt);
// convert it to a string. string holds three parts:
// DigestPrefix, stringified digest of plaintext + salt, and the salt.
return fmt::format(FMT_STRING("{:s}{:s}{:s}"), SaltedPrefix, (digest ? tr_sha1_to_string(*digest) : ""sv), salt);
return fmt::format(FMT_STRING("{:s}{:s}{:s}"), SaltedPrefix, tr_sha1_to_string(digest), salt);
}
} // namespace

View File

@@ -8,6 +8,7 @@
#include <array>
#include <cstddef> // size_t
#include <memory>
#include <optional>
#include <string>
#include <string_view>
@@ -19,10 +20,44 @@
*** @{
**/
/** @brief Opaque SHA1 context type. */
using tr_sha1_ctx_t = void*;
/** @brief Opaque SHA256 context type. */
using tr_sha256_ctx_t = void*;
class tr_sha1
{
public:
static std::unique_ptr<tr_sha1> create();
virtual ~tr_sha1() = default;
virtual void clear() = 0;
virtual void add(void const* data, size_t data_length) = 0;
[[nodiscard]] virtual tr_sha1_digest_t final() = 0;
template<typename... T>
[[nodiscard]] static tr_sha1_digest_t digest(T... args)
{
auto context = tr_sha1::create();
(context->add(std::data(args), std::size(args)), ...);
return context->final();
}
};
class tr_sha256
{
public:
static std::unique_ptr<tr_sha256> create();
virtual ~tr_sha256() = default;
virtual void clear() = 0;
virtual void add(void const* data, size_t data_length) = 0;
[[nodiscard]] virtual tr_sha256_digest_t final() = 0;
template<typename... T>
[[nodiscard]] static tr_sha256_digest_t digest(T... args)
{
auto context = tr_sha256::create();
(context->add(std::data(args), std::size(args)), ...);
return context->final();
}
};
/** @brief Opaque SSL context type. */
using tr_ssl_ctx_t = void*;
/** @brief Opaque X509 certificate store type. */
@@ -30,82 +65,6 @@ using tr_x509_store_t = void*;
/** @brief Opaque X509 certificate type. */
using tr_x509_cert_t = void*;
/**
* @brief Allocate and initialize new SHA1 hasher context.
*/
tr_sha1_ctx_t tr_sha1_init(void);
/**
* @brief Update SHA1 hash.
*/
bool tr_sha1_update(tr_sha1_ctx_t handle, void const* data, size_t data_length);
/**
* @brief Finalize and export SHA1 hash, free hasher context.
*/
std::optional<tr_sha1_digest_t> tr_sha1_final(tr_sha1_ctx_t handle);
/**
* @brief Generate a SHA1 hash from one or more chunks of memory.
*/
template<typename... T>
std::optional<tr_sha1_digest_t> tr_sha1(T... args)
{
auto ctx = tr_sha1_init();
if (ctx == nullptr)
{
return std::nullopt;
}
if ((tr_sha1_update(ctx, std::data(args), std::size(args)) && ...))
{
return tr_sha1_final(ctx);
}
// one of the update() calls failed so we will return nullopt,
// but we need to call final() first to ensure ctx is released
tr_sha1_final(ctx);
return std::nullopt;
}
/**
* @brief Allocate and initialize new SHA256 hasher context.
*/
tr_sha256_ctx_t tr_sha256_init(void);
/**
* @brief Update SHA256 hash.
*/
bool tr_sha256_update(tr_sha256_ctx_t handle, void const* data, size_t data_length);
/**
* @brief Finalize and export SHA256 hash, free hasher context.
*/
std::optional<tr_sha256_digest_t> tr_sha256_final(tr_sha256_ctx_t handle);
/**
* @brief generate a SHA256 hash from some memory
*/
template<typename... T>
std::optional<tr_sha256_digest_t> tr_sha256(T... args)
{
auto ctx = tr_sha256_init();
if (ctx == nullptr)
{
return std::nullopt;
}
if ((tr_sha256_update(ctx, std::data(args), std::size(args)) && ...))
{
return tr_sha256_final(ctx);
}
// one of the update() calls failed so we will return nullopt,
// but we need to call final() first to ensure ctx is released
tr_sha256_final(ctx);
return std::nullopt;
}
/**
* @brief Get X509 certificate store from SSL context.
*/

View File

@@ -410,15 +410,8 @@ static ReadState readYb(tr_handshake* handshake, struct evbuffer* inbuf)
evbuffer* const outbuf = evbuffer_new();
/* HASH('req1', S) */
if (auto const req1 = tr_sha1("req1"sv, handshake->dh.secret()); req1)
{
evbuffer_add(outbuf, std::data(*req1), std::size(*req1));
}
else
{
tr_logAddTraceHand(handshake, "error while computing req1 hash after Yb");
return tr_handshakeDone(handshake, false);
}
auto const req1 = tr_sha1::digest("req1"sv, handshake->dh.secret());
evbuffer_add(outbuf, std::data(req1), std::size(req1));
auto const info_hash = handshake->io->torrentHash();
if (!info_hash)
@@ -429,18 +422,12 @@ static ReadState readYb(tr_handshake* handshake, struct evbuffer* inbuf)
/* HASH('req2', SKEY) xor HASH('req3', S) */
{
auto const req2 = tr_sha1("req2"sv, *info_hash);
auto const req3 = tr_sha1("req3"sv, handshake->dh.secret());
if (!req2 || !req3)
{
tr_logAddTraceHand(handshake, "error while computing req2/req3 hash after Yb");
return tr_handshakeDone(handshake, false);
}
auto const req2 = tr_sha1::digest("req2"sv, *info_hash);
auto const req3 = tr_sha1::digest("req3"sv, handshake->dh.secret());
auto buf = tr_sha1_digest_t{};
for (size_t i = 0, n = std::size(buf); i < n; ++i)
{
buf[i] = (*req2)[i] ^ (*req3)[i];
buf[i] = req2[i] ^ req3[i];
}
evbuffer_add(outbuf, std::data(buf), std::size(buf));
@@ -732,7 +719,7 @@ static ReadState readYa(tr_handshake* handshake, struct evbuffer* inbuf)
static ReadState readPadA(tr_handshake* handshake, struct evbuffer* inbuf)
{
// find the end of PadA by looking for HASH('req1', S)
auto const needle = *tr_sha1("req1"sv, handshake->dh.secret());
auto const needle = tr_sha1::digest("req1"sv, handshake->dh.secret());
for (size_t i = 0; i < PadA_MAXLEN; ++i)
{
@@ -779,17 +766,11 @@ static ReadState readCryptoProvide(tr_handshake* handshake, struct evbuffer* inb
auto req2 = tr_sha1_digest_t{};
evbuffer_remove(inbuf, std::data(req2), std::size(req2));
auto const req3 = tr_sha1("req3"sv, handshake->dh.secret());
if (!req3)
{
tr_logAddTraceHand(handshake, "error while computing req3 hash after req2");
return tr_handshakeDone(handshake, false);
}
auto const req3 = tr_sha1::digest("req3"sv, handshake->dh.secret());
auto obfuscated_hash = tr_sha1_digest_t{};
for (size_t i = 0; i < std::size(obfuscated_hash); ++i)
{
obfuscated_hash[i] = req2[i] ^ (*req3)[i];
obfuscated_hash[i] = req2[i] ^ req3[i];
}
if (auto const info = handshake->mediator->torrentInfoFromObfuscated(obfuscated_hash); info)

View File

@@ -236,23 +236,22 @@ std::optional<tr_sha1_digest_t> recalculateHash(tr_torrent* tor, tr_piece_index_
auto loc = tor->pieceLoc(piece);
tr_ioPrefetch(tor, loc, bytes_left);
auto sha = tr_sha1_init();
auto sha = tr_sha1::create();
auto buffer = std::vector<uint8_t>(tr_block_info::BlockSize);
while (bytes_left != 0)
{
size_t const len = std::min(bytes_left, std::size(buffer));
if (auto const success = tor->session->cache->readBlock(tor, loc, len, std::data(buffer)) == 0; !success)
{
tr_sha1_final(sha);
return {};
}
tr_sha1_update(sha, std::data(buffer), len);
sha->add(std::data(buffer), len);
loc = tor->byteLoc(loc.byte + len);
bytes_left -= len;
}
return tr_sha1_final(sha);
return sha->final();
}
} // namespace

View File

@@ -321,16 +321,8 @@ static std::vector<std::byte> getHashInfo(tr_metainfo_builder* b)
TR_ASSERT(bufptr - std::data(buf) == (int)this_piece_size);
TR_ASSERT(leftInPiece == 0);
auto const digest = tr_sha1(buf);
if (!digest)
{
b->my_errno = EIO;
*fmt::format_to_n(b->errfile, sizeof(b->errfile) - 1, "error hashing piece {:d}", b->pieceIndex).out = '\0';
b->result = TrMakemetaResult::ERR_IO_READ;
break;
}
walk = std::copy(std::begin(*digest), std::end(*digest), walk);
auto const digest = tr_sha1::digest(buf);
walk = std::copy(std::begin(digest), std::end(digest), walk);
if (b->abortFlag)
{

View File

@@ -111,8 +111,8 @@ void Filter::decryptInit(bool is_incoming, DH const& dh, tr_sha1_digest_t const&
auto const key = is_incoming ? "keyA"sv : "keyB"sv;
dec_key_ = std::make_shared<struct arc4_context>();
auto const buf = tr_sha1(key, dh.secret(), info_hash);
arc4_init(dec_key_.get(), std::data(*buf), std::size(*buf));
auto const buf = tr_sha1::digest(key, dh.secret(), info_hash);
arc4_init(dec_key_.get(), std::data(buf), std::size(buf));
arc4_discard(dec_key_.get(), 1024);
}
@@ -129,8 +129,8 @@ void Filter::encryptInit(bool is_incoming, DH const& dh, tr_sha1_digest_t const&
auto const key = is_incoming ? "keyB"sv : "keyA"sv;
enc_key_ = std::make_shared<struct arc4_context>();
auto const buf = tr_sha1(key, dh.secret(), info_hash);
arc4_init(enc_key_.get(), std::data(*buf), std::size(*buf));
auto const buf = tr_sha1::digest(key, dh.secret(), info_hash);
arc4_init(enc_key_.get(), std::data(buf), std::size(buf));
arc4_discard(enc_key_.get(), 1024);
}

View File

@@ -255,8 +255,7 @@ bool tr_torrentUseMetainfoFromFile(
static bool useNewMetainfo(tr_torrent* tor, tr_incomplete_metadata const* m, tr_error** error)
{
// test the info_dict checksum
auto const sha1 = tr_sha1(m->metadata);
if (bool const checksum_passed = sha1 && *sha1 == tor->infoHash(); !checksum_passed)
if (tr_sha1::digest(m->metadata) != tor->infoHash())
{
return false;
}

View File

@@ -573,17 +573,12 @@ private:
char const* const begin = &info_dict_begin_.front();
char const* const end = &context.raw().back() + 1;
auto const info_dict_benc = std::string_view{ begin, size_t(end - begin) };
auto const hash = tr_sha1(info_dict_benc);
auto const hash2 = tr_sha256(info_dict_benc);
if (!hash)
{
tr_error_set(context.error, EINVAL, "bad info_dict checksum");
return false;
}
auto const hash = tr_sha1::digest(info_dict_benc);
auto const hash2 = tr_sha256::digest(info_dict_benc);
tm_.info_hash_ = *hash;
tm_.info_hash_ = hash;
tm_.info_hash_str_ = tr_sha1_to_string(tm_.info_hash_);
tm_.info_hash2_ = *hash2;
tm_.info_hash2_ = hash2;
tm_.info_hash2_str_ = tr_sha256_to_string(tm_.info_hash2_);
tm_.info_dict_size_ = std::size(info_dict_benc);
return true;

View File

@@ -456,7 +456,7 @@ static bool tr_torrentIsSeedIdleLimitDone(tr_torrent const* tor)
difftime(tr_time(), std::max(tor->startDate, tor->activityDate)) >= idleMinutes * 60U;
}
static void torrentCallScript(tr_torrent const* tor, char const* script);
static void torrentCallScript(tr_torrent const* tor, std::string const& script);
static void callScriptIfEnabled(tr_torrent const* tor, TrScript type)
{
@@ -464,7 +464,7 @@ static void callScriptIfEnabled(tr_torrent const* tor, TrScript type)
if (tr_sessionIsScriptEnabled(session, type))
{
torrentCallScript(tor, tr_sessionGetScript(session, type));
torrentCallScript(tor, session->script(type));
}
}
@@ -575,17 +575,7 @@ static void torrentStart(tr_torrent* tor, torrent_start_opts opts);
static void torrentInitFromInfoDict(tr_torrent* tor)
{
tor->completion = tr_completion{ tor, &tor->blockInfo() };
if (auto const obfuscated = tr_sha1("req2"sv, tor->infoHash()); obfuscated)
{
tor->obfuscated_hash = *obfuscated;
}
else
{
// lookups by obfuscated hash will fail for this torrent
tr_logAddErrorTor(tor, _("Couldn't compute obfuscated info hash"));
tor->obfuscated_hash = tr_sha1_digest_t{};
}
tor->obfuscated_hash = tr_sha1::digest("req2"sv, tor->infoHash());
tor->fpm_.reset(tor->metainfo_);
tor->file_mtimes_.resize(tor->fileCount());
tor->file_priorities_.reset(&tor->fpm_);
@@ -1725,9 +1715,9 @@ static std::string buildTrackersString(tr_torrent const* tor)
return buf.str();
}
static void torrentCallScript(tr_torrent const* tor, char const* script)
static void torrentCallScript(tr_torrent const* tor, std::string const& script)
{
if (tr_str_is_empty(script))
if (std::empty(script))
{
return;
}
@@ -1735,7 +1725,7 @@ static void torrentCallScript(tr_torrent const* tor, char const* script)
auto torrent_dir = tr_pathbuf{ tor->currentDir() };
tr_sys_path_native_separators(std::data(torrent_dir));
auto const cmd = std::array<char const*, 2>{ script, nullptr };
auto const cmd = std::array<char const*, 2>{ script.c_str(), nullptr };
auto const id_str = std::to_string(tr_torrentId(tor));
auto const labels_str = buildLabelsString(tor);

View File

@@ -741,11 +741,8 @@ void dht_hash(void* hash_return, int hash_size, void const* v1, int len1, void c
auto const sv1 = std::string_view{ static_cast<char const*>(v1), size_t(len1) };
auto const sv2 = std::string_view{ static_cast<char const*>(v2), size_t(len2) };
auto const sv3 = std::string_view{ static_cast<char const*>(v3), size_t(len3) };
auto const digest = tr_sha1(sv1, sv2, sv3);
if (digest)
{
std::copy_n(std::data(*digest), std::min(size_t(hash_size), std::size(*digest)), setme);
}
auto const digest = tr_sha1::digest(sv1, sv2, sv3);
std::copy_n(std::data(digest), std::min(size_t(hash_size), std::size(digest)), setme);
}
int dht_random_bytes(void* buf, size_t size)

View File

@@ -43,7 +43,7 @@ static bool verifyTorrent(tr_torrent* tor, bool const* stopFlag)
tr_file_index_t prev_file_index = ~file_index;
tr_piece_index_t piece = 0;
auto buffer = std::vector<std::byte>(1024 * 256);
auto sha = tr_sha1_init();
auto sha = tr_sha1::create();
tr_logAddDebugTor(tor, "verifying torrent...");
@@ -78,7 +78,7 @@ static bool verifyTorrent(tr_torrent* tor, bool const* stopFlag)
if (tr_sys_file_read_at(fd, std::data(buffer), bytes_this_pass, file_pos, &num_read) && num_read > 0)
{
bytes_this_pass = num_read;
tr_sha1_update(sha, std::data(buffer), bytes_this_pass);
sha->add(std::data(buffer), bytes_this_pass);
tr_sys_file_advise(fd, file_pos, bytes_this_pass, TR_SYS_FILE_ADVICE_DONT_NEED);
}
}
@@ -92,8 +92,7 @@ static bool verifyTorrent(tr_torrent* tor, bool const* stopFlag)
/* if we're finishing a piece... */
if (left_in_piece == 0)
{
auto const hash = tr_sha1_final(sha);
auto const has_piece = hash && *hash == tor->pieceHash(piece);
auto const has_piece = sha->final() == tor->pieceHash(piece);
if (has_piece || had_piece)
{
@@ -112,7 +111,7 @@ static bool verifyTorrent(tr_torrent* tor, bool const* stopFlag)
tr_wait_msec(MsecToSleepPerSecondDuringVerify);
}
sha = tr_sha1_init();
sha->clear();
++piece;
tor->setVerifyProgress(piece / float(tor->pieceCount()));
piece_pos = 0;
@@ -138,8 +137,6 @@ static bool verifyTorrent(tr_torrent* tor, bool const* stopFlag)
tr_sys_file_close(fd);
}
tr_sha1_final(sha);
/* stopwatch */
time_t const end = tr_time();
tr_logAddDebugTor(

View File

@@ -19,19 +19,11 @@
#define tr_rand_int_weak tr_rand_int_weak_
#define tr_salt_shaker tr_salt_shaker_
#define tr_sha1 tr_sha1_
#define tr_sha1_ctx_t tr_sha1_ctx_t_
#define tr_sha1_final tr_sha1_final_
#define tr_sha1_from_string tr_sha1_from_string_
#define tr_sha1_init tr_sha1_init_
#define tr_sha1_to_string tr_sha1_to_string_
#define tr_sha1_update tr_sha1_update_
#define tr_sha256 tr_sha256_
#define tr_sha256_ctx_t tr_sha256_ctx_t_
#define tr_sha256_final tr_sha256_final_
#define tr_sha256_from_string tr_sha256_from_string_
#define tr_sha256_init tr_sha256_init_
#define tr_sha256_to_string tr_sha256_to_string_
#define tr_sha256_update tr_sha256_update_
#define tr_ssha1 tr_ssha1_
#define tr_ssha1_matches tr_ssha1_matches_
#define tr_ssha1_test tr_ssha1_test_
@@ -59,19 +51,11 @@
#undef tr_rand_int_weak
#undef tr_salt_shaker
#undef tr_sha1
#undef tr_sha1_ctx_t
#undef tr_sha1_final
#undef tr_sha1_from_string
#undef tr_sha1_init
#undef tr_sha1_to_string
#undef tr_sha1_update
#undef tr_sha256
#undef tr_sha256_ctx_t
#undef tr_sha256_final
#undef tr_sha256_from_string
#undef tr_sha256_init
#undef tr_sha256_to_string
#undef tr_sha256_update
#undef tr_ssha1
#undef tr_ssha1_matches
#undef tr_ssha1_test

View File

@@ -103,42 +103,35 @@ TEST(Crypto, encryptDecrypt)
TEST(Crypto, sha1)
{
auto hash1 = tr_sha1("test"sv);
EXPECT_TRUE(hash1);
auto hash1 = tr_sha1::digest("test"sv);
EXPECT_EQ(
0,
memcmp(
std::data(*hash1),
std::data(hash1),
"\xa9\x4a\x8f\xe5\xcc\xb1\x9b\xa6\x1c\x4c\x08\x73\xd3\x91\xe9\x87\x98\x2f\xbb\xd3",
std::size(*hash1)));
std::size(hash1)));
auto hash2 = tr_sha1("test"sv);
EXPECT_TRUE(hash1);
EXPECT_EQ(*hash1, *hash2);
auto hash2 = tr_sha1::digest("test"sv);
EXPECT_EQ(hash1, hash2);
hash1 = tr_sha1("1"sv, "22"sv, "333"sv);
hash2 = tr_sha1("1"sv, "22"sv, "333"sv);
EXPECT_TRUE(hash1);
EXPECT_TRUE(hash2);
EXPECT_EQ(*hash1, *hash2);
hash1 = tr_sha1::digest("1"sv, "22"sv, "333"sv);
hash2 = tr_sha1::digest("1"sv, "22"sv, "333"sv);
EXPECT_EQ(hash1, hash2);
EXPECT_EQ(
0,
memcmp(
std::data(*hash1),
std::data(hash1),
"\x1f\x74\x64\x8e\x50\xa6\xa6\x70\x8e\xc5\x4a\xb3\x27\xa1\x63\xd5\x53\x6b\x7c\xed",
std::size(*hash1)));
std::size(hash1)));
auto const hash3 = tr_sha1("test"sv);
EXPECT_TRUE(hash3);
EXPECT_EQ("a94a8fe5ccb19ba61c4c0873d391e987982fbbd3"sv, tr_sha1_to_string(*hash3));
auto const hash3 = tr_sha1::digest("test"sv);
EXPECT_EQ("a94a8fe5ccb19ba61c4c0873d391e987982fbbd3"sv, tr_sha1_to_string(hash3));
auto const hash4 = tr_sha1("te"sv, "st"sv);
EXPECT_TRUE(hash4);
EXPECT_EQ("a94a8fe5ccb19ba61c4c0873d391e987982fbbd3"sv, tr_sha1_to_string(*hash4));
auto const hash4 = tr_sha1::digest("te"sv, "st"sv);
EXPECT_EQ("a94a8fe5ccb19ba61c4c0873d391e987982fbbd3"sv, tr_sha1_to_string(hash4));
auto const hash5 = tr_sha1("t"sv, "e"sv, std::string{ "s" }, std::array<char, 1>{ { 't' } });
EXPECT_TRUE(hash5);
EXPECT_EQ("a94a8fe5ccb19ba61c4c0873d391e987982fbbd3"sv, tr_sha1_to_string(*hash5));
auto const hash5 = tr_sha1::digest("t"sv, "e"sv, std::string{ "s" }, std::array<char, 1>{ { 't' } });
EXPECT_EQ("a94a8fe5ccb19ba61c4c0873d391e987982fbbd3"sv, tr_sha1_to_string(hash5));
}
TEST(Crypto, ssha1)

View File

@@ -60,7 +60,7 @@ public:
{
for (auto const& [info_hash, info] : torrents)
{
if (obfuscated == *tr_sha1("req2"sv, info.info_hash))
if (obfuscated == tr_sha1::digest("req2"sv, info.info_hash))
{
return info;
}
@@ -142,7 +142,7 @@ auto constexpr ReservedBytesNoExtensions = std::array<uint8_t, 8>{ 0, 0, 0, 0, 0
auto constexpr PlaintextProtocolName = "\023BitTorrent protocol"sv;
auto const DefaultPeerAddr = *tr_address::fromString("127.0.0.1"sv);
auto const DefaultPeerPort = tr_port::fromHost(8080);
auto const TorrentWeAreSeeding = tr_handshake_mediator::torrent_info{ *tr_sha1("abcde"sv),
auto const TorrentWeAreSeeding = tr_handshake_mediator::torrent_info{ tr_sha1::digest("abcde"sv),
tr_peerIdInit(),
tr_torrent_id_t{ 100 },
true /*is_done*/ };
@@ -267,7 +267,7 @@ TEST_F(HandshakeTest, incomingPlaintextUnknownInfoHash)
auto [io, sock] = createIncomingIo(session_);
sendToClient(sock, PlaintextProtocolName);
sendToClient(sock, ReservedBytesNoExtensions);
sendToClient(sock, *tr_sha1("some other torrent unknown to us"sv));
sendToClient(sock, tr_sha1::digest("some other torrent unknown to us"sv));
sendToClient(sock, makeRandomPeerId());
auto const res = runHandshake(mediator, io);