-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathlwe.go
99 lines (80 loc) · 2.08 KB
/
lwe.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
package client
import (
"errors"
"io"
"github.com/si-co/vpir-code/lib/database"
"github.com/si-co/vpir-code/lib/matrix"
"github.com/si-co/vpir-code/lib/utils"
)
// LEW based authenticated single server PIR client
// Client description
type LWE struct {
dbInfo *database.Info
state *StateLWE
params *utils.ParamsLWE
rnd io.Reader
}
type StateLWE struct {
A *matrix.Matrix
digest *matrix.Matrix
secret *matrix.Matrix
i int
j int
t uint32
}
func NewLWE(rnd io.Reader, info *database.Info, params *utils.ParamsLWE) *LWE {
return &LWE{
dbInfo: info,
params: params,
rnd: rnd,
}
}
func (c *LWE) Query(i, j int) *matrix.Matrix {
// Lazy way to sample a random scalar
rand := matrix.NewRandom(c.rnd, 1, 1)
// digest is already stored in the state when receiving the database info
c.state = &StateLWE{
A: matrix.NewRandom(utils.NewPRG(c.params.SeedA), c.params.N, c.params.L),
digest: c.dbInfo.DigestLWE,
secret: matrix.NewRandom(c.rnd, 1, c.params.N),
i: i,
j: j,
t: rand.Get(0, 0),
}
// Query has dimension 1 x l
query := matrix.Mul(c.state.secret, c.state.A)
// Error has dimension 1 x l
e := matrix.NewGauss(1, c.params.L)
msg := matrix.New(1, c.params.L)
msg.Set(0, i, c.state.t)
query.Add(e)
query.Add(msg)
return query
}
func (c *LWE) QueryBytes(index int) ([]byte, error) {
i, j := utils.VectorToMatrixIndices(index, c.dbInfo.NumColumns)
m := c.Query(i, j)
return matrix.MatrixToBytes(m), nil
}
func (c *LWE) Reconstruct(answers *matrix.Matrix) (uint32, error) {
s_trans_d := matrix.Mul(c.state.secret, c.state.digest)
answers.Sub(s_trans_d)
outs := make([]uint32, c.params.M)
for i := 0; i < c.params.M; i++ {
v := answers.Get(0, i)
if c.inRange(v) {
outs[i] = 0
} else if c.inRange(v - c.state.t) {
outs[i] = 1
} else {
return 0, errors.New("REJECT")
}
}
return outs[c.state.j], nil
}
func (c *LWE) ReconstructBytes(a []byte) (uint32, error) {
return c.Reconstruct(matrix.BytesToMatrix(a))
}
func (c *LWE) inRange(val uint32) bool {
return (val < c.params.B) || (val > -c.params.B)
}