diff --git a/mindwalc/datastructures.py b/mindwalc/datastructures.py index a69f0fb..30cb0fe 100644 --- a/mindwalc/datastructures.py +++ b/mindwalc/datastructures.py @@ -22,9 +22,10 @@ class Vertex(object): - def __init__(self, name, predicate=False, _from=None, _to=None): + def __init__(self, name, predicate=False, _from=None, _to=None, relation_modified=False): self.name = name self.predicate = predicate + self.relation_modified = relation_modified self._from = _from self._to = _to @@ -115,41 +116,167 @@ def extract_neighborhood(self, instance, depth=8): for neighbor in self.get_neighbors(v): new_explore.add(neighbor) to_explore = new_explore - + return neighborhood @staticmethod - def rdflib_to_graph(rdflib_g, label_predicates=[]): + def rdflib_to_graph(rdflib_g, label_predicates=[], relation_tail_merging=False, skip_literals=False): + ''' + Converts an rdflib graph to a Graph object. + During the conversion, a multi-relation graph (head)-[relation]->(tail) (aka subject, predicate, object)is converted to a non-relational graph. + e.g. converting it to (head)-->(relation)-->(tail), or, if apply_relation_tail_merging is True, to (head)-->(relation_tail). + + :param rdflib_g: An rdflib graph, e.g. loaded with rdflib.Graph().parse('file.n3') + :param label_predicates: a list of predicates that are used as labels, and should not be converted to edges? + :param relation_tail_merging: If true, relation-tail-merging is applioed, as described in the paper + "Investigating and Optimizing MINDWALC Node Classification to Extract Interpretable DTs from KGs": + The process of relation-tail merging works as follows: First, a specific tail node is + selected, t, as well as a set of nr relations of identical type, r, where the topological + form (*)-r->(t) is given. The process of relation-tail merging then involves inserting + a new node, rt, so that (*)-r->(t) turns into (*)-->(rt)-->(t). The new directional + edges, -->, are now typeless, and the new inserted node, rt, represents a relationmodified node and is + named accordingly in the form _. + :param skip_literals: If True, literals (=node properties/attributes) are skipped during the conversion. + Otherwise, they are converted to nodes. so that a node (n: {name: 'John'}) becomes (n)-->(name)-->(john). + :return: A Graph object of type datastructures::Graph + ''' + kg = Graph() - for (s, p, o) in rdflib_g: + for (s, p, o) in rdflib_g: if p not in label_predicates: + + if skip_literals and isinstance(o, rdflib.term.Literal): + continue + + # Literals are attribute values in RDF, for instance, a person’s name, the date of birth, height, etc. + if isinstance(s, rdflib.term.Literal) and not str(s): + s = "EmptyLiteral" + if isinstance(p, rdflib.term.Literal) and not str(p): + p = "EmptyLiteral" + if isinstance(o, rdflib.term.Literal) and not str(o): + o = "EmptyLiteral" + s = str(s) p = str(p) o = str(o) - if isinstance(s, rdflib.term.BNode): - s_v = Vertex(str(s), wildcard=True) - elif isinstance(s, rdflib.term.Literal): - s_v = Vertex(str(s), literal=True) - else: - s_v = Vertex(str(s)) - - if isinstance(o, rdflib.term.BNode): - o_v = Vertex(str(o), wildcard=True) - elif isinstance(s, rdflib.term.Literal): - o_v = Vertex(str(o), literal=True) + s_v = Vertex(s) + + if relation_tail_merging: + o_v_relation_mod = Vertex(f"{p}_MODIFIED_{o}", relation_modified=True) + o_v = Vertex(o) + kg.add_vertex(s_v) + kg.add_vertex(o_v_relation_mod) + kg.add_vertex(o_v) + kg.add_edge(s_v, o_v_relation_mod) + kg.add_edge(o_v_relation_mod, o_v) else: - o_v = Vertex(str(o)) - - p_v = Vertex(str(p), predicate=True, _from=s_v, _to=o_v) - kg.add_vertex(s_v) - kg.add_vertex(p_v) - kg.add_vertex(o_v) - kg.add_edge(s_v, p_v) - kg.add_edge(p_v, o_v) + o_v = Vertex(o) + p_v = Vertex(p, predicate=True, _from=s_v, _to=o_v) + kg.add_vertex(s_v) + kg.add_vertex(p_v) + kg.add_vertex(o_v) + kg.add_edge(s_v, p_v) + kg.add_edge(p_v, o_v) return kg + def graph_to_neo4j(self, uri='bolt://localhost', user='neo4j', password='password'): + ''' + Converts the graph to a neo4j database. Needs an empty running neo4j db. + :param uri: address where neo4j db is running + :param user: username of neo4j db + :param password: password of neo4j db + :return: None + ''' + + try: + from neo4j import GraphDatabase + except ImportError: + raise ImportError("Please install the neo4j-driver package to use this function.") + from tqdm import tqdm + + use_nodes_for_predicates = True # if false, the predicates are used as edges. Otherwise as nodes. + relation_name = 'R' + + driver = GraphDatabase.driver(uri, auth=(user, password)) + with driver.session() as session: + # check if db is empty: + node_count = session.run("MATCH (n) return count(n)").single().value() + if node_count > 0: + print("Neo4j database is not empty, aborting graph to neo4h db convertion to avoid data loss.") + return + + for v in self.vertices: + if not v.predicate: + # name = v.name.split('/')[-1] + name = v.name.replace("'", "") + session.run(f"CREATE (a:Node" + (":RelationModified" if v.relation_modified else "") + + " {name: '" + name + "'})") # .split(' ')[0] + '_' + vertex.__hash__() + + for v in tqdm(self.vertices): + if not v.predicate: + # v_name = v.name.split('/')[-1] + v_name = v.name.replace("'", "") + + node_type = "Node" + (":RelationModified" if v.relation_modified else "") + + ids_v = [r["id(v)"] for r in + session.run( + "MATCH (v:" + node_type + " {name: '" + v_name + "'}) where not (v:Predicate) RETURN id(v)")] + if len(ids_v) == 0: + raise Exception(f"no id found for {v_name}") + elif len(ids_v) == 1: + id_v = ids_v[0] + else: + raise Exception(f"multiple ids found for {v_name}: {ids_v}") + + for pred in self.get_neighbors(v): + + if pred.predicate: + pred_name = "".join( + [c for c in pred.name.split('/')[-1].replace("#", "_").replace('-', '_') if + not c.isdigit()]) + pred_name = pred_name[1:] if pred_name[0] in ["_", "-"] else pred_name + + for obj in self.get_neighbors(pred): + # obj_name = obj.name.split('/')[-1] + obj_name = obj.name.replace("'", "") + + ids_obj = [r["id(obj)"] for r in + session.run( + "MATCH (obj:Node {name: '" + obj_name + "'}) where not (obj:Predicate) RETURN id(obj)")] + if len(ids_obj) == 0: + raise Exception(f"no id found for {obj_name}") + elif len(ids_obj) == 1: + id_obj = ids_obj[0] + else: + raise Exception(f"multiple ids found for {obj_name}: {ids_obj}") + + if use_nodes_for_predicates: + q = (f"MATCH (a), (b) WHERE ID(a)={id_v} AND ID(b)={id_obj} " + "MERGE (a)-[:") + relation_name + "]->(c:Predicate {name: '" + pred_name + "'})-[:" + relation_name + "]->(b)" + else: + q = f"MATCH (a), (b) WHERE ID(a)={id_v} AND ID(b)={id_obj} MERGE (a)-[:" + pred_name + "]->(b)" + session.run(q) + + else: + obj_name = pred.name.replace("'", "") + + ids_obj = [r["id(obj)"] for r in + session.run( + "MATCH (obj:Node {name: '" + obj_name + "'}) RETURN id(obj)")] + if len(ids_obj) == 0: + raise Exception(f"no id found for {obj_name}") + elif len(ids_obj) == 1: + id_obj = ids_obj[0] + else: + raise Exception(f"multiple ids found for {obj_name}: {ids_obj}") + + q = f"MATCH (a), (b) WHERE ID(a)={id_v} AND ID(b)={id_obj} MERGE (a)-[:" + relation_name + "]->(b)" + session.run(q) + + driver.close() class Neighborhood(object): def __init__(self): @@ -326,3 +453,74 @@ def _convert_node_to_dot(self, node_vis_props): s += 'Node' + str(num) + ' -> ' + 'Node' + str(num + amount_subnodes_left + 1) + ' [label="true"];\n' return s + +if __name__ == "__main__": + from tree_builder import MINDWALCTree, MINDWALCForest, MINDWALCTransform + import pandas as pd + from sklearn.metrics import accuracy_score, confusion_matrix + import sys + + # load graph: + rdf_file = 'data/AIFB/aifb.n3' + _format = 'n3' + label_predicates = [ # these predicates will be deleted, otherwise clf task would get to easy? + rdflib.URIRef('http://swrc.ontoware.org/ontology#affiliation'), + rdflib.URIRef('http://swrc.ontoware.org/ontology#employs'), + rdflib.URIRef('http://swrc.ontoware.org/ontology#carriedOutBy') + ] + g = rdflib.Graph() + g.parse(rdf_file, format=_format) + skip_literals = True + path_max_depth = 10 + + # load train data: + train_file = 'data/AIFB/AIFB_test.tsv' + test_file = 'data/AIFB/AIFB_train.tsv' + entity_col = 'person' + label_col = 'label_affiliation' + test_data = pd.read_csv(train_file, sep='\t') + train_data = pd.read_csv(test_file, sep='\t') + + train_entities = [rdflib.URIRef(x) for x in train_data[entity_col]] + train_labels = train_data[label_col] + + test_entities = [rdflib.URIRef(x) for x in test_data[entity_col]] + test_labels = test_data[label_col] + + + # convert to non relational graphs using relation-to-node convertion: + kg = Graph.rdflib_to_graph(g, label_predicates=label_predicates, relation_tail_merging=False, + skip_literals=skip_literals) + #kg.graph_to_neo4j(password=sys.argv[1]) + verts_a = len(kg.vertices) + edges_a = sum([len(x) for x in kg.transition_matrix.values()]) + print(f"generated graph using relation-to-node-convertion has " + f"{str(float(verts_a)/1000).replace('.', ',')} vertices") + print(f"and {str(float(edges_a) / 1000).replace('.', ',')} edges") + clf = MINDWALCTree(path_max_depth=path_max_depth, min_samples_leaf=1, max_tree_depth=None, n_jobs=1) + clf.fit(kg, train_entities, train_labels) + clf.tree_.visualize("./data/AIFB/aifb_MINDWALCtree1", _view=False, + meta_infos="Training method: MINDWALCTree, relation-to-node-converted graph") + preds = clf.predict(kg, test_entities) + print(f"accuracy: {accuracy_score(test_labels, preds)}") + + print() + + # convert to non relational graphs using relation-tail-merging: + kg = Graph.rdflib_to_graph(g, label_predicates=label_predicates, relation_tail_merging=True, + skip_literals=skip_literals) + verts_b = len(kg.vertices) + edges_b = sum([len(x) for x in kg.transition_matrix.values()]) + print(f"generated graph using relation_tail_merging has " + f"{str(float(verts_b)/1000).replace('.', ',')} vertices") + print(f"and {str(float(edges_b) / 1000).replace('.', ',')} edges") + clf = MINDWALCTree(path_max_depth=path_max_depth, min_samples_leaf=1, max_tree_depth=None, n_jobs=1) + clf.fit(kg, train_entities, train_labels) + clf.tree_.visualize("./data/AIFB/aifb_MINDWALCtree2", _view=False, + meta_infos="Training method: MINDWALCTree, relation-tail merged graph") + preds = clf.predict(kg, test_entities) + print(f"accuracy: {accuracy_score(test_labels, preds)}") + + print(f"\nrelation-tail merging reduced the number of vertices by {verts_a - verts_b} ({round((verts_a - verts_b)/verts_a *100, 2)} %)") + print(f"relation-tail merging reduced the number of edges by {edges_a - edges_b} ({round((edges_a - edges_b) / edges_a * 100, 2)} %)") + diff --git a/mindwalc/tree_builder.py b/mindwalc/tree_builder.py index 5dead8d..b159602 100644 --- a/mindwalc/tree_builder.py +++ b/mindwalc/tree_builder.py @@ -47,8 +47,9 @@ def _generate_candidates(self, neighborhoods, sample_frac=None, """Generates an iterable with all possible walk candidates.""" # Generate a set of all possible (vertex, depth) combinations walks = set() - for d in range(2, self.path_max_depth + 1, 2): - for neighborhood in neighborhoods: + #for d in range(2, self.path_max_depth + 1, 2): + for neighborhood in neighborhoods: + for d in neighborhood.depth_map.keys(): for vertex in neighborhood.depth_map[d]: walks.add((vertex, d))