-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
executable file
·58 lines (43 loc) · 1.51 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from argparse import ArgumentError
import sys
from ae_anom import AEAnom
from ae_basic import AEBasic
from ae_gen import AEGen
from vae_anom import VAEAnom
from vae_basic import VAEBasic
from vae_gen import VAEGen
"""
Main function to run different tasks from
"""
def get_task(encoder_type, task_type, three_colors):
if encoder_type == "ae":
if task_type == "basic":
return AEBasic(
three_colors=three_colors) # AE-basic (mono and stacked)
elif task_type == "gen":
return AEGen(
three_colors=three_colors) # AE-gen (mono and stacked)
elif task_type == "anom":
return AEAnom(
three_colors=three_colors) # AE-anom (mono and stacked)
elif encoder_type == "vae":
if task_type == "basic":
return VAEBasic(
three_colors=three_colors) # VAE-basic (mono and stacked)
elif task_type == "gen":
return VAEGen(
three_colors=three_colors) # VAE-gen (mono and stacked)
elif task_type == "anom":
return VAEAnom(
three_colors=three_colors) # VAE-anom (mono and stacked)
else:
raise Exception()
if __name__ == "__main__":
encoder_type = str(sys.argv[1]).lower()
task_type = str(sys.argv[2]).lower()
if len(sys.argv) >= 4:
three_colors = ("stacked" == str(sys.argv[3]).lower())
else:
three_colors = False
task = get_task(encoder_type, task_type, three_colors)
task.run()