-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
77 lines (61 loc) · 2.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from pyspark import *
from pyspark.sql import *
import sys
import json
from SciNeMCore.Graph import Graph
import time
from pyspark.sql.functions import col
import SciNeMCore.utils as utils
if len(sys.argv) != 2:
print("Usage: spark-submit main.py config.json", file=sys.stderr)
sys.exit(-1)
spark = SparkSession.builder.appName('SciNeM').getOrCreate()
# supress Spark INFO messages
log4j = spark._jvm.org.apache.log4j
log4j.LogManager.getRootLogger().setLevel(log4j.Level.WARN)
config_file = sys.argv[1]
with open(config_file) as fd:
config = json.load(fd)
nodes_dir = config["indir"]
relations_dir = config["irdir"]
analyses = config["analyses"]
alpha = float(config["pr_alpha"]) if ("pr_alpha" in config) else None
tol = float(config["pr_tol"]) if ("pr_tol" in config) else None
edgesThreshold = int(config["edgesThreshold"])
hin_out = config["hin_out"]
join_hin_out = config["join_hin_out"]
ranking_out = config["ranking_out"]
communities_out = config["communities_out"]
queries = config["queries"]
community_detection_iter = int(config["maxSteps"])
community_detection_algorithm = config["community_algorithm"]
transformation_algorithm = config["transformation_algorithm"] if ("transformation_algorithm" in config) else "MatrixMult"
verbose = True
if ("Ranking" in analyses or "Community Detection" in analyses) or ("Transformation" in analyses and transformation_algorithm == "MatrixMult"):
# In ranking and community detection a homegeneous graph is needed
graph = Graph()
res_hin = graph.transform(spark, queries, nodes_dir, relations_dir, verbose)
# apply filter in case of ranking and community detection
res_hin.filter(col("numberOfPaths") >= edgesThreshold)
res_hin.filter(col("src") != col("dst"))
# abort when resulted network contains no edges
if res_hin.non_zero() == 0:
sys.exit(100)
# write output hin to hdfs
res_hin.sort()
res_hin.write(hin_out)
if "Ranking" in analyses:
ranks = graph.pagerank(res_hin, alpha, tol)
#ranks.coalesce(1).map(utils.toCSVLine).saveAsTextFile(ranking_out)
# convert to DF to overwrite output
ranks.coalesce(1).toDF().write.csv(ranking_out, sep='\t', mode='overwrite')
if "Community Detection" in analyses and community_detection_algorithm == "LPA (GraphFrames)":
communities = graph.lpa(res_hin, community_detection_iter)
communities.coalesce(1).write.csv(communities_out, sep='\t', mode='overwrite')
verbose = False
if "Similarity Join" in analyses or "Similarity Search" in analyses:
graph = Graph()
res_hin = graph.transform(spark, queries, nodes_dir, relations_dir, verbose)
# write output hin to hdfs
#res_hin.write(join_hin_out)
graph.similarities(res_hin, config)