Skip to content

Commit

Permalink
Update tests for clang-format
Browse files Browse the repository at this point in the history
  • Loading branch information
srvasude committed May 30, 2024
1 parent d58349c commit d796ca1
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 45 deletions.
24 changes: 12 additions & 12 deletions cxx/tests/test_hirm_animals.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "util_io.hh"
#include "util_math.hh"

int main(int argc, char **argv) {
int main(int argc, char** argv) {
srand(1);
std::mt19937 prng(1);

Expand All @@ -30,8 +30,8 @@ int main(int argc, char **argv) {
incorporate_observations(hirm, encoding_unary, observations_unary);
printf("--- incorporated observations --- \n");
int n_obs_unary = 0;
for (const auto &[z, irm] : hirm.irms) {
for (const auto &[r, relation] : irm->relations) {
for (const auto& [z, irm] : hirm.irms) {
for (const auto& [r, relation] : irm->relations) {
n_obs_unary += relation->data.size();
}
}
Expand All @@ -45,15 +45,15 @@ int main(int argc, char **argv) {
printf("--- set cluster assignments --- \n");
for (int i = 0; i < 20; i++) {
hirm.transition_cluster_assignments_all();
for (const auto &[t, irm] : hirm.irms) {
for (const auto& [t, irm] : hirm.irms) {
irm->transition_cluster_assignments_all();
for (const auto &[d, domain] : irm->domains) {
for (const auto& [d, domain] : irm->domains) {
domain->crp.transition_alpha();
}
}
hirm.crp.transition_alpha();
printf("%d %f [", i, hirm.logp_score());
for (const auto &[t, customers] : hirm.crp.tables) {
for (const auto& [t, customers] : hirm.crp.tables) {
printf("%ld ", customers.size());
}
printf("]\n");
Expand All @@ -71,7 +71,7 @@ int main(int argc, char **argv) {
std::string path_clusters = path_base + ".hirm";
to_txt(path_clusters, hirm, encoding_unary);

auto &enc = std::get<0>(encoding_unary);
auto& enc = std::get<0>(encoding_unary);

// Marginally normalized.
int persiancat = enc["animal"]["persiancat"];
Expand Down Expand Up @@ -118,16 +118,16 @@ int main(int argc, char **argv) {

assert(hirm.irms.size() == hirx.irms.size());
// Check IRMs agree.
for (const auto &[table, irm] : hirm.irms) {
for (const auto& [table, irm] : hirm.irms) {
auto irx = hirx.irms.at(table);
// Check log scores agree.
for (const auto &[d, dm] : irm->domains) {
for (const auto& [d, dm] : irm->domains) {
auto dx = irx->domains.at(d);
dx->crp.alpha = dm->crp.alpha;
}
assert(abs(irx->logp_score() - irm->logp_score()) < 1e-8);
// Check domains agree.
for (const auto &[d, dm] : irm->domains) {
for (const auto& [d, dm] : irm->domains) {
auto dx = irx->domains.at(d);
assert(dm->items == dx->items);
assert(dm->crp.assignments == dx->crp.assignments);
Expand All @@ -136,12 +136,12 @@ int main(int argc, char **argv) {
assert(dm->crp.alpha == dx->crp.alpha);
}
// Check relations agree.
for (const auto &[r, rm] : irm->relations) {
for (const auto& [r, rm] : irm->relations) {
auto rx = irx->relations.at(r);
assert(rm->data == rx->data);
assert(rm->data_r == rx->data_r);
assert(rm->clusters.size() == rx->clusters.size());
for (const auto &[z, clusterm] : rm->clusters) {
for (const auto& [z, clusterm] : rm->clusters) {
auto clusterx = rx->clusters.at(z);
assert(clusterm->N == clusterx->N);
}
Expand Down
20 changes: 10 additions & 10 deletions cxx/tests/test_irm_two_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "util_io.hh"
#include "util_math.hh"

int main(int argc, char **argv) {
int main(int argc, char** argv) {
std::string path_base = "assets/two_relations";
int seed = 1;
int iters = 2;
Expand All @@ -25,11 +25,11 @@ int main(int argc, char **argv) {
std::string path_schema = path_base + ".schema";
std::cout << "loading schema from " << path_schema << std::endl;
auto schema = load_schema(path_schema);
for (auto const &[relation_name, relation] : schema) {
for (auto const& [relation_name, relation] : schema) {
printf("relation: %s, ", relation_name.c_str());
printf("distribution: %s, ", relation.distribution.c_str());
printf("domains: ");
for (auto const &domain : relation.domains) {
for (auto const& domain : relation.domains) {
printf("%s ", domain.c_str());
}
printf("\n");
Expand All @@ -45,7 +45,7 @@ int main(int argc, char **argv) {
printf("running for %d iterations\n", iters);
for (int i = 0; i < iters; i++) {
irm.transition_cluster_assignments_all();
for (auto const &[d, domain] : irm.domains) {
for (auto const& [d, domain] : irm.domains) {
domain->crp.transition_alpha();
}
double x = irm.logp_score();
Expand Down Expand Up @@ -75,7 +75,7 @@ int main(int argc, char **argv) {
std::vector<std::vector<int>> indexes{
{code_item_0_D1, code_item_10_D1, code_item_novel},
{code_item_0_D1, code_item_10_D2, code_item_novel}};
for (const auto &l : product(indexes)) {
for (const auto& l : product(indexes)) {
assert(l.size() == 2);
auto x1 = l.at(0);
auto x2 = l.at(1);
Expand All @@ -88,7 +88,7 @@ int main(int argc, char **argv) {
assert(abs(exp(p0) - expected_p0[x1].at(x2)) < .1);
}

for (const auto &l :
for (const auto& l :
std::vector<std::vector<int>>{{0, 10, 100}, {110, 10, 100}}) {
auto x1 = l.at(0);
auto x2 = l.at(1);
Expand All @@ -104,14 +104,14 @@ int main(int argc, char **argv) {
IRM irx({}, &prng);
from_txt(&irx, path_schema, path_obs, path_clusters);
// Check log scores agree.
for (const auto &d : {"D1", "D2"}) {
for (const auto& d : {"D1", "D2"}) {
auto dm = irm.domains.at(d);
auto dx = irx.domains.at(d);
dx->crp.alpha = dm->crp.alpha;
}
assert(abs(irx.logp_score() - irm.logp_score()) < 1e-8);
// Check domains agree.
for (const auto &d : {"D1", "D2"}) {
for (const auto& d : {"D1", "D2"}) {
auto dm = irm.domains.at(d);
auto dx = irx.domains.at(d);
assert(dm->items == dx->items);
Expand All @@ -121,13 +121,13 @@ int main(int argc, char **argv) {
assert(dm->crp.alpha == dx->crp.alpha);
}
// Check relations agree.
for (const auto &r : {"R1", "R2"}) {
for (const auto& r : {"R1", "R2"}) {
auto rm = irm.relations.at(r);
auto rx = irx.relations.at(r);
assert(rm->data == rx->data);
assert(rm->data_r == rx->data_r);
assert(rm->clusters.size() == rx->clusters.size());
for (const auto &[z, clusterm] : rm->clusters) {
for (const auto& [z, clusterm] : rm->clusters) {
auto clusterx = rx->clusters.at(z);
assert(clusterm->N == clusterx->N);
}
Expand Down
44 changes: 22 additions & 22 deletions cxx/tests/test_misc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "util_io.hh"
#include "util_math.hh"

int main(int argc, char **argv) {
int main(int argc, char** argv) {
srand(1);
std::mt19937 prng(1);

Expand Down Expand Up @@ -68,10 +68,10 @@ int main(int argc, char **argv) {
printf("%f\n", crp.logp_score());
crp.incorporate(ali, 0);
std::cout << "tables count 10 " << crp.tables.count(10) << std::endl;
for (auto const &i : crp.tables[0]) {
for (auto const& i : crp.tables[0]) {
std::cout << i << " ";
}
for (auto const &i : crp.tables[1]) {
for (auto const& i : crp.tables[1]) {
std::cout << i << " ";
}
printf("\n");
Expand All @@ -82,18 +82,18 @@ int main(int argc, char **argv) {

printf("=== tables_weights\n");
auto tables_weights = crp.tables_weights();
for (auto &tw : tables_weights) {
for (auto& tw : tables_weights) {
printf("table %d weight %f\n", tw.first, tw.second);
}

printf("=== tables_weights_gibbs\n");
auto tables_weights_gibbs = crp.tables_weights_gibbs(1);
for (auto &tw : tables_weights_gibbs) {
for (auto& tw : tables_weights_gibbs) {
printf("table %d weight %f\n", tw.first, tw.second);
}
printf("==== tables_weights_gibbs_singleton\n");
auto tables_weights_gibbs_singleton = crp.tables_weights_gibbs(12);
for (auto &tw : tables_weights_gibbs_singleton) {
for (auto& tw : tables_weights_gibbs_singleton) {
printf("table %d weight %f\n", tw.first, tw.second);
}
printf("==== log probability\n");
Expand All @@ -106,17 +106,17 @@ int main(int argc, char **argv) {
T_item salman = 1;
T_item mansour = 2;
d.incorporate(salman);
for (auto &item : d.items) {
for (auto& item : d.items) {
printf("item %d: ", item);
}
d.set_cluster_assignment_gibbs(salman, 12);
d.incorporate(salman);
d.incorporate(mansour, 5);
for (auto &item : d.items) {
for (auto& item : d.items) {
printf("item %d: ", item);
}
// d.unincorporate(salman);
for (auto &item : d.items) {
for (auto& item : d.items) {
printf("item %d: ", item);
}
// d.unincorporate(relation2, salman);
Expand All @@ -126,9 +126,9 @@ int main(int argc, char **argv) {
std::unordered_map<int, std::unordered_set<int>> m;
m[1].insert(10);
m[1] = std::unordered_set<int>();
for (auto &ir : m) {
for (auto& ir : m) {
printf("%d\n", ir.first);
for (auto &x : ir.second) {
for (auto& x : ir.second) {
printf("%d\n", x);
}
}
Expand Down Expand Up @@ -193,14 +193,14 @@ int main(int argc, char **argv) {
};
IRM irm(schema1, &prng);

for (auto const &kv : irm.domains) {
for (auto const& kv : irm.domains) {
printf("%s %s; ", kv.first.c_str(), kv.second->name.c_str());
for (auto const r : irm.domain_to_relations.at(kv.first)) {
printf("%s ", r.c_str());
}
printf("\n");
}
for (auto const &kv : irm.relations) {
for (auto const& kv : irm.relations) {
printf("%s ", kv.first.c_str());
for (auto const d : kv.second->domains) {
printf("%s ", d->name.c_str());
Expand All @@ -210,11 +210,11 @@ int main(int argc, char **argv) {

printf("==== READING IO ===== \n");
auto schema = load_schema("assets/animals.binary.schema");
for (auto const &i : schema) {
for (auto const& i : schema) {
printf("relation: %s\n", i.first.c_str());
printf("distribution: %s\n", i.second.distribution.c_str());
printf("domains: ");
for (auto const &j : i.second.domains) {
for (auto const& j : i.second.domains) {
printf("%s ", j.c_str());
}
printf("\n");
Expand All @@ -224,15 +224,15 @@ int main(int argc, char **argv) {
auto observations = load_observations("assets/animals.binary.obs");
auto encoding = encode_observations(schema, observations);
auto item_to_code = std::get<0>(encoding);
for (auto const &i : observations) {
for (auto const& i : observations) {
auto relation = std::get<0>(i);
auto value = std::get<2>(i);
auto item = std::get<1>(i);
printf("incorporating %s ", relation.c_str());
printf("%1.f ", value);
int counter = 0;
T_items items_code;
for (auto const &item : std::get<1>(i)) {
for (auto const& item : std::get<1>(i)) {
auto domain = schema.at(relation).domains[counter];
counter += 1;
auto code = item_to_code.at(domain).at(item);
Expand All @@ -246,7 +246,7 @@ int main(int argc, char **argv) {
for (int i = 0; i < 4; i++) {
irm3.transition_cluster_assignments({"animal", "feature"});
irm3.transition_cluster_assignments_all();
for (auto const &[d, domain] : irm3.domains) {
for (auto const& [d, domain] : irm3.domains) {
domain->crp.transition_alpha();
}
double x = irm3.logp_score();
Expand All @@ -257,7 +257,7 @@ int main(int argc, char **argv) {
to_txt(path_clusters, irm3, encoding);

auto rel = irm3.relations.at("has");
auto &enc = std::get<0>(encoding);
auto& enc = std::get<0>(encoding);
auto lp0 = rel->logp({enc["animal"]["tail"], enc["animal"]["bat"]}, 0);
auto lp1 = rel->logp({enc["animal"]["tail"], enc["animal"]["bat"]}, 1);
auto lp_01 = logsumexp({lp0, lp1});
Expand All @@ -272,7 +272,7 @@ int main(int argc, char **argv) {
irm4.domains.at("animal")->crp.alpha = irm3.domains.at("animal")->crp.alpha;
irm4.domains.at("feature")->crp.alpha = irm3.domains.at("feature")->crp.alpha;
assert(abs(irm3.logp_score() - irm4.logp_score()) < 1e-8);
for (const auto &d : {"animal", "feature"}) {
for (const auto& d : {"animal", "feature"}) {
auto d3 = irm3.domains.at(d);
auto d4 = irm4.domains.at(d);
assert(d3->items == d4->items);
Expand All @@ -281,13 +281,13 @@ int main(int argc, char **argv) {
assert(d3->crp.N == d4->crp.N);
assert(d3->crp.alpha == d4->crp.alpha);
}
for (const auto &r : {"has"}) {
for (const auto& r : {"has"}) {
auto r3 = irm3.relations.at(r);
auto r4 = irm4.relations.at(r);
assert(r3->data == r4->data);
assert(r3->data_r == r4->data_r);
assert(r3->clusters.size() == r4->clusters.size());
for (const auto &[z, cluster3] : r3->clusters) {
for (const auto& [z, cluster3] : r3->clusters) {
auto cluster4 = r4->clusters.at(z);
assert(cluster3->N == cluster4->N);
}
Expand Down
2 changes: 1 addition & 1 deletion cxx/tests/test_util_math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

#include "util_math.hh"

int main(int argc, char **argv) {
int main(int argc, char** argv) {
std::vector<std::vector<int>> x{{1}, {2, 3}, {1, 10, 11}};

auto cartesian = product(x);
Expand Down

0 comments on commit d796ca1

Please sign in to comment.