-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
160 lines (149 loc) · 3.99 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
package main
import (
"context"
"encoding/csv"
"fmt"
"log"
"os"
"strconv"
"github.com/jackc/pgx/v5"
"github.com/olekukonko/tablewriter"
"main/colfi"
)
type Prediction struct {
Item string
Prediction float64
}
func main() {
u, i, r := loadRatings("host="+os.Getenv("PGHOST"), 10000000)
trainset, testset, err := colfi.DatasetsFromSlices(u, i, r, 0.2)
if err != nil {
log.Fatalf("error loading datasets: %v", err)
}
testParams := colfi.GridSearchParams{
NumEpochs: []int{20},
NumFactors: []int{25, 35, 45},
Reg: []float64{0.02},
LR: []float64{0.01},
InitStdDev: []float64{0.1},
}
results := colfi.GridSearch(trainset, testset, testParams)
var data [][]string
for _, r := range results {
row := []string{strconv.Itoa(r.NumEpochs), strconv.Itoa(r.NumFactors), fmt.Sprintf("%.3f", r.Reg), fmt.Sprintf("%.3f", r.LR), fmt.Sprintf("%.1f", r.InitStdDev), fmt.Sprintf("%.4f", r.Loss), fmt.Sprintf("%v", r.Runtime)}
data = append(data, row)
}
table := tablewriter.NewWriter(os.Stdout)
table.SetHeader([]string{"NumEpochs", "NumFactors", "Reg", "LR", "InitStdDev", "Loss", "Runtime"})
for _, v := range data {
table.Append(v)
}
table.Render()
}
func loadRatings(connString string, limit int) ([]string, []string, []float32) {
ctx := context.Background()
conn, err := pgx.Connect(ctx, connString)
if err != nil {
log.Fatalf("Unable to connect to database: %v\n", err)
}
defer conn.Close(ctx)
queryFilter := ""
queryFilter += ` INNER JOIN (SELECT film_id
FROM ratings
GROUP BY film_id
HAVING COUNT(*) >= 500) f
ON r.film_id = f.film_id`
if limit > 0 {
queryFilter += fmt.Sprintf(" LIMIT %d", limit)
}
rows, err := conn.Query(ctx, "SELECT r.user_name, r.film_id, r.rating FROM ratings r"+queryFilter)
if err != nil {
log.Fatalf("Unable to get ratings: %v\n", err)
}
j := 0
var u, i string
var r float64
var us, is []string
var rs []float32
for rows.Next() {
if j%1000000 == 0 {
log.Printf("loaded %d rows", j)
}
rows.Scan(&u, &i, &r)
us = append(us, u)
is = append(is, i)
rs = append(rs, float32(r))
j++
}
return us, is, rs
}
/*func main() {
dataset := colfi.NewDataset()
start := time.Now()
loadRatings(dataset, "host="+os.Getenv("PGHOST"), 100000)
loadRatings(dataset, "host="+os.Getenv("PGHOST")+" dbname=soothsayer", 0)
log.Printf("loading ratings took %s\n", time.Since(start))
config := colfi.SVDConfig{
NumFactors: 20,
Verbose: true,
}
m := colfi.NewSVD(dataset, &config)
start = time.Now()
m.Fit(5)
log.Printf("svd took %s\n", time.Since(start))
user := "freeth"
var predictions []Prediction
start = time.Now()
for film := range dataset.ItemMap {
predictions = append(predictions, Prediction{film, m.Predict(user, film)})
}
log.Printf("predictions took %s\n", time.Since(start))
sort.Slice(predictions, func(i, j int) bool {
return predictions[i].Prediction > predictions[j].Prediction
})
for _, pred := range predictions[:50] {
fmt.Printf("%s: %f\n", pred.Item, pred.Prediction)
}
}
func loadRatings(dataset *colfi.Dataset, connString string, limit int) {
ctx := context.Background()
conn, err := pgx.Connect(ctx, connString)
if err != nil {
log.Fatalf("Unable to connect to database: %v\n", err)
}
defer conn.Close(ctx)
queryFilter := ""
if limit > 0 {
queryFilter += fmt.Sprintf(" LIMIT %d", limit)
}
rows, err := conn.Query(ctx, "SELECT user_name, film_id, rating FROM ratings"+queryFilter)
if err != nil {
log.Fatalf("Unable to get ratings: %v\n", err)
}
j := 0
var u, i string
var r float64
for rows.Next() {
if j%1000000 == 0 {
log.Printf("loaded %d rows", j)
}
rows.Scan(&u, &i, &r)
dataset.Append(u, i, float32(r))
j++
}
}*/
func loadRatingsFromCSV(fileName string) *colfi.Dataset {
f, _ := os.Open(fileName)
defer f.Close()
r := csv.NewReader(f)
dataset := colfi.NewDataset()
for {
record, err := r.Read()
if err != nil {
break
}
rating, _ := strconv.ParseFloat(record[2], 32)
dataset.Append(record[0], record[1], float32(rating))
}
return dataset
}