Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

relation-tail-merging-1.0 #9

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 221 additions & 23 deletions mindwalc/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

class Vertex(object):

def __init__(self, name, predicate=False, _from=None, _to=None):
def __init__(self, name, predicate=False, _from=None, _to=None, relation_modified=False):
self.name = name
self.predicate = predicate
self.relation_modified = relation_modified
self._from = _from
self._to = _to

Expand Down Expand Up @@ -115,41 +116,167 @@ def extract_neighborhood(self, instance, depth=8):
for neighbor in self.get_neighbors(v):
new_explore.add(neighbor)
to_explore = new_explore

return neighborhood

@staticmethod
def rdflib_to_graph(rdflib_g, label_predicates=[]):
def rdflib_to_graph(rdflib_g, label_predicates=[], relation_tail_merging=False, skip_literals=False):
'''
Converts an rdflib graph to a Graph object.
During the conversion, a multi-relation graph (head)-[relation]->(tail) (aka subject, predicate, object)is converted to a non-relational graph.
e.g. converting it to (head)-->(relation)-->(tail), or, if apply_relation_tail_merging is True, to (head)-->(relation_tail).

:param rdflib_g: An rdflib graph, e.g. loaded with rdflib.Graph().parse('file.n3')
:param label_predicates: a list of predicates that are used as labels, and should not be converted to edges?
:param relation_tail_merging: If true, relation-tail-merging is applioed, as described in the paper
"Investigating and Optimizing MINDWALC Node Classification to Extract Interpretable DTs from KGs":
The process of relation-tail merging works as follows: First, a specific tail node is
selected, t, as well as a set of nr relations of identical type, r, where the topological
form (*)-r->(t) is given. The process of relation-tail merging then involves inserting
a new node, rt, so that (*)-r->(t) turns into (*)-->(rt)-->(t). The new directional
edges, -->, are now typeless, and the new inserted node, rt, represents a relationmodified node and is
named accordingly in the form <type_of_r>_<name_of_t>.
:param skip_literals: If True, literals (=node properties/attributes) are skipped during the conversion.
Otherwise, they are converted to nodes. so that a node (n: {name: 'John'}) becomes (n)-->(name)-->(john).
:return: A Graph object of type datastructures::Graph
'''

kg = Graph()
for (s, p, o) in rdflib_g:

for (s, p, o) in rdflib_g:
if p not in label_predicates:

if skip_literals and isinstance(o, rdflib.term.Literal):
continue

# Literals are attribute values in RDF, for instance, a person’s name, the date of birth, height, etc.
if isinstance(s, rdflib.term.Literal) and not str(s):
s = "EmptyLiteral"
if isinstance(p, rdflib.term.Literal) and not str(p):
p = "EmptyLiteral"
if isinstance(o, rdflib.term.Literal) and not str(o):
o = "EmptyLiteral"

s = str(s)
p = str(p)
o = str(o)

if isinstance(s, rdflib.term.BNode):
s_v = Vertex(str(s), wildcard=True)
elif isinstance(s, rdflib.term.Literal):
s_v = Vertex(str(s), literal=True)
else:
s_v = Vertex(str(s))

if isinstance(o, rdflib.term.BNode):
o_v = Vertex(str(o), wildcard=True)
elif isinstance(s, rdflib.term.Literal):
o_v = Vertex(str(o), literal=True)
s_v = Vertex(s)

if relation_tail_merging:
o_v_relation_mod = Vertex(f"{p}_MODIFIED_{o}", relation_modified=True)
o_v = Vertex(o)
kg.add_vertex(s_v)
kg.add_vertex(o_v_relation_mod)
kg.add_vertex(o_v)
kg.add_edge(s_v, o_v_relation_mod)
kg.add_edge(o_v_relation_mod, o_v)
else:
o_v = Vertex(str(o))

p_v = Vertex(str(p), predicate=True, _from=s_v, _to=o_v)
kg.add_vertex(s_v)
kg.add_vertex(p_v)
kg.add_vertex(o_v)
kg.add_edge(s_v, p_v)
kg.add_edge(p_v, o_v)
o_v = Vertex(o)
p_v = Vertex(p, predicate=True, _from=s_v, _to=o_v)
kg.add_vertex(s_v)
kg.add_vertex(p_v)
kg.add_vertex(o_v)
kg.add_edge(s_v, p_v)
kg.add_edge(p_v, o_v)
return kg

def graph_to_neo4j(self, uri='bolt://localhost', user='neo4j', password='password'):
'''
Converts the graph to a neo4j database. Needs an empty running neo4j db.
:param uri: address where neo4j db is running
:param user: username of neo4j db
:param password: password of neo4j db
:return: None
'''

try:
from neo4j import GraphDatabase
except ImportError:
raise ImportError("Please install the neo4j-driver package to use this function.")
from tqdm import tqdm

use_nodes_for_predicates = True # if false, the predicates are used as edges. Otherwise as nodes.
relation_name = 'R'

driver = GraphDatabase.driver(uri, auth=(user, password))
with driver.session() as session:
# check if db is empty:
node_count = session.run("MATCH (n) return count(n)").single().value()
if node_count > 0:
print("Neo4j database is not empty, aborting graph to neo4h db convertion to avoid data loss.")
return

for v in self.vertices:
if not v.predicate:
# name = v.name.split('/')[-1]
name = v.name.replace("'", "")
session.run(f"CREATE (a:Node" + (":RelationModified" if v.relation_modified else "") +
" {name: '" + name + "'})") # .split(' ')[0] + '_' + vertex.__hash__()

for v in tqdm(self.vertices):
if not v.predicate:
# v_name = v.name.split('/')[-1]
v_name = v.name.replace("'", "")

node_type = "Node" + (":RelationModified" if v.relation_modified else "")

ids_v = [r["id(v)"] for r in
session.run(
"MATCH (v:" + node_type + " {name: '" + v_name + "'}) where not (v:Predicate) RETURN id(v)")]
if len(ids_v) == 0:
raise Exception(f"no id found for {v_name}")
elif len(ids_v) == 1:
id_v = ids_v[0]
else:
raise Exception(f"multiple ids found for {v_name}: {ids_v}")

for pred in self.get_neighbors(v):

if pred.predicate:
pred_name = "".join(
[c for c in pred.name.split('/')[-1].replace("#", "_").replace('-', '_') if
not c.isdigit()])
pred_name = pred_name[1:] if pred_name[0] in ["_", "-"] else pred_name

for obj in self.get_neighbors(pred):
# obj_name = obj.name.split('/')[-1]
obj_name = obj.name.replace("'", "")

ids_obj = [r["id(obj)"] for r in
session.run(
"MATCH (obj:Node {name: '" + obj_name + "'}) where not (obj:Predicate) RETURN id(obj)")]
if len(ids_obj) == 0:
raise Exception(f"no id found for {obj_name}")
elif len(ids_obj) == 1:
id_obj = ids_obj[0]
else:
raise Exception(f"multiple ids found for {obj_name}: {ids_obj}")

if use_nodes_for_predicates:
q = (f"MATCH (a), (b) WHERE ID(a)={id_v} AND ID(b)={id_obj} "
"MERGE (a)-[:") + relation_name + "]->(c:Predicate {name: '" + pred_name + "'})-[:" + relation_name + "]->(b)"
else:
q = f"MATCH (a), (b) WHERE ID(a)={id_v} AND ID(b)={id_obj} MERGE (a)-[:" + pred_name + "]->(b)"
session.run(q)

else:
obj_name = pred.name.replace("'", "")

ids_obj = [r["id(obj)"] for r in
session.run(
"MATCH (obj:Node {name: '" + obj_name + "'}) RETURN id(obj)")]
if len(ids_obj) == 0:
raise Exception(f"no id found for {obj_name}")
elif len(ids_obj) == 1:
id_obj = ids_obj[0]
else:
raise Exception(f"multiple ids found for {obj_name}: {ids_obj}")

q = f"MATCH (a), (b) WHERE ID(a)={id_v} AND ID(b)={id_obj} MERGE (a)-[:" + relation_name + "]->(b)"
session.run(q)

driver.close()

class Neighborhood(object):
def __init__(self):
Expand Down Expand Up @@ -326,3 +453,74 @@ def _convert_node_to_dot(self, node_vis_props):
s += 'Node' + str(num) + ' -> ' + 'Node' + str(num + amount_subnodes_left + 1) + ' [label="true"];\n'

return s

if __name__ == "__main__":
from tree_builder import MINDWALCTree, MINDWALCForest, MINDWALCTransform
import pandas as pd
from sklearn.metrics import accuracy_score, confusion_matrix
import sys

# load graph:
rdf_file = 'data/AIFB/aifb.n3'
_format = 'n3'
label_predicates = [ # these predicates will be deleted, otherwise clf task would get to easy?
rdflib.URIRef('http://swrc.ontoware.org/ontology#affiliation'),
rdflib.URIRef('http://swrc.ontoware.org/ontology#employs'),
rdflib.URIRef('http://swrc.ontoware.org/ontology#carriedOutBy')
]
g = rdflib.Graph()
g.parse(rdf_file, format=_format)
skip_literals = True
path_max_depth = 10

# load train data:
train_file = 'data/AIFB/AIFB_test.tsv'
test_file = 'data/AIFB/AIFB_train.tsv'
entity_col = 'person'
label_col = 'label_affiliation'
test_data = pd.read_csv(train_file, sep='\t')
train_data = pd.read_csv(test_file, sep='\t')

train_entities = [rdflib.URIRef(x) for x in train_data[entity_col]]
train_labels = train_data[label_col]

test_entities = [rdflib.URIRef(x) for x in test_data[entity_col]]
test_labels = test_data[label_col]


# convert to non relational graphs using relation-to-node convertion:
kg = Graph.rdflib_to_graph(g, label_predicates=label_predicates, relation_tail_merging=False,
skip_literals=skip_literals)
#kg.graph_to_neo4j(password=sys.argv[1])
verts_a = len(kg.vertices)
edges_a = sum([len(x) for x in kg.transition_matrix.values()])
print(f"generated graph using relation-to-node-convertion has "
f"{str(float(verts_a)/1000).replace('.', ',')} vertices")
print(f"and {str(float(edges_a) / 1000).replace('.', ',')} edges")
clf = MINDWALCTree(path_max_depth=path_max_depth, min_samples_leaf=1, max_tree_depth=None, n_jobs=1)
clf.fit(kg, train_entities, train_labels)
clf.tree_.visualize("./data/AIFB/aifb_MINDWALCtree1", _view=False,
meta_infos="Training method: MINDWALCTree, relation-to-node-converted graph")
preds = clf.predict(kg, test_entities)
print(f"accuracy: {accuracy_score(test_labels, preds)}")

print()

# convert to non relational graphs using relation-tail-merging:
kg = Graph.rdflib_to_graph(g, label_predicates=label_predicates, relation_tail_merging=True,
skip_literals=skip_literals)
verts_b = len(kg.vertices)
edges_b = sum([len(x) for x in kg.transition_matrix.values()])
print(f"generated graph using relation_tail_merging has "
f"{str(float(verts_b)/1000).replace('.', ',')} vertices")
print(f"and {str(float(edges_b) / 1000).replace('.', ',')} edges")
clf = MINDWALCTree(path_max_depth=path_max_depth, min_samples_leaf=1, max_tree_depth=None, n_jobs=1)
clf.fit(kg, train_entities, train_labels)
clf.tree_.visualize("./data/AIFB/aifb_MINDWALCtree2", _view=False,
meta_infos="Training method: MINDWALCTree, relation-tail merged graph")
preds = clf.predict(kg, test_entities)
print(f"accuracy: {accuracy_score(test_labels, preds)}")

print(f"\nrelation-tail merging reduced the number of vertices by {verts_a - verts_b} ({round((verts_a - verts_b)/verts_a *100, 2)} %)")
print(f"relation-tail merging reduced the number of edges by {edges_a - edges_b} ({round((edges_a - edges_b) / edges_a * 100, 2)} %)")

5 changes: 3 additions & 2 deletions mindwalc/tree_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ def _generate_candidates(self, neighborhoods, sample_frac=None,
"""Generates an iterable with all possible walk candidates."""
# Generate a set of all possible (vertex, depth) combinations
walks = set()
for d in range(2, self.path_max_depth + 1, 2):
for neighborhood in neighborhoods:
#for d in range(2, self.path_max_depth + 1, 2):
for neighborhood in neighborhoods:
for d in neighborhood.depth_map.keys():
for vertex in neighborhood.depth_map[d]:
walks.add((vertex, d))

Expand Down