Skip to content

Commit

Permalink
Added FFT cache to domains [SyncWith: crypto3-zk#277] (#73)
Browse files Browse the repository at this point in the history
Added caching to basic_radix2_domain_aux and it's calls.
  • Loading branch information
Iluvmagick authored Jan 31, 2024
1 parent 6a554ff commit 2c3d8bb
Show file tree
Hide file tree
Showing 12 changed files with 424 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ namespace nil {
*/
template<typename FieldType, typename ValueType = typename FieldType::value_type>
std::shared_ptr<evaluation_domain<FieldType, ValueType>> make_evaluation_domain(std::size_t m) {

typedef std::shared_ptr<evaluation_domain<FieldType, ValueType>> result_type;

const std::size_t big = 1ul << (std::size_t(std::ceil(std::log2(m))) - 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,11 @@ namespace nil {
}

std::vector<value_type> evaluate_all_lagrange_polynomials(const typename std::vector<value_type>::const_iterator &t_powers_begin,
const typename std::vector<value_type>::const_iterator &t_powers_end) override {
const typename std::vector<value_type>::const_iterator &t_powers_end) override {
if(std::distance(t_powers_begin, t_powers_end) < this->m) {
throw std::invalid_argument("arithmetic_sequence_radix2: expected std::distance(t_powers_begin, t_powers_end) >= this->m");
}

/* Compute Lagrange polynomial of size m, with m+1 points (x_0, y_0), ... ,(x_m, y_m) */
/* Evaluate for x = t */
/* Return coeffs for each l_j(x) = (l / l_i[j]) * w[j] */
Expand Down Expand Up @@ -263,14 +263,14 @@ namespace nil {

for(std::size_t j = 0; j < l[i].size(); ++j) {
result[i] = result[i] + t_powers_begin[j] * l[i][j];
}
}
result[i] = result[i] * w[i];
}

return result;
}

// This one is not the unity root actually, but it's ok for our purposes.
// This one is not the unity root actually, but it's ok for our purposes.
const field_value_type& get_unity_root() override {
return arithmetic_generator;
}
Expand Down
25 changes: 19 additions & 6 deletions include/nil/crypto3/math/domains/basic_radix2_domain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,22 @@ namespace nil {
class basic_radix2_domain : public evaluation_domain<FieldType, ValueType> {
typedef typename FieldType::value_type field_value_type;
typedef ValueType value_type;
typedef std::pair<std::vector<field_value_type>, std::vector<field_value_type>> cache_type;
std::unique_ptr<cache_type> fft_cache;

void create_fft_cache() {
fft_cache = std::make_unique<cache_type>(std::vector<field_value_type>(), std::vector<field_value_type>());
detail::create_fft_cache<FieldType>(this->m, omega, fft_cache->first);
detail::create_fft_cache<FieldType>(this->m, omega.inversed(), fft_cache->second);
}
public:
typedef FieldType field_type;

field_value_type omega;
const field_value_type omega;

basic_radix2_domain(const std::size_t m) : evaluation_domain<FieldType, ValueType>(m) {
basic_radix2_domain(const std::size_t m)
: evaluation_domain<FieldType, ValueType>(m),
omega(unity_root<FieldType>(m)) {
if (m <= 1)
throw std::invalid_argument("basic_radix2(): expected m > 1");

Expand All @@ -64,8 +73,6 @@ namespace nil {
throw std::invalid_argument(
"basic_radix2(): expected logm <= fields::arithmetic_params<FieldType>::s");
}

omega = unity_root<FieldType>(m);
}

void fft(std::vector<value_type> &a) override {
Expand All @@ -77,7 +84,10 @@ namespace nil {
}
}

detail::basic_radix2_fft<FieldType>(a, omega);
if (fft_cache == nullptr) {
create_fft_cache();
}
detail::basic_radix2_fft_cached<FieldType>(a, fft_cache->first);
}

void inverse_fft(std::vector<value_type> &a) override {
Expand All @@ -89,7 +99,10 @@ namespace nil {
}
}

detail::basic_radix2_fft<FieldType>(a, omega.inversed());
if (fft_cache == nullptr) {
create_fft_cache();
}
detail::basic_radix2_fft_cached<FieldType>(a, fft_cache->second);

const field_value_type sconst = field_value_type(a.size()).inversed();
for (std::size_t i = 0; i < a.size(); ++i) {
Expand Down
62 changes: 46 additions & 16 deletions include/nil/crypto3/math/domains/detail/basic_radix2_domain_aux.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//---------------------------------------------------------------------------//
// Copyright (c) 2020-2021 Mikhail Komarov <[email protected]>
// Copyright (c) 2020-2021 Nikita Kaskov <[email protected]>
// Copyright (c) 2024 Dmitrii Tabalin <[email protected]>
//
// MIT License
//
Expand All @@ -27,6 +28,7 @@
#define CRYPTO3_MATH_BASIC_RADIX2_DOMAIN_AUX_HPP

#include <algorithm>
#include <memory>
#include <vector>

#include <nil/crypto3/algebra/type_traits.hpp>
Expand All @@ -39,17 +41,32 @@ namespace nil {
namespace math {
namespace detail {

/*
* Building caches for fft operations
*/
template<typename FieldType>
void create_fft_cache(
const std::size_t size,
const typename FieldType::value_type &omega,
std::vector<typename FieldType::value_type> &cache) {
typedef typename FieldType::value_type value_type;
cache.resize(size);
cache[0] = value_type::one();
for (std::size_t i = 1; i < size; ++i) {
cache[i] = cache[i - 1] * omega;
}
}

/*
* Below we make use of pseudocode from [CLRS 2n Ed, pp. 864].
* Also, note that it's the caller's responsibility to multiply by 1/N.
*/
template<typename FieldType, typename Range>
void basic_radix2_fft(Range &a, const typename FieldType::value_type &omega) {
void basic_radix2_fft_cached(Range &a, const std::vector<typename FieldType::value_type> &omega_cache) {
typedef typename std::iterator_traits<decltype(std::begin(std::declval<Range>()))>::value_type
value_type;
typedef typename FieldType::value_type field_value_type;
BOOST_STATIC_ASSERT(algebra::is_field<FieldType>::value);

// It now supports curve elements too, should probably some other assertion about the field type and value type
// BOOST_STATIC_ASSERT(std::is_same<typename FieldType::value_type, value_type>::value);

Expand All @@ -64,23 +81,36 @@ namespace nil {
std::swap(a[k], a[rk]);
}

std::size_t m = 1; // invariant: m = 2^{s-1}
for (std::size_t s = 1; s <= logn; ++s) {
// invariant: m = 2^{s-1}
value_type t;
for (std::size_t s = 1, m = 1, inc = n / 2; s <= logn; ++s, m <<= 1, inc >>= 1) {
// w_m is 2^s-th root of unity now
const field_value_type w_m = omega.pow(n / (2 * m));

asm volatile("/* pre-inner */");
for (std::size_t k = 0; k < n; k += 2 * m) {
field_value_type w = field_value_type::one();
for (std::size_t j = 0; j < m; ++j) {
const value_type t = a[k + j + m] * w;
a[k + j + m] = a[k + j] - t;
a[k + j] = a[k + j] + t;
w *= w_m;
for (std::size_t j = 0, idx = 0; j < m; ++j, idx += inc) {
t = std::move(a[k + j + m]);
t = t * omega_cache[idx];
a[k + j + m] = a[k + j];
a[k + j + m] -= t;
a[k + j] += t;
}
}
asm volatile("/* post-inner */");
m *= 2;
}
}

/**
* Note that it's the caller's responsibility to multiply by 1/N.
*/
template<typename FieldType, typename Range>
void basic_radix2_fft(
Range &a, const typename FieldType::value_type &omega,
std::shared_ptr<std::vector<typename FieldType::value_type>> omega_cache = nullptr) {

if (omega_cache == nullptr) {
std::vector<typename FieldType::value_type> omega_powers;
create_fft_cache<FieldType>(a.size(), omega, omega_powers);
basic_radix2_fft_cached<FieldType>(a, omega_powers);
} else {
basic_radix2_fft_cached<FieldType>(a, *omega_cache);
}
}

Expand Down
13 changes: 8 additions & 5 deletions include/nil/crypto3/math/domains/evaluation_domain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ namespace nil {
public:
typedef FieldType field_type;

std::size_t m;
std::size_t log2_size;
std::size_t generator_size;
const std::size_t m;
const std::size_t log2_size;

/**
* Construct an evaluation domain S of size m, if possible.
Expand All @@ -62,6 +61,11 @@ namespace nil {
return m;
}

/*
* Virtual destructor.
*/
virtual ~evaluation_domain() {};

/**
* Get the unity root.
*/
Expand Down Expand Up @@ -127,8 +131,7 @@ namespace nil {
virtual void divide_by_z_on_coset(std::vector<field_value_type> &P) = 0;

bool operator==(const evaluation_domain &rhs) const {
return m == rhs.m && log2_size == rhs.log2_size &&
generator_size == rhs.generator_size;
return m == rhs.m && log2_size == rhs.log2_size;
}
};
} // namespace math
Expand Down
49 changes: 31 additions & 18 deletions include/nil/crypto3/math/domains/extended_radix2_domain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <vector>

#include <nil/crypto3/math/domains/evaluation_domain.hpp>
#include <nil/crypto3/math/domains/basic_radix2_domain.hpp>
#include <nil/crypto3/math/domains/detail/basic_radix2_domain_aux.hpp>
#include <nil/crypto3/math/algorithms/unity_root.hpp>
#include <nil/crypto3/math/polynomial/polynomial.hpp>
Expand All @@ -46,15 +47,28 @@ namespace nil {
class extended_radix2_domain : public evaluation_domain<FieldType, ValueType> {
typedef typename FieldType::value_type field_value_type;
typedef ValueType value_type;
typedef std::pair<std::vector<field_value_type>, std::vector<field_value_type>> cache_type;

std::unique_ptr<cache_type> fft_cache;

void create_fft_cache() {
fft_cache = std::make_unique<cache_type>(std::vector<field_value_type>(),
std::vector<field_value_type>());
detail::create_fft_cache<FieldType>(small_m, omega, fft_cache->first);
detail::create_fft_cache<FieldType>(small_m, omega.inversed(), fft_cache->second);
}
public:
typedef FieldType field_type;

std::size_t small_m;
field_value_type omega;
field_value_type shift;
const std::size_t small_m;
const field_value_type omega;
const field_value_type shift;

extended_radix2_domain(const std::size_t m) : evaluation_domain<FieldType, ValueType>(m) {
extended_radix2_domain(const std::size_t m)
: evaluation_domain<FieldType, ValueType>(m),
small_m(m / 2),
omega(unity_root<FieldType>(small_m)),
shift(detail::coset_shift<FieldType>()) {
if (m <= 1)
throw std::invalid_argument("extended_radix2(): expected m > 1");

Expand All @@ -64,12 +78,6 @@ namespace nil {
throw std::invalid_argument(
"extended_radix2(): expected logm == fields::arithmetic_params<FieldType>::s + 1");
}

small_m = m / 2;

omega = unity_root<FieldType>(small_m);

shift = detail::coset_shift<FieldType>();
}

void fft(std::vector<value_type> &a) override {
Expand All @@ -94,8 +102,11 @@ namespace nil {
shift_i *= shift;
}

detail::basic_radix2_fft<FieldType>(a0, omega);
detail::basic_radix2_fft<FieldType>(a1, omega);
if (fft_cache == nullptr) {
create_fft_cache();
}
detail::basic_radix2_fft_cached<FieldType>(a0, fft_cache->first);
detail::basic_radix2_fft_cached<FieldType>(a1, fft_cache->first);

for (std::size_t i = 0; i < small_m; ++i) {
a[i] = a0[i];
Expand All @@ -116,9 +127,11 @@ namespace nil {
std::vector<value_type> a0(a.begin(), a.begin() + small_m);
std::vector<value_type> a1(a.begin() + small_m, a.end());

const field_value_type omega_inverse = omega.inversed();
detail::basic_radix2_fft<FieldType>(a0, omega_inverse);
detail::basic_radix2_fft<FieldType>(a1, omega_inverse);
if (fft_cache == nullptr) {
create_fft_cache();
}
detail::basic_radix2_fft_cached<FieldType>(a0, fft_cache->second);
detail::basic_radix2_fft_cached<FieldType>(a1, fft_cache->second);

const field_value_type shift_to_small_m = shift.pow(small_m);
const field_value_type sconst = (field_value_type(small_m) * (field_value_type::one() - shift_to_small_m)).inversed();
Expand Down Expand Up @@ -161,14 +174,14 @@ namespace nil {
if(std::distance(t_powers_begin, t_powers_end) < this->m) {
throw std::invalid_argument("extended_radix2: expected std::distance(t_powers_begin, t_powers_end) >= this->m");
}

basic_radix2_domain<FieldType, ValueType> basic_domain(small_m);

std::vector<value_type> T0 =
basic_domain.evaluate_all_lagrange_polynomials(t_powers_begin, t_powers_end);
std::vector<value_type> T0_times_t_to_small_m =
basic_domain.evaluate_all_lagrange_polynomials(t_powers_begin + small_m, t_powers_end);

field_value_type shift_inverse = shift.inversed();
std::vector<value_type> shift_inv_t_powers(small_m);
std::vector<value_type> shift_inv_t_powers_times_t_to_small_m(small_m);
Expand Down
Loading

0 comments on commit 2c3d8bb

Please sign in to comment.