forked from google/distributed_point_functions
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdistributed_point_function.h
899 lines (822 loc) · 38.8 KB
/
distributed_point_function.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_
#include <glog/logging.h>
#include <openssl/cipher.h>
#include <memory>
#include <type_traits>
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/meta/type_traits.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "dpf/aes_128_fixed_key_hash.h"
#include "dpf/distributed_point_function.pb.h"
#include "dpf/internal/proto_validator.h"
#include "dpf/internal/value_type_helpers.h"
namespace distributed_point_functions {
// Type trait for all supported types. Used to provide meaningful error messages
// in std::enable_if template guards.
template <typename T>
using is_supported_type = dpf_internal::is_supported_type<T>;
template <typename T>
constexpr bool is_supported_type_v = is_supported_type<T>::value;
// Converts a given Value to the template parameter T.
//
// Returns INVALID_ARGUMENT if the conversion fails.
template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>>
absl::StatusOr<T> FromValue(const Value& value) {
return dpf_internal::ValueTypeHelper<T>::FromValue(value);
}
// ToValue Converts the argument to a Value.
template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>>
Value ToValue(const T& input) {
return dpf_internal::ValueTypeHelper<T>::ToValue(input);
}
// ToValueType<T> Returns a `ValueType` message describing T.
template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>>
ValueType ToValueType() {
return dpf_internal::ValueTypeHelper<T>::ToValueType();
}
// Implements key generation and evaluation of distributed point functions.
// A distributed point function (DPF) is parameterized by an index `alpha` and a
// value `beta`. The key generation procedure produces two keys `k_a`, `k_b`.
// Evaluating each key on any point `x` in the DPF domain results in an additive
// secret share of `beta`, if `x == alpha`, and a share of 0 otherwise. This
// class also supports *incremental* DPFs that can additionally be evaluated on
// prefixes of points, resulting in different values `beta_i`for each prefix of
// `alpha`.
class DistributedPointFunction {
public:
// Creates a new instance of a distributed point function that can be
// evaluated only at the output layer.
//
// Returns INVALID_ARGUMENT if the parameters are invalid.
static absl::StatusOr<std::unique_ptr<DistributedPointFunction>> Create(
const DpfParameters& parameters);
// Creates a new instance of an *incremental* DPF that can be evaluated at
// multiple layers. Each parameter set in `parameters` should specify the
// domain size and element size at one of the layers to be evaluated, in
// increasing domain size order. Element sizes must be non-decreasing.
//
// Returns INVALID_ARGUMENT if the parameters are invalid.
static absl::StatusOr<std::unique_ptr<DistributedPointFunction>>
CreateIncremental(absl::Span<const DpfParameters> parameters);
// DistributedPointFunction is neither copyable nor movable.
DistributedPointFunction(const DistributedPointFunction&) = delete;
DistributedPointFunction& operator=(const DistributedPointFunction&) = delete;
// Converts the argument to a `Value` proto. Also registers the corresponding
// value type with the DPF by calling `RegisterValueType<T>()`.
template <typename T>
absl::StatusOr<Value> ToValue(const T& in) {
absl::Status status = RegisterValueType<T>();
if (!status.ok()) {
return status;
}
return distributed_point_functions::ToValue(in);
}
// Registers the template parameter type with this DPF. Note that it is rarely
// necessary to call this function by hand: It is called by `Create` and
// `CreateIncremental` for all unsigned integer types, including
// absl::uint128, and on every call to ToValue<T>. Only call this function
// when passing `Value`s created by other means than ToValue<T>.
//
// Returns OK on success and otherwise an INTERNAL status describing the
// failure.
template <typename T>
absl::Status RegisterValueType() {
return RegisterValueTypeImpl<T>(value_correction_functions_);
}
// Generates a pair of keys for a DPF that evaluates to `beta` when evaluated
// `alpha`. The type of `beta` must match the ValueType passed in `parameters`
// at construction.
//
// This function provides three overloads: One with `absl::uint128` for
// `beta`, which implies the output type is a simple integer; One with a
// `Value` proto for `beta`, which can be used for all supported value types;
// And a templated version that computes the Value by calling ToValue<T> on
// the argument.
//
// Example Usages (assuming a std::unique_ptr<DistributedPointFunction> dpf):
//
// // Simple integer:
// dpf->GenerateKeys(23, 42);
//
// // Explicit `Value` proto:
// Value value;
// value[1]->mutable_tuple->add_elements()
// ->mutable_integer->set_value_uint64(12);
// value[1]->mutable_tuple->add_elements()
// ->mutable_integer->set_value_uint64(34);
// // Must be called once before calling GenerateKeys for any type that is
// // not a simple integer. The type should match the one in the
// // DpfParameters passed at construction.
// dpf->RegisterValueType<Tuple<uint32_t, uint64_t>>();
// dpf->GenerateKeys(23, value);
//
// // Templated version (no call to RegisterValueType needed):
// dpf->GenerateKeys(23, Tuple<uint32_t, uint64_t>{12, 34});
//
// Returns INVALID_ARGUMENT if used on an incremental DPF with more
// than one set of parameters, if `alpha` is outside of the domain specified
// at construction, or if `beta` does not match the value type passed at
// construction.
// Returns FAILED_PRECONDITION if `RegisterValueType<T>` has not been called
// for the type in the `DpfParameters` passed at construction.
// Overload for simple integers.
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha,
absl::uint128 beta) {
return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&beta, 1));
}
// Overload for explicit Value proto.
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha,
Value beta) {
return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&beta, 1));
}
// Template for automatic conversion to Value proto. Disabled if the argument
// is convertible to `absl::uint128` or `Value` to make overloading
// unambiguous.
template <typename T, typename = absl::enable_if_t<
!std::is_convertible<T, absl::uint128>::value &&
!std::is_convertible<T, Value>::value &&
is_supported_type_v<T>>>
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha,
const T& beta) {
absl::StatusOr<Value> value = ToValue<T>(beta);
if (!value.ok()) {
return value.status();
}
return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&(*value), 1));
}
// Generates a pair of keys for an incremental DPF. For each parameter i
// passed at construction, the DPF evaluates to `beta[i]` at the lowest
// `parameters_[i].log_domain_size()` bits of `alpha`.
//
// Similar to `GenerateKeys`, supports three overloads: One for simple
// integers, passed as an `absl::Span<const absl::uint128>`; One for a span of
// `Value` protos; And a variadic function template that automatically
// converts the passed arguments to a vector of `Value`s.
//
// Example Usages (assuming a std::unique_ptr<DistributedPointFunction> dpf):
//
// // Simple integers:
// std::vector<absl::uint128> beta{123, 456};
// dpf->GenerateKeysIncremental(23, beta);
//
// // Explicit Value protos:
// std::vector<Value> beta(2);
// value[0]->mutable_integer()->set_value_uint128(42);
// value[1]->mutable_tuple->add_elements()
// ->mutable_integer->set_value_uint64(12);
// value[1]->mutable_tuple->add_elements()
// ->mutable_integer->set_value_uint64(34);
// // Must be called once before calling GenerateKeys for any type that is
// // not a simple integer. The type should match the one in the
// // DpfParameters passed at construction.
// dpf->RegisterValueType<Tuple<uint32_t, uint64_t>>();
// dpf->GenerateKeysIncremental(23, beta);
//
// // Templated version (equivalent to the one above):
// dpf->GenerateKeysIncremental(23, 42, Tuple<uint32_t, uint64_t>{12, 34}));
//
// Returns INVALID_ARGUMENT if `beta.size() != parameters_.size()`, if `alpha`
// is outside of the domain specified at construction, or if `beta` does not
// match the element type passed at construction.
// Returns FAILED_PRECONDITION if `RegisterValueType<T>` has not been called
// for all types in the `DpfParameters` passed at construction.
// Overload for simple integers.
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
absl::uint128 alpha, absl::Span<const absl::uint128> beta) {
std::vector<Value> values(beta.size());
for (int i = 0; i < static_cast<int>(beta.size()); ++i) {
absl::StatusOr<Value> value = ToValue(beta[i]);
if (!value.ok()) {
return value.status();
}
values[i] = std::move(*value);
}
return GenerateKeysIncremental(alpha, values);
}
// Overload for Value protos.
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
absl::uint128 alpha, absl::Span<const Value> beta);
// Variadic template version. Disabled if the first argument is convertible to
// a span of `absl::uint128`s or `Value`s to make overloading unambiguous.
template <
typename T0, typename... Tn,
typename = absl::enable_if_t<
!std::is_convertible<T0, absl::Span<const Value>>::value &&
!std::is_convertible<T0, absl::Span<const absl::uint128>>::value &&
absl::conjunction<is_supported_type<T0>,
is_supported_type<Tn>...>::value>>
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
absl::uint128 alpha, T0&& beta_0, Tn&&... beta_n);
// Returns an `EvaluationContext` for incrementally evaluating the given
// DpfKey.
//
// Returns INVALID_ARGUMENT if `key` doesn't match the parameters given at
// construction.
absl::StatusOr<EvaluationContext> CreateEvaluationContext(DpfKey key) const;
// Evaluates the given `hierarchy_level` of the DPF under all `prefixes`
// passed to this function. If `prefixes` is empty, evaluation starts from the
// seed of `ctx.key`. Otherwise, each element of `prefixes` must fit in the
// domain size of `ctx.previous_hierarchy_level`. Further, `prefixes` may only
// contain extensions of the prefixes passed in the previous call. For
// example, in the following sequence of calls, for each element p2 of
// `prefixes2`, there must be an element p1 of `prefixes1` such that p1 is a
// prefix of p2:
//
// DPF_ASSIGN_OR_RETURN(std::unique_ptr<EvaluationContext> ctx,
// dpf->CreateEvaluationContext(key));
// using T0 = ...;
// DPF_ASSIGN_OR_RETURN(std::vector<T0> evaluations0,
// dpf->EvaluateUntil(0, {}, *ctx));
//
// std::vector<absl::uint128> prefixes1 = ...;
// using T1 = ...;
// DPF_ASSIGN_OR_RETURN(std::vector<T1> evaluations1,
// dpf->EvaluateUntil(1, prefixes1, *ctx));
// ...
// std::vector<absl::uint128> prefixes2 = ...;
// using T2 = ...;
// DPF_ASSIGN_OR_RETURN(std::vector<T2> evaluations2,
// dpf->EvaluateUntil(3, prefixes2, *ctx));
//
// The prefixes are read from the lowest-order bits of the corresponding
// absl::uint128. The number of bits used for each prefix depends on the
// output domain size of the previously evaluated hierarchy level. For
// example, if `ctx` was last evaluated on a hierarchy level with output
// domain size 2**20, then the 20 lowest-order bits of each element in
// `prefixes` are used.
//
// Returns `INVALID_ARGUMENT` if
// - any element of `prefixes` is larger than the next hierarchy level's
// log_domain_size,
// - `prefixes` contains elements that are not extensions of previous
// prefixes, or
// - the bit-size of T doesn't match the next hierarchy level's
// element_bitsize.
template <typename T>
absl::StatusOr<std::vector<T>> EvaluateUntil(
int hierarchy_level, absl::Span<const absl::uint128> prefixes,
EvaluationContext& ctx) const;
template <typename T>
absl::StatusOr<std::vector<T>> EvaluateNext(
absl::Span<const absl::uint128> prefixes, EvaluationContext& ctx) const {
if (prefixes.empty()) {
return EvaluateUntil<T>(0, prefixes, ctx);
} else {
return EvaluateUntil<T>(ctx.previous_hierarchy_level() + 1, prefixes,
ctx);
}
}
// Evaluates a single key at one or multiple points, up to the given
// hierarchy_level. Each element of `evaluation_points` must be within the
// domain of this DPF at `hierarchy_level`.
//
// Example:
//
// DpfKey key = ...;
// std::vector<absl::uint128> evaluation_points = {1, 23, 42};
// // Evaluate `key` on {1, 23, 42}.
// DPF_ASSIGN_OR_RETURN(std::vector<T> result,
// dpf->EvaluateAt(key, 0, evaluation_points);
//
// Returns INVALID_ARGUMENT if `key` is malformed, or if `hierarchy_level` or
// any element of `evaluation_points` is out of range.
template <typename T>
absl::StatusOr<std::vector<T>> EvaluateAt(
const DpfKey& key, int hierarchy_level,
absl::Span<const absl::uint128> evaluation_points) const;
// Returns the DpfParameters of this DPF.
inline absl::Span<const DpfParameters> parameters() const {
return parameters_;
}
private:
// BitVector is a vector of bools. Allows for faster access times than
// std::vector<bool>, as well as inlining if the size is small.
using BitVector =
absl::InlinedVector<bool,
std::max<size_t>(1, sizeof(bool*) / sizeof(bool))>;
// Seeds and control bits resulting from a DPF expansion. This type is
// returned by `ExpandSeeds` and `ExpandAndUpdateContext`.
struct DpfExpansion {
std::vector<absl::uint128> seeds;
BitVector control_bits;
};
// A function for computing value corrections. Used as return type in
// `GetValueCorrectionFunction`.
using ValueCorrectionFunction = absl::StatusOr<std::vector<Value>> (*)(
absl::string_view, absl::string_view, int block_index, const Value&,
bool);
// Private constructor, called by `CreateIncremental`.
DistributedPointFunction(
std::unique_ptr<dpf_internal::ProtoValidator> proto_validator,
std::vector<int> blocks_needed, Aes128FixedKeyHash prg_left,
Aes128FixedKeyHash prg_right, Aes128FixedKeyHash prg_value,
absl::flat_hash_map<std::string, ValueCorrectionFunction>
value_correction_functions);
// Computes the value correction for the given `hierarchy_level`, `seeds`,
// index `alpha` and value `beta`. If `invert` is true, the individual values
// in the returned block are multiplied element-wise by -1. Expands `seeds`
// using `prg_ctx_value_`, then calls the function returned by
// `GetValueCorrectionFunction(parameters_[hierarchy_level])` to obtain the
// value correction words.
//
// Returns multiple values in the case of packing, and a single Value
// otherwise.
//
// Returns INTERNAL in case the PRG expansion fails, and UNIMPLEMENTED if
// `element_bitsize` is not supported.
absl::StatusOr<std::vector<Value>> ComputeValueCorrection(
int hierarchy_level, absl::Span<const absl::uint128> seeds,
absl::uint128 alpha, const Value& beta, bool invert) const;
// Expands the PRG seeds at the next `tree_level` for an incremental DPF with
// index `alpha` and values `beta`, updates `seeds` and `control_bits`, and
// writes the next correction word to `keys`. Called from
// `GenerateKeysIncremental`.
absl::Status GenerateNext(int tree_level, absl::uint128 alpha,
absl::Span<const Value> beta,
absl::Span<absl::uint128> seeds,
absl::Span<bool> control_bits,
absl::Span<DpfKey> keys) const;
// Computes the tree index (representing a path in the FSS tree) from the
// given `domain_index` and `hierarchy_level`. Does NOT check whether the
// given domain index fits in the domain at `hierarchy_level`.
absl::uint128 DomainToTreeIndex(absl::uint128 domain_index,
int hierarchy_level) const;
// Computes the block index (pointing to an element in a batched 128-bit
// block) from the given `domain_index` and `hierarchy_level`. Does NOT check
// whether the given domain index fits in the domain at `hierarchy_level`.
int DomainToBlockIndex(absl::uint128 domain_index, int hierarchy_level) const;
// Performs DPF evaluation of the given `partial_evaluations` using
// prg_ctx_left_ or prg_ctx_right_, and the given `correction_words`. At each
// level `l < correction_words.size()`, the evaluation for the i-th seed in
// `partial_evaluations` continues along the left or right path depending on
// the l-th most significant bit among the lowest `correction_words.size()`
// bits of `paths[i]`.
//
// Returns INVALID_ARGUMENT if the input sizes don't match.
// Returns INTERNAL in case of OpenSSL errors.
absl::StatusOr<DpfExpansion> EvaluateSeeds(
DpfExpansion partial_evaluations, absl::Span<const absl::uint128> paths,
absl::Span<const CorrectionWord* const> correction_words) const;
// Performs DPF expansion of the given `partial_evaluations` using
// prg_ctx_left_ and prg_ctx_right_, and the given `correction_words`. In more
// detail, each of the partial evaluations is subjected to a full subtree
// expansion of `correction_words.size()` levels, and the concatenated result
// is provided in the response. The result contains
// `(partial_evaluations.size() * (2^correction_words.size())` evaluations in
// a single `DpfExpansion`.
//
// Returns INTERNAL in case of OpenSSL errors.
absl::StatusOr<DpfExpansion> ExpandSeeds(
const DpfExpansion& partial_evaluations,
absl::Span<const CorrectionWord* const> correction_words) const;
// Computes partial evaluations of the paths to `prefixes` to be used as the
// starting point of the expansion of `ctx`. If `update_ctx == true`, saves
// the partial evaluations of `ctx.previous_hierarchy_level` to `ctx` and sets
// `ctx.partial_evaluations_level` to `ctx.previous_hierarchy_level`.
// Called by `ExpandAndUpdateContext`.
//
// Returns INVALID_ARGUMENT if any element of `prefixes` is not found in
// `ctx.partial_evaluations()`, or `ctx.partial_evaluations()` contains
// duplicate seeds.
absl::StatusOr<DpfExpansion> ComputePartialEvaluations(
absl::Span<const absl::uint128> prefixes, bool update_ctx,
EvaluationContext& ctx) const;
// Extracts the seeds for the given `prefixes` from `ctx` and expands them as
// far as needed for the next hierarchy level. Returns the result as a
// `DpfExpansion`. Called by `EvaluateUntil`, where the expanded seeds are
// corrected to obtain output values.
// After expansion, `ctx.hierarchy_level()` is increased. If this isn't the
// last expansion, the expanded seeds are also saved in `ctx` for the next
// expansion.
//
// Returns INVALID_ARGUMENT if any element of `prefixes` is not found in
// `ctx.partial_evaluations()`, or `ctx.partial_evaluations()` contains
// duplicate seeds. Returns INTERNAL in case of OpenSSL errors.
absl::StatusOr<DpfExpansion> ExpandAndUpdateContext(
int hierarchy_level, absl::Span<const absl::uint128> prefixes,
EvaluationContext& ctx) const;
// Compute output PRG value of expanded seeds using prg_ctx_value_.
// Returns blocks_needed_[hierarchy_level] * expansion.seeds.size() blocks,
// where every blocks_needed_[hierarchy_level] correspond to the hash of an
// input seed.
//
// Returns INTERNAL in case of OpenSSL errors.
absl::StatusOr<std::vector<absl::uint128>> HashExpandedSeeds(
int hierarchy_level, absl::Span<const absl::uint128> expansion) const;
// Deterministically serializes the given value_type.
//
// Returns OK on success and INTERNAL in case serialization fails.
static absl::StatusOr<std::string> SerializeValueTypeDeterministically(
const ValueType& value_type);
// Returns the value correction function for the given parameters.
// For all value types except unsigned integers, these functions have to be
// first registered using RegisterValueType<T>.
//
// Returns UNIMPLEMENTED if no matching function was registered.
absl::StatusOr<ValueCorrectionFunction> GetValueCorrectionFunction(
const DpfParameters& parameters) const;
// Static implementation of RegisterValueType<T>, so we can call it from
// `Create`.
template <typename T>
static absl::Status RegisterValueTypeImpl(
absl::flat_hash_map<std::string, ValueCorrectionFunction>&
value_correction_functions);
// Used to validate DpfParameters, DpfKey and EvaluationContext protos.
const std::unique_ptr<dpf_internal::ProtoValidator> proto_validator_;
// DP parameters passed to the factory function. Contains the domain size and
// element size for hierarchy level of the incremental DPF. Owned by
// proto_validator_.
const absl::Span<const DpfParameters> parameters_;
// Number of levels in the evaluation tree. This is always less than or equal
// to the largest log_domain_size in parameters_.
const int tree_levels_needed_;
// Maps levels of the FSS evaluation tree to hierarchy levels (i.e., elements
// of parameters_).
const absl::flat_hash_map<int, int>& tree_to_hierarchy_;
// The inverse of tree_to_hierarchy_.
const std::vector<int>& hierarchy_to_tree_;
// Cached numbers of AES blocks needed for value correction at each hierarchy
// level.
const std::vector<int> blocks_needed_;
// Pseudorandom generator used for seed expansion (left and right), and value
// correction. The PRG G(x) for hierarchy level i is defined as the
// concatenation of
//
// H_left(x), H_right(x), H_value(x + 0), ..., H_value(x + k-1)
//
// where k is equal to blocks_needed_[i], and H_*(x) is the evaluation of
// prg_*_ on input x.
const Aes128FixedKeyHash prg_left_;
const Aes128FixedKeyHash prg_right_;
const Aes128FixedKeyHash prg_value_;
// Maps serialized `ValueType` messages to the correct value correction
// functions. Map values are instantiations of
// `dpf_internal::ComputeValueCorrectionFor`. Relies on protobuf's
// deterministic serialization feature. This has the caveat that messages with
// unknown fields are not supported. However, as long as `ValueType` consists
// of a single `oneof` field, this is fine, since we either know the value
// type and have deterministic serialization because the `ValueType` can only
// contain one field, or we don't know the type and wouldn't be able to
// correct values for it anyway.
absl::flat_hash_map<std::string, ValueCorrectionFunction>
value_correction_functions_;
};
//========================//
// Implementation Details //
//========================//
template <typename T>
absl::Status DistributedPointFunction::RegisterValueTypeImpl(
absl::flat_hash_map<std::string, ValueCorrectionFunction>&
value_correction_functions) {
ValueType value_type = ToValueType<T>();
absl::StatusOr<std::string> serialized_value_type =
SerializeValueTypeDeterministically(value_type);
if (!serialized_value_type.ok()) {
return serialized_value_type.status();
}
value_correction_functions[*serialized_value_type] =
dpf_internal::ComputeValueCorrectionFor<T>;
return absl::OkStatus();
}
template <typename T0, typename... Tn, typename /*= absl::enable_if_t<...>*/>
absl::StatusOr<std::pair<DpfKey, DpfKey>>
DistributedPointFunction::GenerateKeysIncremental(absl::uint128 alpha,
T0&& beta_0, Tn&&... beta_n) {
// Convert the first element of beta. We need to treat it separately to be
// able to check its type in the enable_if above.
absl::StatusOr<Value> value = ToValue(beta_0);
if (!value.ok()) {
return value.status();
}
std::vector<Value> values = {std::move(*value)};
values.reserve(1 + sizeof...(beta_n));
// Convert all values in the parameter pack, stopping at the first error.
absl::Status status = absl::OkStatus();
// We create an unused std::tuple<Tn...> here, because its braced-initializer
// list constructor allows us to operate on beta_n in a well-defined order. In
// C++17, this could be replaced by a fold expression instead.
std::tuple<Tn...>{[this, &status, &values, &value](auto&& beta_i) -> Tn {
if (status.ok()) {
value = this->ToValue(beta_i);
if (value.ok()) {
values.push_back(std::move(*value));
} else {
status = value.status();
}
}
return Tn{};
}(beta_n)...};
// Return if there was an error during conversion, otherwise generate keys.
if (!status.ok()) {
return status;
}
return GenerateKeysIncremental(alpha, values);
}
template <typename T>
absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateUntil(
int hierarchy_level, absl::Span<const absl::uint128> prefixes,
EvaluationContext& ctx) const {
absl::Status status = proto_validator_->ValidateEvaluationContext(ctx);
if (!status.ok()) {
return status;
}
if (hierarchy_level < 0 ||
hierarchy_level >= static_cast<int>(parameters_.size())) {
return absl::InvalidArgumentError(
"`hierarchy_level` must be non-negative and less than "
"parameters_.size()");
}
if (parameters_[hierarchy_level].has_value_type()) {
absl::StatusOr<bool> types_are_equal = dpf_internal::ValueTypesAreEqual(
ToValueType<T>(), parameters_[hierarchy_level].value_type());
if (!types_are_equal.ok()) {
return types_are_equal.status();
} else if (!*types_are_equal) {
return absl::InvalidArgumentError(
"Value type T doesn't match parameters at `hierarchy_level`");
}
} else if (sizeof(T) * 8 != parameters_[hierarchy_level].element_bitsize()) {
return absl::InvalidArgumentError(
"Size of template parameter T doesn't match the element size of "
"`hierarchy_level`");
}
if (hierarchy_level <= ctx.previous_hierarchy_level()) {
return absl::InvalidArgumentError(
"`hierarchy_level` must be greater than "
"`ctx.previous_hierarchy_level`");
}
if ((ctx.previous_hierarchy_level() < 0) != (prefixes.empty())) {
return absl::InvalidArgumentError(
"`prefixes` must be empty if and only if this is the first call with "
"`ctx`.");
}
int previous_log_domain_size = 0;
int previous_hierarchy_level = ctx.previous_hierarchy_level();
if (!prefixes.empty()) {
DCHECK(ctx.previous_hierarchy_level() >= 0);
previous_log_domain_size =
parameters_[previous_hierarchy_level].log_domain_size();
for (absl::uint128 prefix : prefixes) {
if (previous_log_domain_size < 128 &&
prefix >= (absl::uint128{1} << previous_log_domain_size)) {
return absl::InvalidArgumentError(
absl::StrFormat("Index %d out of range for hierarchy level %d",
prefix, previous_hierarchy_level));
}
}
}
int64_t prefixes_size = static_cast<int64_t>(prefixes.size());
int log_domain_size = parameters_[hierarchy_level].log_domain_size();
if (log_domain_size - previous_log_domain_size > 62) {
return absl::InvalidArgumentError(
"Output size would be larger than 2**62. Please evaluate fewer "
"hierarchy levels at once.");
}
// The `prefixes` passed in by the caller refer to the domain of the previous
// hierarchy level. However, because we batch multiple elements of type T in a
// single uint128 block, multiple prefixes can actually refer to the same
// block in the FSS evaluation tree. On a high level, our approach is as
// follows:
//
// 1. Split up each element of `prefixes` into a tree index, pointing to a
// block in the FSS tree, and a block index, pointing to an element of type
// T in that block.
//
// 2. Compute a list of unique `tree_indices`, and for each original prefix,
// remember the position of the corresponding tree index in `tree_indices`.
//
// 3. After expanding the unique `tree_indices`, use the positions saved in
// Step (2) together with the corresponding block index to retrieve the
// expanded values for each prefix, and return them in the same order as
// `prefixes`.
//
// `tree_indices` holds the unique tree indices from `prefixes`, to be passed
// to `ExpandAndUpdateContext`.
std::vector<absl::uint128> tree_indices;
tree_indices.reserve(prefixes_size);
// `tree_indices_inverse` is the inverse of `tree_indices`, used for
// deduplicating and constructing `prefix_map`. Use a btree_map because we
// expect `prefixes` (and thus `tree_indices`) to be sorted.
absl::btree_map<absl::uint128, int64_t> tree_indices_inverse;
// `prefix_map` maps each i < prefixes.size() to an element of `tree_indices`
// and a block index. Used to select which elements to return after the
// expansion, to ensure the result is ordered the same way as `prefixes`.
std::vector<std::pair<int64_t, int>> prefix_map;
prefix_map.reserve(prefixes_size);
for (int64_t i = 0; i < prefixes_size; ++i) {
absl::uint128 tree_index =
DomainToTreeIndex(prefixes[i], previous_hierarchy_level);
int block_index = DomainToBlockIndex(prefixes[i], previous_hierarchy_level);
// Check if `tree_index` already exists in `tree_indices`.
int64_t previous_size = tree_indices_inverse.size();
auto it = tree_indices_inverse.try_emplace(tree_indices_inverse.end(),
tree_index, tree_indices.size());
if (tree_indices_inverse.size() > previous_size) {
tree_indices.push_back(tree_index);
}
prefix_map.push_back(std::make_pair(it->second, block_index));
}
// Perform expansion of unique `tree_indices`.
absl::StatusOr<DpfExpansion> expansion =
ExpandAndUpdateContext(hierarchy_level, tree_indices, ctx);
if (!expansion.ok()) {
return expansion.status();
}
// Hash the expanded seeds.
absl::StatusOr<std::vector<absl::uint128>> hashed_expansion =
HashExpandedSeeds(hierarchy_level, expansion->seeds);
if (!hashed_expansion.ok()) {
return hashed_expansion.status();
}
// Get output correction word from `ctx`.
constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>();
const ::google::protobuf::RepeatedPtrField<Value>* value_correction = nullptr;
if (hierarchy_level < static_cast<int>(parameters_.size()) - 1) {
value_correction =
&(ctx.key()
.correction_words(hierarchy_to_tree_[hierarchy_level])
.value_correction());
} else {
// Last level value correction is stored in an extra proto field, since we
// have one less correction word than tree levels.
value_correction = &(ctx.key().last_level_value_correction());
}
// Split output correction into elements of type T.
absl::StatusOr<std::array<T, elements_per_block>> correction_ints =
dpf_internal::ValuesToArray<T>(*value_correction);
if (!correction_ints.ok()) {
return correction_ints.status();
}
// Compute value corrections for each block in `expanded_seeds`. We have to
// account for the fact that blocks might not be full (i.e., have less than
// elements_per_block elements).
const int corrected_elements_per_block =
1 << (parameters_[hierarchy_level].log_domain_size() -
hierarchy_to_tree_[hierarchy_level]);
const auto expansion_size = static_cast<int64_t>(expansion->seeds.size());
const int blocks_needed = blocks_needed_[hierarchy_level];
DCHECK(corrected_elements_per_block <= elements_per_block);
std::vector<T> corrected_expansion(expansion_size *
corrected_elements_per_block);
for (int64_t i = 0; i < expansion_size; ++i) {
std::array<T, elements_per_block> current_elements =
dpf_internal::ConvertBytesToArrayOf<T>(
absl::string_view(reinterpret_cast<const char*>(
&(*hashed_expansion)[i * blocks_needed]),
blocks_needed * sizeof(absl::uint128)));
for (int j = 0; j < corrected_elements_per_block; ++j) {
if (expansion->control_bits[i]) {
current_elements[j] += (*correction_ints)[j];
}
if (ctx.key().party() == 1) {
current_elements[j] = -current_elements[j];
}
corrected_expansion[i * corrected_elements_per_block + j] =
current_elements[j];
}
}
// Compute the number of outputs we will have. For each prefix, we will have a
// full expansion from the previous heirarchy level to the current heirarchy
// level.
DCHECK(log_domain_size - previous_log_domain_size < 63);
int64_t outputs_per_prefix = int64_t{1}
<< (log_domain_size - previous_log_domain_size);
if (prefixes.empty()) {
// If prefixes is empty (i.e., this is the first evaluation of `ctx`), just
// return the expansion.
DCHECK(static_cast<int>(corrected_expansion.size()) == outputs_per_prefix);
return corrected_expansion;
} else {
// Otherwise, only return elements under `prefixes`.
int blocks_per_tree_prefix = expansion->seeds.size() / tree_indices.size();
std::vector<T> result(prefixes_size * outputs_per_prefix);
for (int64_t i = 0; i < prefixes_size; ++i) {
int64_t prefix_expansion_start =
prefix_map[i].first * blocks_per_tree_prefix *
corrected_elements_per_block +
prefix_map[i].second * outputs_per_prefix;
std::copy_n(&corrected_expansion[prefix_expansion_start],
outputs_per_prefix, &result[i * outputs_per_prefix]);
}
return result;
}
}
template <typename T>
absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateAt(
const DpfKey& key, int hierarchy_level,
absl::Span<const absl::uint128> evaluation_points) const {
auto num_evaluation_points = static_cast<int64_t>(evaluation_points.size());
if (hierarchy_level < 0) {
return absl::InvalidArgumentError("`hierarchy_level` must be non-negative");
}
if (hierarchy_level >= static_cast<int>(parameters_.size())) {
return absl::InvalidArgumentError(
"`hierarchy_level` must be less than the number of parameters passed "
"at construction");
}
absl::Status status = proto_validator_->ValidateDpfKey(key);
if (!status.ok()) {
return status;
}
// Get output correction word from `key`.
constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>();
const ::google::protobuf::RepeatedPtrField<Value>* value_correction = nullptr;
if (hierarchy_level < static_cast<int>(parameters_.size()) - 1) {
value_correction =
&(key.correction_words(hierarchy_to_tree_[hierarchy_level])
.value_correction());
} else {
// Last level value correction is stored in an extra proto field, since we
// have one less correction word than tree levels.
value_correction = &(key.last_level_value_correction());
}
// Split output correction into elements of type T, and save it in
// correction_ints.
absl::StatusOr<std::array<T, elements_per_block>> correction_ints =
dpf_internal::ValuesToArray<T>(*value_correction);
if (!correction_ints.ok()) {
return correction_ints.status();
}
// Split up evaluation_points into tree indices and block indices, if we're
// operating on a packed type. Otherwise set `tree_indices` to
// `evaluation_points`.
std::vector<absl::uint128> maybe_recomputed_tree_indices(0);
absl::Span<const absl::uint128> tree_indices;
if (elements_per_block > 1) {
maybe_recomputed_tree_indices.reserve(num_evaluation_points);
for (int64_t i = 0; i < num_evaluation_points; ++i) {
maybe_recomputed_tree_indices.push_back(
DomainToTreeIndex(evaluation_points[i], hierarchy_level));
}
tree_indices = absl::MakeConstSpan(maybe_recomputed_tree_indices);
} else {
// This avoids copying the evaluation points when elements_per_block == 1.
tree_indices = evaluation_points;
}
// Extract seed and party for DPF evaluation.
absl::uint128 seed = absl::MakeUint128(key.seed().high(), key.seed().low());
bool party = key.party();
DpfExpansion inputs;
inputs.seeds.resize(num_evaluation_points, seed);
inputs.control_bits.resize(num_evaluation_points, party);
// Evaluate DPFs.
const int stop_level = hierarchy_to_tree_[hierarchy_level];
auto correction_words =
absl::MakeConstSpan(key.correction_words()).subspan(0, stop_level);
absl::StatusOr<DpfExpansion> evaluated_inputs =
EvaluateSeeds(std::move(inputs), tree_indices, correction_words);
if (!evaluated_inputs.ok()) {
return evaluated_inputs.status();
}
DCHECK(static_cast<int64_t>(evaluated_inputs->seeds.size()) ==
num_evaluation_points);
// Hash DPF evaluations.
absl::StatusOr<std::vector<absl::uint128>> hashed_expansion =
HashExpandedSeeds(hierarchy_level, evaluated_inputs->seeds);
if (!hashed_expansion.ok()) {
return hashed_expansion.status();
}
// Perform value correction.
std::vector<T> result;
result.reserve(num_evaluation_points);
const int blocks_needed = blocks_needed_[hierarchy_level];
for (int64_t i = 0; i < num_evaluation_points; ++i) {
std::array<T, elements_per_block> current_elements =
dpf_internal::ConvertBytesToArrayOf<T>(
absl::string_view(reinterpret_cast<const char*>(
&(*hashed_expansion)[i * blocks_needed]),
blocks_needed * sizeof(absl::uint128)));
int block_index = 0;
if (elements_per_block > 1) {
block_index = DomainToBlockIndex(evaluation_points[i], hierarchy_level);
}
result.push_back(current_elements[block_index]);
if (evaluated_inputs->control_bits[i]) {
result[i] += (*correction_ints)[block_index];
}
if (party == 1) {
result[i] = -result[i];
}
}
return result;
}
} // namespace distributed_point_functions
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_