-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added FFT cache to domains [SyncWith: crypto3-zk#277] (#73)
Added caching to basic_radix2_domain_aux and it's calls.
- Loading branch information
1 parent
6a554ff
commit 2c3d8bb
Showing
12 changed files
with
424 additions
and
114 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
//---------------------------------------------------------------------------// | ||
// Copyright (c) 2020-2021 Mikhail Komarov <[email protected]> | ||
// Copyright (c) 2020-2021 Nikita Kaskov <[email protected]> | ||
// Copyright (c) 2024 Dmitrii Tabalin <[email protected]> | ||
// | ||
// MIT License | ||
// | ||
|
@@ -27,6 +28,7 @@ | |
#define CRYPTO3_MATH_BASIC_RADIX2_DOMAIN_AUX_HPP | ||
|
||
#include <algorithm> | ||
#include <memory> | ||
#include <vector> | ||
|
||
#include <nil/crypto3/algebra/type_traits.hpp> | ||
|
@@ -39,17 +41,32 @@ namespace nil { | |
namespace math { | ||
namespace detail { | ||
|
||
/* | ||
* Building caches for fft operations | ||
*/ | ||
template<typename FieldType> | ||
void create_fft_cache( | ||
const std::size_t size, | ||
const typename FieldType::value_type &omega, | ||
std::vector<typename FieldType::value_type> &cache) { | ||
typedef typename FieldType::value_type value_type; | ||
cache.resize(size); | ||
cache[0] = value_type::one(); | ||
for (std::size_t i = 1; i < size; ++i) { | ||
cache[i] = cache[i - 1] * omega; | ||
} | ||
} | ||
|
||
/* | ||
* Below we make use of pseudocode from [CLRS 2n Ed, pp. 864]. | ||
* Also, note that it's the caller's responsibility to multiply by 1/N. | ||
*/ | ||
template<typename FieldType, typename Range> | ||
void basic_radix2_fft(Range &a, const typename FieldType::value_type &omega) { | ||
void basic_radix2_fft_cached(Range &a, const std::vector<typename FieldType::value_type> &omega_cache) { | ||
typedef typename std::iterator_traits<decltype(std::begin(std::declval<Range>()))>::value_type | ||
value_type; | ||
typedef typename FieldType::value_type field_value_type; | ||
BOOST_STATIC_ASSERT(algebra::is_field<FieldType>::value); | ||
|
||
// It now supports curve elements too, should probably some other assertion about the field type and value type | ||
// BOOST_STATIC_ASSERT(std::is_same<typename FieldType::value_type, value_type>::value); | ||
|
||
|
@@ -64,23 +81,36 @@ namespace nil { | |
std::swap(a[k], a[rk]); | ||
} | ||
|
||
std::size_t m = 1; // invariant: m = 2^{s-1} | ||
for (std::size_t s = 1; s <= logn; ++s) { | ||
// invariant: m = 2^{s-1} | ||
value_type t; | ||
for (std::size_t s = 1, m = 1, inc = n / 2; s <= logn; ++s, m <<= 1, inc >>= 1) { | ||
// w_m is 2^s-th root of unity now | ||
const field_value_type w_m = omega.pow(n / (2 * m)); | ||
|
||
asm volatile("/* pre-inner */"); | ||
for (std::size_t k = 0; k < n; k += 2 * m) { | ||
field_value_type w = field_value_type::one(); | ||
for (std::size_t j = 0; j < m; ++j) { | ||
const value_type t = a[k + j + m] * w; | ||
a[k + j + m] = a[k + j] - t; | ||
a[k + j] = a[k + j] + t; | ||
w *= w_m; | ||
for (std::size_t j = 0, idx = 0; j < m; ++j, idx += inc) { | ||
t = std::move(a[k + j + m]); | ||
t = t * omega_cache[idx]; | ||
a[k + j + m] = a[k + j]; | ||
a[k + j + m] -= t; | ||
a[k + j] += t; | ||
} | ||
} | ||
asm volatile("/* post-inner */"); | ||
m *= 2; | ||
} | ||
} | ||
|
||
/** | ||
* Note that it's the caller's responsibility to multiply by 1/N. | ||
*/ | ||
template<typename FieldType, typename Range> | ||
void basic_radix2_fft( | ||
Range &a, const typename FieldType::value_type &omega, | ||
std::shared_ptr<std::vector<typename FieldType::value_type>> omega_cache = nullptr) { | ||
|
||
if (omega_cache == nullptr) { | ||
std::vector<typename FieldType::value_type> omega_powers; | ||
create_fft_cache<FieldType>(a.size(), omega, omega_powers); | ||
basic_radix2_fft_cached<FieldType>(a, omega_powers); | ||
} else { | ||
basic_radix2_fft_cached<FieldType>(a, *omega_cache); | ||
} | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.