Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions include/session/random.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
#include <sodium/randombytes.h>

#include <algorithm>
#include <cstdint>
#include <limits>
#include <random>
#include <type_traits>

#include "util.hpp"

Expand Down Expand Up @@ -77,11 +81,16 @@ std::string unique_id(std::string_view prefix);
/// - A random integer in the specified range
template <typename T>
T get_uniform_distribution(T min, T max) {
static_assert(std::is_integral_v<T>, "get_uniform_distribution requires an integral type");

if (min > max)
return min;

const uint64_t range = static_cast<uint64_t>(max) - static_cast<uint64_t>(min) + 1;
return static_cast<T>(static_cast<uint64_t>(min) + (csrng() % range));
using dist_type = std::conditional_t<std::is_signed_v<T>, int64_t, uint64_t>;
std::uniform_int_distribution<dist_type> dist{
static_cast<dist_type>(min), static_cast<dist_type>(max)};

return static_cast<T>(dist(csrng));
}

} // namespace session::random
20 changes: 20 additions & 0 deletions tests/test_random.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include <catch2/catch_test_macros.hpp>
#include <cstdint>
#include <limits>

#include "session/random.h"
#include "session/random.hpp"
Expand All @@ -14,3 +16,21 @@ TEST_CASE("Random generation", "[random][random]") {
CHECK(rand3.size() == 20);
CHECK(rand1 != rand2);
}

TEST_CASE("Random uniform distribution", "[random][uniform]") {
CHECK(session::random::get_uniform_distribution<int>(7, 7) == 7);
CHECK(session::random::get_uniform_distribution<int>(9, 3) == 9);
CHECK_NOTHROW(session::random::get_uniform_distribution<uint64_t>(
0, std::numeric_limits<uint64_t>::max()));
CHECK_NOTHROW(session::random::get_uniform_distribution<int64_t>(
std::numeric_limits<int64_t>::min(), std::numeric_limits<int64_t>::max()));

for (int i = 0; i < 1000; ++i) {
const auto signed_value = session::random::get_uniform_distribution<int>(-5, 5);
CHECK(signed_value >= -5);
CHECK(signed_value <= 5);

const auto unsigned_value = session::random::get_uniform_distribution<size_t>(0, 10);
CHECK(unsigned_value <= 10);
}
}