-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathlukai.go
149 lines (118 loc) · 2.9 KB
/
lukai.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package lukai
import (
"context"
"log"
"os"
"sync"
"time"
lru "github.com/hashicorp/golang-lru"
"github.com/pkg/errors"
"golang.org/x/time/rate"
"github.com/luk-ai/lukai/debounce"
"github.com/luk-ai/lukai/protobuf/aggregatorpb"
"github.com/luk-ai/lukai/protobuf/clientpb"
"github.com/luk-ai/lukai/tf"
)
var (
ErrNotImplemented = errors.New("not implemented")
EdgeAddress = "dns://edge.luk.ai"
// ModelCacheSize controls how many training models are cached between
// training iterations.
ModelCacheSize = 3
DialTimeout = 60 * time.Second
outOfDateModelTimeout = 24 * time.Hour
// ErrorRateLimit controls how often the client should report errors to the
// server.
ErrorRateLimit = 1 * time.Minute
// MaxQueuedErrors controls how many errors can be queued before they start
// getting discarded.
MaxQueuedErrors = 10
)
type ModelType struct {
Domain, ModelType, DataDir string
prod struct {
sync.RWMutex
modelID aggregatorpb.ModelID
model *tf.Model
cache tfOpCache
lastUpdate time.Time
}
training struct {
sync.Mutex
running bool
stop context.CancelFunc
err error
}
examplesMeta struct {
sync.RWMutex
index clientpb.ExampleIndex
saveIndex func()
stop func()
err error
}
modelCache *lru.Cache
errorLimiter *rate.Limiter
errors struct {
sync.Mutex
errors []aggregatorpb.Error
}
ctx context.Context
cancel context.CancelFunc
}
// MakeModelType creates a new model type with a specified domain and model type
// and stores all training data in dataDir.
func MakeModelType(domain, modelType, dataDir string) (*ModelType, error) {
ctx, cancel := context.WithCancel(context.Background())
mt := ModelType{
Domain: domain,
ModelType: modelType,
DataDir: dataDir,
errorLimiter: rate.NewLimiter(rate.Every(ErrorRateLimit), 1),
ctx: ctx,
cancel: cancel,
}
if domain == "" {
return nil, errors.Errorf("domain required")
}
if modelType == "" {
return nil, errors.Errorf("modelType required")
}
if dataDir == "" {
return nil, errors.Errorf("dataDir required")
}
if err := os.MkdirAll(dataDir, DirPerm); err != nil {
return nil, err
}
mt.examplesMeta.saveIndex, mt.examplesMeta.stop = debounce.Debounce(
300*time.Millisecond,
func() {
if err := mt.saveExamplesMeta(); err != nil {
log.Printf("saveExamplesMeta error: %+v", err)
mt.examplesMeta.Lock()
defer mt.examplesMeta.Unlock()
mt.examplesMeta.err = err
return
}
},
)
var err error
mt.modelCache, err = lru.NewWithEvict(ModelCacheSize, func(key, val interface{}) {
val.(*tf.Model).Close()
})
if err != nil {
return nil, err
}
if err := mt.loadExamplesMeta(); err != nil {
return nil, err
}
go mt.gcLoop()
return &mt, nil
}
func (mt *ModelType) Close() error {
mt.examplesMeta.stop()
if err := mt.saveExamplesMeta(); err != nil {
return err
}
mt.cancel()
return nil
}