Skip to content

Commit

Permalink
Merge pull request #221 from probcomp/100124-emilyaf-unique-entities
Browse files Browse the repository at this point in the history
Confine the unique_entities flag to where it's relevant in incorporate.
  • Loading branch information
emilyfertig authored Oct 4, 2024
2 parents b7a4b6b + 7aa4920 commit b846860
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 38 deletions.
81 changes: 50 additions & 31 deletions cxx/gendb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ void GenDB::incorporate(
bool new_rows_have_unique_entities) {
int id = row.first;

// TODO: Consider not walking the DAG when new_rows_have_unique_entities =
// True.

// Maps a query relation name to an observed value.
std::map<std::string, ObservationVariant> vals = row.second;

Expand All @@ -64,31 +61,64 @@ void GenDB::incorporate(
// Sample a set of items to be incorporated into the query relation.
const std::vector<std::string>& class_path =
schema.query.fields.at(query_rel).class_path;
T_items items = sample_entities_relation(
prng, schema.query.record_class, class_path.cbegin(), class_path.cend(),
id, new_rows_have_unique_entities);

T_items items;
if (new_rows_have_unique_entities) {
const std::vector<std::string>& domains = std::visit(
[&](auto tr) { return tr.domains; }, hirm->schema.at(query_rel));
items.resize(domains.size());
get_unique_entities_relation(query_rel, items.size() - 1, id, items);
} else {
items =
sample_entities_relation(prng, schema.query.record_class,
class_path.cbegin(), class_path.cend(), id);
}

// Incorporate the items/value into the query relation.
incorporate_query_relation(prng, query_rel, items, val);
}
}

void GenDB::get_unique_entities_relation(const std::string& rel_name,
const int ind, const int class_item,
T_items& items) {
const std::vector<std::string>& domains = std::visit(
[&](auto tr) { return tr.domains; }, hirm->schema.at(rel_name));
items[ind] = class_item;
auto& ref_indices = relation_reference_indices;
if (ref_indices.contains(rel_name)) {
if (ref_indices.at(rel_name).contains(ind)) {
for (const auto& [rf_name, rf_ind] : ref_indices.at(rel_name).at(ind)) {
if (!reference_values.at(domains[ind])
.contains({rf_name, class_item})) {
const std::string& ref_class = domains.at(rf_ind);
int new_val = entity_crps.at(ref_class).max_table() + 1;
int new_id = get_reference_id(domains[ind], rf_name, class_item);
reference_values.at(domains[ind])[{rf_name, class_item}] = new_val;
entity_crps.at(ref_class).incorporate(new_id, new_val);
}
int refval =
reference_values.at(domains[ind]).at({rf_name, class_item});
get_unique_entities_relation(rel_name, rf_ind, refval, items);
}
}
}
}

// This function walks the class_path of the query, populates the global
// reference_values table if necessary, and returns a sampled set of items
// for the query relation that corresponds to the class path. class_path_start
// is an attribute of the Class named class_name.
T_items GenDB::sample_entities_relation(
std::mt19937* prng, const std::string& class_name,
std::vector<std::string>::const_iterator class_path_start,
std::vector<std::string>::const_iterator class_path_end, int class_item,
bool new_rows_have_unique_entities) {
std::vector<std::string>::const_iterator class_path_end, int class_item) {
if (class_path_end - class_path_start == 1) {
// The last item in class_path is the class from which the queried attribute
// is observed (for which there's a corresponding clean relation, observing
// the attribute from the class). We need to DFS-traverse the class's
// parents, similar to PCleanSchemaHelper::compute_domains_for.
return sample_class_ancestors(prng, class_name, class_item,
new_rows_have_unique_entities);
return sample_class_ancestors(prng, class_name, class_item);
}

// These are noisy relation domains along the path from the latent cleanly-
Expand All @@ -102,13 +132,11 @@ T_items GenDB::sample_entities_relation(
.class_name;
std::pair<std::string, int> ref_key = {ref_field, class_item};
if (!reference_values.at(class_name).contains(ref_key)) {
sample_and_incorporate_reference(prng, class_name, ref_key, ref_class,
new_rows_have_unique_entities);
sample_and_incorporate_reference(prng, class_name, ref_key, ref_class);
}
T_items items = sample_entities_relation(
prng, ref_class, ++class_path_start, class_path_end,
reference_values.at(class_name).at(ref_key),
new_rows_have_unique_entities);
reference_values.at(class_name).at(ref_key));
// The order of the items corresponds to the order of the relation's domains,
// with the class (domain) corresponding to the primary key placed last on the
// list.
Expand All @@ -131,21 +159,15 @@ int GenDB::get_reference_id(const std::string& class_name,
// and stores the value in reference_values.
void GenDB::sample_and_incorporate_reference(
std::mt19937* prng, const std::string& class_name,
const std::pair<std::string, int>& ref_key, const std::string& ref_class,
bool new_rows_have_unique_entities) {
const std::pair<std::string, int>& ref_key, const std::string& ref_class) {
auto [ref_field, class_item] = ref_key;
int new_val;
if (new_rows_have_unique_entities) {
new_val = entity_crps[ref_class].max_table() + 1;
} else {
new_val = entity_crps[ref_class].sample(prng);
}
int new_val = entity_crps.at(ref_class).sample(prng);

// Generate a unique ID for the sample and incorporate it into the
// entity CRP.
int new_id = get_reference_id(class_name, ref_field, class_item);
reference_values.at(class_name)[ref_key] = new_val;
entity_crps[ref_class].incorporate(new_id, new_val);
entity_crps.at(ref_class).incorporate(new_id, new_val);
}

// Incorporates an observed value into a query relation. Recursively
Expand Down Expand Up @@ -180,7 +202,7 @@ void GenDB::sample_and_incorporate_for_class(std::mt19937* prng,
const std::string& class_name,
const T_item& item) {
for (const std::string& rel_name : class_to_relations.at(class_name)) {
sample_class_ancestors(prng, class_name, item, false);
sample_class_ancestors(prng, class_name, item);
const std::vector<std::string>& domains = std::visit(
[&](auto tr) { return tr.domains; }, hirm->schema.at(rel_name));
T_items rel_items(domains.size());
Expand Down Expand Up @@ -216,8 +238,7 @@ void GenDB::sample_and_incorporate_for_class(std::mt19937* prng,
// reference_values table/entity CRPs) if necessary.
T_items GenDB::sample_class_ancestors(std::mt19937* prng,
const std::string& class_name,
int class_item,
bool new_rows_have_unique_entities) {
int class_item) {
T_items items;
assert(schema.classes.contains(class_name));
PCleanClass c = schema.classes.at(class_name);
Expand All @@ -230,12 +251,10 @@ T_items GenDB::sample_class_ancestors(std::mt19937* prng,
if (!reference_values.at(class_name).contains(ref_key)) {
assert(prng != nullptr);
sample_and_incorporate_reference(prng, class_name, ref_key,
cv->class_name,
new_rows_have_unique_entities);
cv->class_name);
}
T_items ref_items = sample_class_ancestors(
prng, cv->class_name, reference_values.at(class_name).at(ref_key),
new_rows_have_unique_entities);
prng, cv->class_name, reference_values.at(class_name).at(ref_key));
items.insert(items.end(), ref_items.begin(), ref_items.end());
}
}
Expand Down Expand Up @@ -727,7 +746,7 @@ void GenDB::transition_reference(std::mt19937* prng,
// Sample and incorporate a new row into the ref_class table. Update
// reference_values and entity_crps.
T_items unused_base_items =
sample_class_ancestors(prng, ref_class, table, false);
sample_class_ancestors(prng, ref_class, table);

// Sample and incorporate values into the relations corresponding to
// the reference class. This may also incorporate new values into the IRM
Expand Down
13 changes: 7 additions & 6 deletions cxx/gendb.hh
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,18 @@ class GenDB {
void sample_and_incorporate_reference(

std::mt19937* prng, const std::string& class_name,
const std::pair<std::string, int>& ref_key, const std::string& ref_class,
bool new_rows_have_unique_entities);
const std::pair<std::string, int>& ref_key,
const std::string& ref_class);

void get_unique_entities_relation(const std::string& rel_name, const int ind,
const int class_item, T_items& items);

// Samples a set of entities in the domains of the relation corresponding to
// class_path.
T_items sample_entities_relation(
std::mt19937* prng, const std::string& class_name,
std::vector<std::string>::const_iterator class_path_start,
std::vector<std::string>::const_iterator class_path_end, int class_item,
bool new_rows_have_unique_entities);
std::vector<std::string>::const_iterator class_path_end, int class_item);

// Samples and incorporates a value into all relations belonging to class_name
// (including class attributes and noisy observations of ancestor class
Expand All @@ -62,8 +64,7 @@ class GenDB {

// Sample items from a class' ancestors (recursive reference fields).
T_items sample_class_ancestors(std::mt19937* prng,
const std::string& class_name, int class_item,
bool new_rows_have_unique_entities);
const std::string& class_name, int class_item);

// Populates "items" with entities by walking the DAG of reference indices,
// starting with "ind".
Expand Down
2 changes: 1 addition & 1 deletion cxx/pclean/pclean_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void make_pclean_sample(
T_items entities = gendb->sample_entities_relation(
prng, gendb->schema.query.record_class,
query_field.class_path.begin(), query_field.class_path.end(),
class_item, false);
class_item);

(*query_values)[query_field.name] = gendb->hirm->sample_and_incorporate_relation(
prng, query_field.name, entities);
Expand Down

0 comments on commit b846860

Please sign in to comment.