From e63299ed714e1da4be53ac9a1b31ca5a56774764 Mon Sep 17 00:00:00 2001 From: Iluvmagick Date: Tue, 2 Apr 2024 18:43:44 +0400 Subject: [PATCH] Batching no longer fails when circuit/assignment gens are reordered. (#356) --- .../blueprint/blueprint/plonk/assignment.hpp | 25 +++- include/nil/blueprint/component_batch.hpp | 136 +++++++++++++----- .../pasta/plonk/variable_base_scalar_mul.hpp | 4 - .../algebra/fields/plonk/addition.hpp | 3 - .../algebra/fields/plonk/division.hpp | 4 - .../algebra/fields/plonk/division_or_zero.hpp | 4 - .../algebra/fields/plonk/exponentiation.hpp | 5 - .../fields/plonk/linear_interpolation.hpp | 2 +- .../algebra/fields/plonk/multiplication.hpp | 4 - .../plonk/multiplication_by_constant.hpp | 4 - .../fields/plonk/non_native/equality_flag.hpp | 4 - .../fields/plonk/quadratic_interpolation.hpp | 2 +- .../algebra/fields/plonk/subtraction.hpp | 4 - .../expression_evaluation_component.hpp | 35 +++-- .../systems/snark/plonk/verifier/verifier.hpp | 8 +- test/component_batch.cpp | 89 +++++++----- test/test_plonk_component.hpp | 4 +- 17 files changed, 205 insertions(+), 132 deletions(-) diff --git a/include/nil/blueprint/blueprint/plonk/assignment.hpp b/include/nil/blueprint/blueprint/plonk/assignment.hpp index 94220db71..d43a3a8d5 100644 --- a/include/nil/blueprint/blueprint/plonk/assignment.hpp +++ b/include/nil/blueprint/blueprint/plonk/assignment.hpp @@ -80,6 +80,9 @@ namespace nil { template struct has_finalize_batch; + template + struct has_name; + template struct input_type_v; @@ -119,6 +122,7 @@ namespace nil { boost::mpl::vector< has_add_input<_batch, _input_type, _result_type>, has_finalize_batch<_batch, ArithmetizationType, var>, + has_name<_batch>, boost::type_erasure::same_type>, _input_type>, boost::type_erasure::same_type>, _result_type>, boost::type_erasure::same_type>, _variadics>, @@ -155,22 +159,39 @@ namespace nil { desc.constant_columns, desc.selector_columns) { } + template + typename ComponentType::result_type add_input_to_batch_assignment( + const typename ComponentType::input_type &input, + ComponentParams... params) { + + return add_input_to_batch(input, false, params...); + } + + template + typename ComponentType::result_type add_input_to_batch_circuit( + const typename ComponentType::input_type &input, + ComponentParams... params) { + + return add_input_to_batch(input, true, params...); + } + template typename ComponentType::result_type add_input_to_batch( const typename ComponentType::input_type &input, + bool called_from_generate_circuit, ComponentParams... params) { using batching_type = component_batch; batching_type batch(*this, std::tuple(params...)); auto it = component_batches.find(batch); if (it == component_batches.end()) { - auto result = batch.add_input(input); + auto result = batch.add_input(input, called_from_generate_circuit); component_batches.insert(batch); return result; } else { // safe because the ordering doesn't depend on the batch inputs return boost::type_erasure::any_cast(const_cast(*it)) - .add_input(input); + .add_input(input, called_from_generate_circuit); } } diff --git a/include/nil/blueprint/component_batch.hpp b/include/nil/blueprint/component_batch.hpp index ef0bd3a2a..c55debdcb 100644 --- a/include/nil/blueprint/component_batch.hpp +++ b/include/nil/blueprint/component_batch.hpp @@ -25,6 +25,7 @@ #pragma once #include +#include #include #include #include @@ -135,6 +136,13 @@ namespace nil { } }; + template + struct has_name { + static std::string apply(const BatchType& batch) { + return batch.name(); + } + }; + // Generic-ish enough batching solution for single-line components // Lookups currently unsupported // Partially supports component prarameterization -- only if passed through template parameters @@ -150,7 +158,9 @@ namespace nil { using gate_type = crypto3::zk::snark::plonk_gate; using component_params_type = typename std::tuple; // input-output pairs for batched components - std::map> + // the second bool determines whether the result has been actually filled; is left unfilled if called + // from generate_circuit + std::map, comparison_for_inputs_results> inputs_results; // pointer to the assignment we are going to use in the end assignment &parent_assignment; @@ -176,17 +186,63 @@ namespace nil { ~component_batch() = default; + std::string name() const { + std::string result = typeid(ComponentType).name(); + std::apply([&result](auto... args) { + ((result += "__" + std::to_string(args)), ...); + }, params_tuple); + return result; + } + void variable_transform(std::reference_wrapper variable) { variable.get() = parent_assignment.add_batch_variable( var_value(internal_assignment, variable.get())); } + ComponentType build_component_instance(const std::size_t component_witness_amount, + const std::size_t start_value = 0) const { + const std::vector constants = {}, public_inputs = {}; + std::vector witness_columns(component_witness_amount); + std::iota(witness_columns.begin(), witness_columns.end(), start_value); + return std::apply(component_builder, std::make_tuple(witness_columns, constants, public_inputs, params_tuple)); + } + + std::size_t get_component_witness_amount() const { + const compiler_manifest assignment_manifest(parent_assignment.witnesses_amount(), false); + const auto component_manifest = std::apply(ComponentType::get_manifest, params_tuple); + const auto intersection = assignment_manifest.intersect(component_manifest); + BOOST_ASSERT_MSG(intersection.is_satisfiable(), "Component either has a constant or does not fit"); + const std::size_t component_witness_amount = intersection.witness_amount->max_value_if_sat(); + return component_witness_amount; + } + // call this in both generate_assignments and generate_circuit - result_type add_input(const input_type &input) { - // short-circuit if we are in generate_circuit and the input has already been through batching + result_type add_input(const input_type &input, bool called_from_generate_circuit = false) { + // short-circuit if the input has already been through batching + bool unassigned_result_found = false; if (inputs_results.find(input) != inputs_results.end()) { - return inputs_results.at(input); + auto result_pair = inputs_results.at(input); + if (result_pair.second || called_from_generate_circuit) { + return result_pair.first; + } + unassigned_result_found = true; + } + + std::size_t component_witness_amount = get_component_witness_amount(); + ComponentType component_instance = build_component_instance(component_witness_amount); + + if (called_from_generate_circuit) { + // if we found a result we have already returned before this point + // generating a dummy result + result_type result(component_instance, 0); + for (auto variable : result.all_vars()) { + variable.get() = parent_assignment.add_batch_variable(0); + } + bool insertion_result = inputs_results.insert({input, {result, false}}).second; + BOOST_ASSERT(insertion_result); + return result; } + // now we need to actually calculate the result without instantiating the component // luckily, we already have the mechanism for that input_type input_copy = input; @@ -195,18 +251,6 @@ namespace nil { for (const auto &var : vars) { values.push_back(var_value(parent_assignment, var.get())); } - // generate_empty_assignments is used to get the correctly filled result_type - const compiler_manifest assignment_manifest(parent_assignment.witnesses_amount(), false); - const auto component_manifest = std::apply(ComponentType::get_manifest, params_tuple); - const auto intersection = assignment_manifest.intersect(component_manifest); - BOOST_ASSERT_MSG(intersection.is_satisfiable(), "Component either has a constant or does not fit"); - const std::size_t component_witness_amount = intersection.witness_amount->max_value_if_sat(); - - const std::vector constants = {}, public_inputs = {}; - std::vector witness_columns(component_witness_amount); - std::iota(witness_columns.begin(), witness_columns.end(), 0); - ComponentType component_instance = - std::apply(component_builder, std::make_tuple(witness_columns, constants, public_inputs, params_tuple)); // safety resize for the case where parent assignment got resized during the lifetime internal_assignment.resize_witnesses(component_witness_amount); // move the variables to internal_assignment's public_input column @@ -216,12 +260,28 @@ namespace nil { } auto result = generate_empty_assignments(component_instance, internal_assignment, input_copy, 0); // and replace the variables with placeholders, while saving their values - for (auto variable : result.all_vars()) { - variable_transform(variable); + if (!unassigned_result_found) { + for (auto variable : result.all_vars()) { + variable_transform(variable); + } + bool insertion_result = inputs_results.insert({input, {result, true}}).second; + BOOST_ASSERT(insertion_result); + return result; + } else { + // already have some vars + auto unassigned_result = inputs_results.find(input)->second.first; + auto unsassigned_vars = unassigned_result.all_vars(); + auto result_vars = result.all_vars(); + BOOST_ASSERT(unsassigned_vars.size() == result_vars.size()); + for (std::size_t i = 0; i < unsassigned_vars.size(); i++) { + parent_assignment.batch_private_storage(unsassigned_vars[i].get().rotation) = + var_value(internal_assignment, result_vars[i].get()); + } + inputs_results.erase(input); + bool insertion_result = inputs_results.insert({input, {unassigned_result, true}}).second; + BOOST_ASSERT(insertion_result); + return unassigned_result; } - bool insertion_result = inputs_results.insert({input, result}).second; - BOOST_ASSERT(insertion_result); - return result; } // call this once in the end in assignment @@ -237,11 +297,7 @@ namespace nil { return start_row_index; } // First figure out how much we can scale the component - const compiler_manifest assignment_manifest(parent_assignment.witnesses_amount(), true); - const auto component_manifest = std::apply(ComponentType::get_manifest, params_tuple); - const auto intersection = assignment_manifest.intersect(component_manifest); - BOOST_ASSERT_MSG(intersection.is_satisfiable(), "Component does not fit"); - const std::size_t component_witness_amount = intersection.witness_amount->max_value_if_sat(); + const std::size_t component_witness_amount = get_component_witness_amount(); std::size_t row = start_row_index, col_offset = 0; const std::vector constants = {}, public_inputs = {}; @@ -249,14 +305,13 @@ namespace nil { bp, inputs_results.begin()->first, component_witness_amount); for (auto &input_result : inputs_results) { const input_type &input = input_result.first; - result_type &result = input_result.second; + result_type &result = input_result.second.first; + bool result_status = input_result.second.second; + BOOST_ASSERT(result_status); if (col_offset == 0) { parent_assignment.enable_selector(gate_id, row); } - std::vector witness_columns(component_witness_amount); - std::iota(witness_columns.begin(), witness_columns.end(), col_offset); - ComponentType component_instance = - std::apply(component_builder, std::make_tuple(witness_columns, constants, public_inputs, params_tuple)); + ComponentType component_instance = build_component_instance(component_witness_amount, col_offset); auto actual_result = generate_assignments(component_instance, parent_assignment, input, row); generate_copy_constraints(component_instance, bp, parent_assignment, input, row); std::size_t vars_amount = result.all_vars().size(); @@ -308,11 +363,7 @@ namespace nil { const std::size_t component_witness_amount) { circuit tmp_bp; - std::vector witness_columns(component_witness_amount); - const std::vector constants = {}, public_inputs = {}; - std::iota(witness_columns.begin(), witness_columns.end(), 0); - ComponentType component_instance = - std::apply(component_builder, std::make_tuple(witness_columns, constants, public_inputs, params_tuple)); + ComponentType component_instance = build_component_instance(component_witness_amount); generate_gates(component_instance, tmp_bp, parent_assignment, example_input); const auto &gates = tmp_bp.gates(); BOOST_ASSERT(gates.size() == 1); @@ -357,8 +408,10 @@ namespace boost { struct concept_interface, Base, BatchType> : Base { - ResultType add_input(typename as_param::type input) { - return call(nil::blueprint::has_add_input(), *this, input); + ResultType add_input(typename as_param::type input, + bool called_from_generate_circuit) { + return call(nil::blueprint::has_add_input(), *this, input, + called_from_generate_circuit); } }; @@ -374,5 +427,12 @@ namespace boost { bp, variable_map, start_row_index); } }; + + template + struct concept_interface, Base, BatchType> : Base { + std::string name() const { + return call(nil::blueprint::has_name(), *this); + } + }; } // namespace type_erasure } // namespace boost diff --git a/include/nil/blueprint/components/algebra/curves/pasta/plonk/variable_base_scalar_mul.hpp b/include/nil/blueprint/components/algebra/curves/pasta/plonk/variable_base_scalar_mul.hpp index 1b955c194..8ba602548 100644 --- a/include/nil/blueprint/components/algebra/curves/pasta/plonk/variable_base_scalar_mul.hpp +++ b/include/nil/blueprint/components/algebra/curves/pasta/plonk/variable_base_scalar_mul.hpp @@ -188,10 +188,6 @@ namespace nil { X = var(component.W(0), start_row_index + component.rows_amount - 1, false, var::column_type::witness); Y = var(component.W(1), start_row_index + component.rows_amount - 1, false, var::column_type::witness); } - result_type(const curve_element_variable_base_scalar_mul &component, std::size_t start_row_index) { - X = var(component.W(0), start_row_index + component.rows_amount - 1, false, var::column_type::witness); - Y = var(component.W(1), start_row_index + component.rows_amount - 1, false, var::column_type::witness); - } std::vector> all_vars() { return {X, Y}; diff --git a/include/nil/blueprint/components/algebra/fields/plonk/addition.hpp b/include/nil/blueprint/components/algebra/fields/plonk/addition.hpp index 32c3c1871..f322daa19 100644 --- a/include/nil/blueprint/components/algebra/fields/plonk/addition.hpp +++ b/include/nil/blueprint/components/algebra/fields/plonk/addition.hpp @@ -104,9 +104,6 @@ namespace nil { struct result_type { var output = var(0, 0, false); - result_type(const addition &component, std::uint32_t start_row_index) { - output = var(component.W(2), start_row_index, false, var::column_type::witness); - } result_type(const addition &component, std::size_t start_row_index) { output = var(component.W(2), start_row_index, false, var::column_type::witness); diff --git a/include/nil/blueprint/components/algebra/fields/plonk/division.hpp b/include/nil/blueprint/components/algebra/fields/plonk/division.hpp index e62b0e415..eed81c090 100644 --- a/include/nil/blueprint/components/algebra/fields/plonk/division.hpp +++ b/include/nil/blueprint/components/algebra/fields/plonk/division.hpp @@ -108,10 +108,6 @@ namespace nil { output = var(component.W(2), start_row_index, false, var::column_type::witness); } - result_type(const division &component, std::size_t start_row_index) { - output = var(component.W(2), start_row_index, false, var::column_type::witness); - } - std::vector> all_vars() { return {output}; } diff --git a/include/nil/blueprint/components/algebra/fields/plonk/division_or_zero.hpp b/include/nil/blueprint/components/algebra/fields/plonk/division_or_zero.hpp index 097b6b955..135187e44 100644 --- a/include/nil/blueprint/components/algebra/fields/plonk/division_or_zero.hpp +++ b/include/nil/blueprint/components/algebra/fields/plonk/division_or_zero.hpp @@ -105,10 +105,6 @@ namespace nil { output = var(component.W(2), start_row_index, false, var::column_type::witness); } - result_type(const division_or_zero &component, std::size_t start_row_index) { - output = var(component.W(2), start_row_index, false, var::column_type::witness); - } - std::vector> all_vars() { return {output}; } diff --git a/include/nil/blueprint/components/algebra/fields/plonk/exponentiation.hpp b/include/nil/blueprint/components/algebra/fields/plonk/exponentiation.hpp index 72699405e..ac826a036 100644 --- a/include/nil/blueprint/components/algebra/fields/plonk/exponentiation.hpp +++ b/include/nil/blueprint/components/algebra/fields/plonk/exponentiation.hpp @@ -154,11 +154,6 @@ namespace nil { start_row_index + component.rows_amount - 1, false); } - result_type(const exponentiation &component, std::size_t start_row_index) { - output = var(component.W(intermediate_start + component.intermediate_results_per_row - 1), - start_row_index + component.rows_amount - 1, false); - } - std::vector> all_vars() { return {output}; } diff --git a/include/nil/blueprint/components/algebra/fields/plonk/linear_interpolation.hpp b/include/nil/blueprint/components/algebra/fields/plonk/linear_interpolation.hpp index 75ec3c6e5..bdb39da0e 100644 --- a/include/nil/blueprint/components/algebra/fields/plonk/linear_interpolation.hpp +++ b/include/nil/blueprint/components/algebra/fields/plonk/linear_interpolation.hpp @@ -91,7 +91,7 @@ namespace nil { }; struct result_type { - std::array output; + std::array output; result_type(const linear_inter_coefs &component, std::uint32_t start_row_index) { output = { var(component.W(4), start_row_index, false, var::column_type::witness), diff --git a/include/nil/blueprint/components/algebra/fields/plonk/multiplication.hpp b/include/nil/blueprint/components/algebra/fields/plonk/multiplication.hpp index 7970b4e0f..198a572eb 100644 --- a/include/nil/blueprint/components/algebra/fields/plonk/multiplication.hpp +++ b/include/nil/blueprint/components/algebra/fields/plonk/multiplication.hpp @@ -109,10 +109,6 @@ namespace nil { output = var(component.W(2), start_row_index, false, var::column_type::witness); } - result_type(const multiplication &component, std::size_t start_row_index) { - output = var(component.W(2), start_row_index, false, var::column_type::witness); - } - std::vector> all_vars() { return {output}; } diff --git a/include/nil/blueprint/components/algebra/fields/plonk/multiplication_by_constant.hpp b/include/nil/blueprint/components/algebra/fields/plonk/multiplication_by_constant.hpp index 23caf3521..1d1855cd3 100644 --- a/include/nil/blueprint/components/algebra/fields/plonk/multiplication_by_constant.hpp +++ b/include/nil/blueprint/components/algebra/fields/plonk/multiplication_by_constant.hpp @@ -108,10 +108,6 @@ namespace nil { output = var(component.W(1), start_row_index, false, var::column_type::witness); } - result_type(const mul_by_constant &component, std::size_t start_row_index) { - output = var(component.W(1), start_row_index, false, var::column_type::witness); - } - std::vector> all_vars() { return {output}; } diff --git a/include/nil/blueprint/components/algebra/fields/plonk/non_native/equality_flag.hpp b/include/nil/blueprint/components/algebra/fields/plonk/non_native/equality_flag.hpp index cfcec7cc5..e4191eef8 100644 --- a/include/nil/blueprint/components/algebra/fields/plonk/non_native/equality_flag.hpp +++ b/include/nil/blueprint/components/algebra/fields/plonk/non_native/equality_flag.hpp @@ -108,10 +108,6 @@ namespace nil { output = var(component.W(3), start_row_index, false, var::column_type::witness); } - result_type(const equality_flag &component, std::size_t start_row_index) { - output = var(component.W(3), start_row_index, false, var::column_type::witness); - } - std::vector> all_vars() { return {output}; } diff --git a/include/nil/blueprint/components/algebra/fields/plonk/quadratic_interpolation.hpp b/include/nil/blueprint/components/algebra/fields/plonk/quadratic_interpolation.hpp index b954db110..b234036dd 100644 --- a/include/nil/blueprint/components/algebra/fields/plonk/quadratic_interpolation.hpp +++ b/include/nil/blueprint/components/algebra/fields/plonk/quadratic_interpolation.hpp @@ -103,7 +103,7 @@ namespace nil { }; struct result_type { - std::array output; + std::array output; result_type(const quadratic_inter_coefs &component, std::uint32_t start_row_index) { output = { var(component.W(6), start_row_index, false, var::column_type::witness), diff --git a/include/nil/blueprint/components/algebra/fields/plonk/subtraction.hpp b/include/nil/blueprint/components/algebra/fields/plonk/subtraction.hpp index aad708e81..8d597f347 100644 --- a/include/nil/blueprint/components/algebra/fields/plonk/subtraction.hpp +++ b/include/nil/blueprint/components/algebra/fields/plonk/subtraction.hpp @@ -108,10 +108,6 @@ namespace nil { output = var(component.W(2), start_row_index, false, var::column_type::witness); } - result_type(const subtraction &component, std::size_t start_row_index) { - output = var(component.W(2), start_row_index, false, var::column_type::witness); - } - std::vector> all_vars() { return {output}; } diff --git a/include/nil/blueprint/components/systems/snark/plonk/placeholder/detail/expression_evaluation_component.hpp b/include/nil/blueprint/components/systems/snark/plonk/placeholder/detail/expression_evaluation_component.hpp index b3fbc62b4..a5f233888 100644 --- a/include/nil/blueprint/components/systems/snark/plonk/placeholder/detail/expression_evaluation_component.hpp +++ b/include/nil/blueprint/components/systems/snark/plonk/placeholder/detail/expression_evaluation_component.hpp @@ -62,8 +62,10 @@ namespace nil { nil::blueprint::basic_non_native_policy>; expression_to_execution_simple(assignment_type &_assignment, - const std::unordered_map &_variable_map) - : assignment(_assignment), variable_map(_variable_map) + const std::unordered_map &_variable_map, + bool _generate_assignment_call) + : assignment(_assignment), variable_map(_variable_map), + generate_assignment_call(_generate_assignment_call) {} var visit(const nil::crypto3::math::expression &expr) { @@ -80,14 +82,16 @@ namespace nil { if (term.get_coeff() != value_type::one()) { auto coeff_var = assignment.add_batch_constant_variable(term.get_coeff()); result = assignment.template add_input_to_batch( - {coeff_var, variable_map.at(term.get_vars()[curr_term])}).output; + {coeff_var, variable_map.at(term.get_vars()[curr_term])}, + generate_assignment_call).output; } else { result = variable_map.at(term.get_vars()[curr_term]); } curr_term++; for (; curr_term < term_size; curr_term++) { result = assignment.template add_input_to_batch( - {result, variable_map.at(term.get_vars()[curr_term])}).output; + {result, variable_map.at(term.get_vars()[curr_term])}, + generate_assignment_call).output; } return result; } @@ -103,16 +107,19 @@ namespace nil { while (power > 1) { if (power % 2 == 0) { expr_res = assignment.template add_input_to_batch( - {expr_res, expr_res}).output; + {expr_res, expr_res}, + generate_assignment_call).output; power /= 2; } else { result = assignment.template add_input_to_batch( - {result, expr_res}).output; + {result, expr_res}, + generate_assignment_call).output; power -= 1; } } return assignment.template add_input_to_batch( - {result, expr_res}).output; + {result, expr_res}, + generate_assignment_call).output; } var operator()(const nil::crypto3::math::binary_arithmetic_operation& op) { @@ -121,13 +128,16 @@ namespace nil { switch (op.get_op()) { case crypto3::math::ArithmeticOperator::ADD: return assignment.template add_input_to_batch( - {res1, res2}).output; + {res1, res2}, + generate_assignment_call).output; case crypto3::math::ArithmeticOperator::SUB: return assignment.template add_input_to_batch( - {res1, res2}).output; + {res1, res2}, + generate_assignment_call).output; case crypto3::math::ArithmeticOperator::MULT: return assignment.template add_input_to_batch( - {res1, res2}).output; + {res1, res2}, + generate_assignment_call).output; default: throw std::runtime_error("Unsupported operation"); } @@ -135,6 +145,7 @@ namespace nil { private: assignment_type &assignment; const std::unordered_map &variable_map; + bool generate_assignment_call; }; template @@ -254,7 +265,7 @@ namespace nil { using component_type = plonk_expression_evaluation_component; using expression_evaluator_type = typename component_type::expression_evaluator_type; - expression_evaluator_type evaluator(assignment, instance_input.variable_mapping); + expression_evaluator_type evaluator(assignment, instance_input.variable_mapping, true); return typename component_type::result_type(evaluator.visit(component.constraint), start_row_index); } @@ -273,7 +284,7 @@ namespace nil { using component_type = plonk_expression_evaluation_component; using expression_evaluator_type = typename component_type::expression_evaluator_type; - expression_evaluator_type evaluator(assignment, instance_input.variable_mapping); + expression_evaluator_type evaluator(assignment, instance_input.variable_mapping, false); return typename component_type::result_type(evaluator.visit(component.constraint), start_row_index); } } // namespace components diff --git a/include/nil/blueprint/components/systems/snark/plonk/verifier/verifier.hpp b/include/nil/blueprint/components/systems/snark/plonk/verifier/verifier.hpp index fe3892203..f3b2fc84c 100644 --- a/include/nil/blueprint/components/systems/snark/plonk/verifier/verifier.hpp +++ b/include/nil/blueprint/components/systems/snark/plonk/verifier/verifier.hpp @@ -438,7 +438,7 @@ namespace nil { swap_input_type swap_input; swap_input.arr.push_back({instance_input.merkle_tree_positions[i][k], instance_input.initial_proof_hashes[i][cur_hash], hash_var}); - auto swap_result = assignment.template add_input_to_batch( + auto swap_result = assignment.template add_input_to_batch_assignment( swap_input, 1); poseidon_input = {zero_var, swap_result.output[0].first, swap_result.output[0].second}; poseidon_output = generate_assignments(poseidon_instance, assignment, poseidon_input, row); @@ -471,7 +471,7 @@ namespace nil { swap_input_type swap_input; swap_input.arr.push_back({instance_input.merkle_tree_positions[i][k], instance_input.round_proof_hashes[i][cur_hash], hash_var}); - auto swap_result = assignment.template add_input_to_batch( + auto swap_result = assignment.template add_input_to_batch_assignment( swap_input, 1); poseidon_input = {zero_var, swap_result.output[0].first, swap_result.output[0].second}; poseidon_output = generate_assignments(poseidon_instance, assignment, poseidon_input, row); @@ -669,7 +669,7 @@ namespace nil { swap_input_type swap_input; swap_input.arr.push_back({instance_input.merkle_tree_positions[i][k], instance_input.initial_proof_hashes[i][cur_hash], hash_var}); - auto swap_result = assignment.template add_input_to_batch( + auto swap_result = assignment.template add_input_to_batch_circuit( swap_input, 1); poseidon_input = {zero_var, swap_result.output[0].first, swap_result.output[0].second}; poseidon_output = generate_circuit(poseidon_instance, bp, assignment, poseidon_input, row); @@ -701,7 +701,7 @@ namespace nil { swap_input_type swap_input; swap_input.arr.push_back({instance_input.merkle_tree_positions[i][k], instance_input.round_proof_hashes[i][cur_hash], hash_var}); - auto swap_result = assignment.template add_input_to_batch( + auto swap_result = assignment.template add_input_to_batch_circuit( swap_input, 1); poseidon_input = {zero_var, swap_result.output[0].first, swap_result.output[0].second}; poseidon_output = generate_circuit(poseidon_instance, bp, assignment, poseidon_input, row); diff --git a/test/component_batch.cpp b/test/component_batch.cpp index c9336b6f0..28f357f00 100644 --- a/test/component_batch.cpp +++ b/test/component_batch.cpp @@ -116,8 +116,8 @@ BOOST_AUTO_TEST_CASE(component_batch_basic_test) { using component_type = components::multiplication< ArithmetizationType, field_type, nil::blueprint::basic_non_native_policy>; - assignment.add_input_to_batch({public_input_var_maker(), public_input_var_maker()}); - assignment.add_input_to_batch({public_input_var_maker(), public_input_var_maker()}); + assignment.add_input_to_batch_assignment({public_input_var_maker(), public_input_var_maker()}); + assignment.add_input_to_batch_assignment({public_input_var_maker(), public_input_var_maker()}); std::size_t row = assignment.finalize_component_batches(circuit, 0); BOOST_CHECK_EQUAL(row, 1); BOOST_CHECK_EQUAL(circuit.gates().size(), 1); @@ -161,18 +161,18 @@ BOOST_AUTO_TEST_CASE(component_batch_continuation_test) { using component_type = components::multiplication< ArithmetizationType, field_type, nil::blueprint::basic_non_native_policy>; - auto first_result = assignment.add_input_to_batch({public_input_var_maker(), public_input_var_maker()}); - auto second_result = assignment.add_input_to_batch({public_input_var_maker(), public_input_var_maker()}); - assignment.add_input_to_batch({public_input_var_maker(), public_input_var_maker()}); - auto third_result = assignment.add_input_to_batch({public_input_var_maker(), public_input_var_maker()}); - auto fourth_result = assignment.add_input_to_batch({first_result.output, second_result.output}); + auto first_result = assignment.add_input_to_batch_assignment({public_input_var_maker(), public_input_var_maker()}); + auto second_result = assignment.add_input_to_batch_assignment({public_input_var_maker(), public_input_var_maker()}); + assignment.add_input_to_batch_assignment({public_input_var_maker(), public_input_var_maker()}); + auto third_result = assignment.add_input_to_batch_assignment({public_input_var_maker(), public_input_var_maker()}); + auto fourth_result = assignment.add_input_to_batch_assignment({first_result.output, second_result.output}); using addition_type = components::addition< ArithmetizationType, field_type, nil::blueprint::basic_non_native_policy>; std::size_t row = 0; addition_type add_component({0, 1, 2}, {}, {}); auto addition_result = generate_assignments(add_component, assignment, {third_result.output, fourth_result.output}, row); generate_circuit(add_component, circuit, assignment, {third_result.output, fourth_result.output}, row++); - auto fifth_result = assignment.add_input_to_batch({addition_result.output, public_input_var_maker()}); + auto fifth_result = assignment.add_input_to_batch_assignment({addition_result.output, public_input_var_maker()}); generate_assignments(add_component, assignment, {addition_result.output, fifth_result.output}, row); generate_circuit(add_component, circuit, assignment, {addition_result.output, fifth_result.output}, row++); row = assignment.finalize_component_batches(circuit, row); @@ -235,26 +235,26 @@ BOOST_AUTO_TEST_CASE(component_batch_multibatch_test) { using add_component_type = components::addition< ArithmetizationType, field_type, nil::blueprint::basic_non_native_policy>; using div_or_zero_component_type = components::division_or_zero; - auto mul_result = assignment.add_input_to_batch( + auto mul_result = assignment.add_input_to_batch_assignment( {public_input_var_maker(), public_input_var_maker()}); - auto add_result = assignment.add_input_to_batch({mul_result.output, public_input_var_maker()}); - auto mul_result_2 = assignment.add_input_to_batch({add_result.output, mul_result.output}); - assignment.add_input_to_batch({public_input_var_maker(), public_input_var_maker()}); + auto add_result = assignment.add_input_to_batch_assignment({mul_result.output, public_input_var_maker()}); + auto mul_result_2 = assignment.add_input_to_batch_assignment({add_result.output, mul_result.output}); + assignment.add_input_to_batch_assignment({public_input_var_maker(), public_input_var_maker()}); div_or_zero_component_type div_or_zero_component({0, 1, 2, 3, 4}, {}, {}); var div_or_zero_var = public_input_var_maker(); auto div_or_zero_res = generate_assignments( div_or_zero_component, assignment, {mul_result_2.output, div_or_zero_var}, 0); generate_circuit(div_or_zero_component, circuit, assignment, {mul_result_2.output, div_or_zero_var}, 0); - assignment.add_input_to_batch({div_or_zero_res.output, public_input_var_maker()}); - assignment.add_input_to_batch({public_input_var_maker(), public_input_var_maker()}); - assignment.add_input_to_batch({add_result.output, mul_result.output}); + assignment.add_input_to_batch_assignment({div_or_zero_res.output, public_input_var_maker()}); + assignment.add_input_to_batch_assignment({public_input_var_maker(), public_input_var_maker()}); + assignment.add_input_to_batch_assignment({add_result.output, mul_result.output}); // duplicates, should not count! for (std::size_t i = 0; i < 5; i++) { - assignment.add_input_to_batch({add_result.output, mul_result.output}); + assignment.add_input_to_batch_assignment({add_result.output, mul_result.output}); } // not duplicates, should count for (std::size_t i = 0; i < 5; i++) { - assignment.add_input_to_batch({public_input_var_maker(), public_input_var_maker()}); + assignment.add_input_to_batch_assignment({public_input_var_maker(), public_input_var_maker()}); } std::size_t row = assignment.finalize_component_batches(circuit, 1); BOOST_CHECK_EQUAL(row, 4); @@ -371,11 +371,11 @@ BOOST_AUTO_TEST_CASE(component_batch_const_batch_test) { assignment.constant(0, row) = 1445; assignment.enable_selector(lookup_selector, row++); assignment.constant(0, row) = 1446; - auto mul_result = assignment.add_input_to_batch( + auto mul_result = assignment.add_input_to_batch_assignment( {assignment.add_batch_constant_variable(1), assignment.add_batch_constant_variable(2)}); // have to check lookup functionality manually - assignment.add_input_to_batch({public_input_var_maker(), mul_result.output}); - assignment.add_input_to_batch({mul_by_const_result.output, public_input_var_maker()}); + assignment.add_input_to_batch_assignment({public_input_var_maker(), mul_result.output}); + assignment.add_input_to_batch_assignment({mul_by_const_result.output, public_input_var_maker()}); assignment.finalize_component_batches(circuit, row); assignment.finalize_constant_batches(circuit, 0); @@ -429,16 +429,16 @@ BOOST_AUTO_TEST_CASE(component_batch_params_test) { input_type input; input.arr.push_back(std::make_tuple( public_input_var_maker.binary_var(), public_input_var_maker(), public_input_var_maker())); - auto res_1 = assignment.add_input_to_batch(input, size_small); + auto res_1 = assignment.add_input_to_batch_assignment(input, size_small); input.arr = {}; input.arr.push_back(std::make_tuple( public_input_var_maker.binary_var(), public_input_var_maker(), public_input_var_maker())); input.arr.push_back(std::make_tuple( public_input_var_maker.binary_var(), public_input_var_maker(), public_input_var_maker())); - auto res_2 = assignment.add_input_to_batch(input, size_big); + auto res_2 = assignment.add_input_to_batch_assignment(input, size_big); input.arr = {}; input.arr.push_back({public_input_var_maker.binary_var(), res_1.output[0].first, res_2.output[0].second}); - auto res_3 = assignment.add_input_to_batch(input, size_small); + auto res_3 = assignment.add_input_to_batch_assignment(input, size_small); assignment.finalize_component_batches(circuit, 0); BOOST_CHECK_EQUAL(circuit.gates().size(), 2); @@ -475,19 +475,6 @@ BOOST_AUTO_TEST_CASE(component_batch_params_test) { BOOST_CHECK_EQUAL(gate_2.constraints[i], expected_constraints[i]); } - // pub_0_abs w_0_abs - // pub_0_abs_rot(1) w_1_abs - // pub_0_abs_rot(2) w_2_abs - // pub_0_abs_rot(9) w_5_abs - // w_3_abs w_6_abs - // w_4_abs_rot(1) w_7_abs - // pub_0_abs_rot(3) w_0_abs_rot(1) - // pub_0_abs_rot(4) w_1_abs_rot(1) - // pub_0_abs_rot(5) w_2_abs_rot(1) - // pub_0_abs_rot(6) w_5_abs_rot(1) - // pub_0_abs_rot(7) w_6_abs_rot(1) - // pub_0_abs_rot(8) w_7_abs_rot(1) - const std::vector expected_copy_constraints = { {var(0, 0, false, var::column_type::public_input), var(0, 0, false, var::column_type::witness)}, {var(0, 1, false, var::column_type::public_input), var(1, 0, false, var::column_type::witness)}, @@ -509,4 +496,34 @@ BOOST_AUTO_TEST_CASE(component_batch_params_test) { // circuit.export_circuit(std::cout); } +BOOST_AUTO_TEST_CASE(component_batch_generate_circuit_variant_basic_test) { + using curve_type = nil::crypto3::algebra::curves::vesta; + using field_type = typename curve_type::scalar_field_type; + + using assignment_type = assignment>; + using circuit_type = circuit>; + using ArithmetizationType = nil::crypto3::zk::snark::plonk_constraint_system; + + assignment_type assignment(15, 1, 1, 3); + circuit_type circuit; + public_input_var_maker public_input_var_maker(assignment); + + using multiplication_type = components::multiplication< + ArithmetizationType, field_type, nil::blueprint::basic_non_native_policy>; + + typename multiplication_type::input_type input_1 = {public_input_var_maker(), public_input_var_maker()}; + typename multiplication_type::input_type input_2 = {public_input_var_maker(), public_input_var_maker()}; + auto res_1 = assignment.add_input_to_batch_circuit(input_1); + auto res_2 = assignment.add_input_to_batch_circuit(input_2); + BOOST_ASSERT(var_value(assignment, res_1.output) == 0); + BOOST_ASSERT(var_value(assignment, res_2.output) == 0); + res_1 = assignment.add_input_to_batch_assignment(input_1); + BOOST_ASSERT(var_value(assignment, res_1.output) == var_value(assignment, input_1.x) * var_value(assignment, input_1.y)); + BOOST_ASSERT(var_value(assignment, res_1.output) != 0); + BOOST_ASSERT(var_value(assignment, res_2.output) == 0); + res_2 = assignment.add_input_to_batch_assignment(input_2); + BOOST_ASSERT(var_value(assignment, res_2.output) == var_value(assignment, input_2.x) * var_value(assignment, input_2.y)); + BOOST_ASSERT(var_value(assignment, res_2.output) != 0); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/test/test_plonk_component.hpp b/test/test_plonk_component.hpp index d918b802f..1edccbdcd 100644 --- a/test/test_plonk_component.hpp +++ b/test/test_plonk_component.hpp @@ -212,10 +212,10 @@ namespace nil { } } - auto component_result = boost::get( - assigner(component_instance, assignment, instance_input, start_row)); blueprint::components::generate_circuit( component_instance, bp, assignment, instance_input, start_row); + auto component_result = boost::get( + assigner(component_instance, assignment, instance_input, start_row)); // Stretched components do not have a manifest, as they are dynamically generated. if constexpr (!blueprint::components::is_component_stretcher<