Skip to content

Commit

Permalink
Batching no longer fails when circuit/assignment gens are reordered. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Iluvmagick authored Apr 2, 2024
1 parent 45fe773 commit e63299e
Show file tree
Hide file tree
Showing 17 changed files with 205 additions and 132 deletions.
25 changes: 23 additions & 2 deletions include/nil/blueprint/blueprint/plonk/assignment.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ namespace nil {
template<typename BatchType, typename ArithmetizationType, typename VariableType>
struct has_finalize_batch;

template<typename BatchType>
struct has_name;

template<typename ComponentType>
struct input_type_v;

Expand Down Expand Up @@ -119,6 +122,7 @@ namespace nil {
boost::mpl::vector<
has_add_input<_batch, _input_type, _result_type>,
has_finalize_batch<_batch, ArithmetizationType, var>,
has_name<_batch>,
boost::type_erasure::same_type<boost::type_erasure::deduced<input_type_v<_batch>>, _input_type>,
boost::type_erasure::same_type<boost::type_erasure::deduced<result_type_v<_batch>>, _result_type>,
boost::type_erasure::same_type<boost::type_erasure::deduced<component_params_type_v<_batch>>, _variadics>,
Expand Down Expand Up @@ -155,22 +159,39 @@ namespace nil {
desc.constant_columns, desc.selector_columns) {
}

template<typename ComponentType, typename... ComponentParams>
typename ComponentType::result_type add_input_to_batch_assignment(
const typename ComponentType::input_type &input,
ComponentParams... params) {

return add_input_to_batch<ComponentType>(input, false, params...);
}

template<typename ComponentType, typename... ComponentParams>
typename ComponentType::result_type add_input_to_batch_circuit(
const typename ComponentType::input_type &input,
ComponentParams... params) {

return add_input_to_batch<ComponentType>(input, true, params...);
}

template<typename ComponentType, typename... ComponentParams>
typename ComponentType::result_type add_input_to_batch(
const typename ComponentType::input_type &input,
bool called_from_generate_circuit,
ComponentParams... params) {
using batching_type = component_batch<ArithmetizationType, BlueprintFieldType, ComponentType,
ComponentParams...>;
batching_type batch(*this, std::tuple<ComponentParams...>(params...));
auto it = component_batches.find(batch);
if (it == component_batches.end()) {
auto result = batch.add_input(input);
auto result = batch.add_input(input, called_from_generate_circuit);
component_batches.insert(batch);
return result;
} else {
// safe because the ordering doesn't depend on the batch inputs
return boost::type_erasure::any_cast<batching_type&>(const_cast<batcher_type&>(*it))
.add_input(input);
.add_input(input, called_from_generate_circuit);
}
}

Expand Down
136 changes: 98 additions & 38 deletions include/nil/blueprint/component_batch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#pragma once

#include <functional>
#include <string>
#include <vector>
#include <numeric>
#include <utility>
Expand Down Expand Up @@ -135,6 +136,13 @@ namespace nil {
}
};

template<typename BatchType>
struct has_name {
static std::string apply(const BatchType& batch) {
return batch.name();
}
};

// Generic-ish enough batching solution for single-line components
// Lookups currently unsupported
// Partially supports component prarameterization -- only if passed through template parameters
Expand All @@ -150,7 +158,9 @@ namespace nil {
using gate_type = crypto3::zk::snark::plonk_gate<BlueprintFieldType, constraint_type>;
using component_params_type = typename std::tuple<ComponentParams...>;
// input-output pairs for batched components
std::map<input_type, result_type, comparison_for_inputs_results<BlueprintFieldType, input_type>>
// the second bool determines whether the result has been actually filled; is left unfilled if called
// from generate_circuit
std::map<input_type, std::pair<result_type, bool>, comparison_for_inputs_results<BlueprintFieldType, input_type>>
inputs_results;
// pointer to the assignment we are going to use in the end
assignment<ArithmetizationType> &parent_assignment;
Expand All @@ -176,17 +186,63 @@ namespace nil {

~component_batch() = default;

std::string name() const {
std::string result = typeid(ComponentType).name();
std::apply([&result](auto... args) {
((result += "__" + std::to_string(args)), ...);
}, params_tuple);
return result;
}

void variable_transform(std::reference_wrapper<var> variable) {
variable.get() = parent_assignment.add_batch_variable(
var_value(internal_assignment, variable.get()));
}

ComponentType build_component_instance(const std::size_t component_witness_amount,
const std::size_t start_value = 0) const {
const std::vector<std::size_t> constants = {}, public_inputs = {};
std::vector<std::size_t> witness_columns(component_witness_amount);
std::iota(witness_columns.begin(), witness_columns.end(), start_value);
return std::apply(component_builder, std::make_tuple(witness_columns, constants, public_inputs, params_tuple));
}

std::size_t get_component_witness_amount() const {
const compiler_manifest assignment_manifest(parent_assignment.witnesses_amount(), false);
const auto component_manifest = std::apply(ComponentType::get_manifest, params_tuple);
const auto intersection = assignment_manifest.intersect(component_manifest);
BOOST_ASSERT_MSG(intersection.is_satisfiable(), "Component either has a constant or does not fit");
const std::size_t component_witness_amount = intersection.witness_amount->max_value_if_sat();
return component_witness_amount;
}

// call this in both generate_assignments and generate_circuit
result_type add_input(const input_type &input) {
// short-circuit if we are in generate_circuit and the input has already been through batching
result_type add_input(const input_type &input, bool called_from_generate_circuit = false) {
// short-circuit if the input has already been through batching
bool unassigned_result_found = false;
if (inputs_results.find(input) != inputs_results.end()) {
return inputs_results.at(input);
auto result_pair = inputs_results.at(input);
if (result_pair.second || called_from_generate_circuit) {
return result_pair.first;
}
unassigned_result_found = true;
}

std::size_t component_witness_amount = get_component_witness_amount();
ComponentType component_instance = build_component_instance(component_witness_amount);

if (called_from_generate_circuit) {
// if we found a result we have already returned before this point
// generating a dummy result
result_type result(component_instance, 0);
for (auto variable : result.all_vars()) {
variable.get() = parent_assignment.add_batch_variable(0);
}
bool insertion_result = inputs_results.insert({input, {result, false}}).second;
BOOST_ASSERT(insertion_result);
return result;
}

// now we need to actually calculate the result without instantiating the component
// luckily, we already have the mechanism for that
input_type input_copy = input;
Expand All @@ -195,18 +251,6 @@ namespace nil {
for (const auto &var : vars) {
values.push_back(var_value(parent_assignment, var.get()));
}
// generate_empty_assignments is used to get the correctly filled result_type
const compiler_manifest assignment_manifest(parent_assignment.witnesses_amount(), false);
const auto component_manifest = std::apply(ComponentType::get_manifest, params_tuple);
const auto intersection = assignment_manifest.intersect(component_manifest);
BOOST_ASSERT_MSG(intersection.is_satisfiable(), "Component either has a constant or does not fit");
const std::size_t component_witness_amount = intersection.witness_amount->max_value_if_sat();

const std::vector<std::size_t> constants = {}, public_inputs = {};
std::vector<std::size_t> witness_columns(component_witness_amount);
std::iota(witness_columns.begin(), witness_columns.end(), 0);
ComponentType component_instance =
std::apply(component_builder, std::make_tuple(witness_columns, constants, public_inputs, params_tuple));
// safety resize for the case where parent assignment got resized during the lifetime
internal_assignment.resize_witnesses(component_witness_amount);
// move the variables to internal_assignment's public_input column
Expand All @@ -216,12 +260,28 @@ namespace nil {
}
auto result = generate_empty_assignments(component_instance, internal_assignment, input_copy, 0);
// and replace the variables with placeholders, while saving their values
for (auto variable : result.all_vars()) {
variable_transform(variable);
if (!unassigned_result_found) {
for (auto variable : result.all_vars()) {
variable_transform(variable);
}
bool insertion_result = inputs_results.insert({input, {result, true}}).second;
BOOST_ASSERT(insertion_result);
return result;
} else {
// already have some vars
auto unassigned_result = inputs_results.find(input)->second.first;
auto unsassigned_vars = unassigned_result.all_vars();
auto result_vars = result.all_vars();
BOOST_ASSERT(unsassigned_vars.size() == result_vars.size());
for (std::size_t i = 0; i < unsassigned_vars.size(); i++) {
parent_assignment.batch_private_storage(unsassigned_vars[i].get().rotation) =
var_value(internal_assignment, result_vars[i].get());
}
inputs_results.erase(input);
bool insertion_result = inputs_results.insert({input, {unassigned_result, true}}).second;
BOOST_ASSERT(insertion_result);
return unassigned_result;
}
bool insertion_result = inputs_results.insert({input, result}).second;
BOOST_ASSERT(insertion_result);
return result;
}

// call this once in the end in assignment
Expand All @@ -237,26 +297,21 @@ namespace nil {
return start_row_index;
}
// First figure out how much we can scale the component
const compiler_manifest assignment_manifest(parent_assignment.witnesses_amount(), true);
const auto component_manifest = std::apply(ComponentType::get_manifest, params_tuple);
const auto intersection = assignment_manifest.intersect(component_manifest);
BOOST_ASSERT_MSG(intersection.is_satisfiable(), "Component does not fit");
const std::size_t component_witness_amount = intersection.witness_amount->max_value_if_sat();
const std::size_t component_witness_amount = get_component_witness_amount();
std::size_t row = start_row_index,
col_offset = 0;
const std::vector<std::size_t> constants = {}, public_inputs = {};
std::size_t gate_id = generate_batch_gate(
bp, inputs_results.begin()->first, component_witness_amount);
for (auto &input_result : inputs_results) {
const input_type &input = input_result.first;
result_type &result = input_result.second;
result_type &result = input_result.second.first;
bool result_status = input_result.second.second;
BOOST_ASSERT(result_status);
if (col_offset == 0) {
parent_assignment.enable_selector(gate_id, row);
}
std::vector<std::size_t> witness_columns(component_witness_amount);
std::iota(witness_columns.begin(), witness_columns.end(), col_offset);
ComponentType component_instance =
std::apply(component_builder, std::make_tuple(witness_columns, constants, public_inputs, params_tuple));
ComponentType component_instance = build_component_instance(component_witness_amount, col_offset);
auto actual_result = generate_assignments(component_instance, parent_assignment, input, row);
generate_copy_constraints(component_instance, bp, parent_assignment, input, row);
std::size_t vars_amount = result.all_vars().size();
Expand Down Expand Up @@ -308,11 +363,7 @@ namespace nil {
const std::size_t component_witness_amount) {

circuit<ArithmetizationType> tmp_bp;
std::vector<std::size_t> witness_columns(component_witness_amount);
const std::vector<std::size_t> constants = {}, public_inputs = {};
std::iota(witness_columns.begin(), witness_columns.end(), 0);
ComponentType component_instance =
std::apply(component_builder, std::make_tuple(witness_columns, constants, public_inputs, params_tuple));
ComponentType component_instance = build_component_instance(component_witness_amount);
generate_gates(component_instance, tmp_bp, parent_assignment, example_input);
const auto &gates = tmp_bp.gates();
BOOST_ASSERT(gates.size() == 1);
Expand Down Expand Up @@ -357,8 +408,10 @@ namespace boost {
struct concept_interface<nil::blueprint::has_add_input<BatchType, InputType, ResultType>, Base, BatchType>
: Base {

ResultType add_input(typename as_param<Base, const InputType&>::type input) {
return call(nil::blueprint::has_add_input<BatchType, InputType, ResultType>(), *this, input);
ResultType add_input(typename as_param<Base, const InputType&>::type input,
bool called_from_generate_circuit) {
return call(nil::blueprint::has_add_input<BatchType, InputType, ResultType>(), *this, input,
called_from_generate_circuit);
}
};

Expand All @@ -374,5 +427,12 @@ namespace boost {
bp, variable_map, start_row_index);
}
};

template<typename BatchType, typename Base>
struct concept_interface<nil::blueprint::has_name<BatchType>, Base, BatchType> : Base {
std::string name() const {
return call(nil::blueprint::has_name<BatchType>(), *this);
}
};
} // namespace type_erasure
} // namespace boost
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,6 @@ namespace nil {
X = var(component.W(0), start_row_index + component.rows_amount - 1, false, var::column_type::witness);
Y = var(component.W(1), start_row_index + component.rows_amount - 1, false, var::column_type::witness);
}
result_type(const curve_element_variable_base_scalar_mul &component, std::size_t start_row_index) {
X = var(component.W(0), start_row_index + component.rows_amount - 1, false, var::column_type::witness);
Y = var(component.W(1), start_row_index + component.rows_amount - 1, false, var::column_type::witness);
}

std::vector<std::reference_wrapper<var>> all_vars() {
return {X, Y};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@ namespace nil {

struct result_type {
var output = var(0, 0, false);
result_type(const addition &component, std::uint32_t start_row_index) {
output = var(component.W(2), start_row_index, false, var::column_type::witness);
}

result_type(const addition &component, std::size_t start_row_index) {
output = var(component.W(2), start_row_index, false, var::column_type::witness);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,6 @@ namespace nil {
output = var(component.W(2), start_row_index, false, var::column_type::witness);
}

result_type(const division &component, std::size_t start_row_index) {
output = var(component.W(2), start_row_index, false, var::column_type::witness);
}

std::vector<std::reference_wrapper<var>> all_vars() {
return {output};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ namespace nil {
output = var(component.W(2), start_row_index, false, var::column_type::witness);
}

result_type(const division_or_zero &component, std::size_t start_row_index) {
output = var(component.W(2), start_row_index, false, var::column_type::witness);
}

std::vector<std::reference_wrapper<var>> all_vars() {
return {output};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,6 @@ namespace nil {
start_row_index + component.rows_amount - 1, false);
}

result_type(const exponentiation &component, std::size_t start_row_index) {
output = var(component.W(intermediate_start + component.intermediate_results_per_row - 1),
start_row_index + component.rows_amount - 1, false);
}

std::vector<std::reference_wrapper<var>> all_vars() {
return {output};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ namespace nil {
};

struct result_type {
std::array<var,2> output;
std::array<var,2> output;

result_type(const linear_inter_coefs &component, std::uint32_t start_row_index) {
output = { var(component.W(4), start_row_index, false, var::column_type::witness),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,6 @@ namespace nil {
output = var(component.W(2), start_row_index, false, var::column_type::witness);
}

result_type(const multiplication &component, std::size_t start_row_index) {
output = var(component.W(2), start_row_index, false, var::column_type::witness);
}

std::vector<std::reference_wrapper<var>> all_vars() {
return {output};
}
Expand Down
Loading

0 comments on commit e63299e

Please sign in to comment.