-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmain.py
36 lines (26 loc) · 1.17 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
__author__ = 'zhenanye'
import tensorflow as tf
import joint_model as jm
import argparse
import config
if __name__=='__main__':
parser = argparse.ArgumentParser(description="Activity Recognition using Deep Multi-task Learning")
parser.add_argument('--test', type=int, default=0, help='select the test day. Max num is 6')
parser.add_argument('--version', type=str, help='model version')
parser.add_argument('--gpu', type=int, default=0, help='assign task to selected gpu')
args = parser.parse_args()
LOGS = "logs/"
log_path = LOGS + args.version + '/'
# joint model for Huynh dataset
def run_joint_model():
cfg = config.get_config()
cfg.gpu = args.gpu
X = tf.placeholder(dtype=tf.float32, shape=[None, cfg.c_win_size, cfg.s_win_size, cfg.channels])
YC = tf.placeholder(dtype=tf.float32, shape=[None, cfg.c_labels_num])
YS = tf.placeholder(dtype=tf.float32, shape=[None, cfg.c_win_size, cfg.s_labels_num])
model = jm.JointModel(X, YS, YC, cfg, log_path, args.version)
# load data
model.load_data(args.test)
model.build_model()
model.train_model()
run_joint_model()