-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathesm_inference.py
107 lines (91 loc) · 3.35 KB
/
esm_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import re
import torch
import pandas as pd
from lib.pipeline import Pipeline
from lib.utils import dict_tensor_to_num, read_initial_csv
from atpbind_main import ALL_PARAMS
params = [
{
'model': 'esm-t33-gearnet',
'weights': [[f'atpbind3d_esm-t33-gearnet_{i}.pt'] for i in range(5)],
},
{
'model': 'esm-t33-gearnet-adaboost-r90',
'weights': [
[f'atpbind3d_esm-t33-gearnet-adaboost-r90_{fold}_{i}.pt' for i in range(10)]
for fold in range(5)
],
},
{
'model': 'esm-t33-gearnet-resiboost-r90',
'weights': [
[f'atpbind3d_esm-t33-gearnet-resiboost-r90_{fold}_{i}.pt' for i in range(10)]
for fold in range(5)
],
}
]
def esm_inference(model_key, weights, weight_base, threshold, gpu):
model_config = ALL_PARAMS[model_key]
pipeline = Pipeline(
gpus=[gpu],
dataset='atpbind3d-esm',
model=model_config['model'],
model_kwargs={
'gpu': gpu,
**model_config['model_kwargs']
},
dataset_kwargs={
'to_slice': True,
'max_slice_length': 500,
'padding': 50,
}
)
preds = []
for weight in weights:
model_weight = torch.load(os.path.join(weight_base, weight))
pipeline.task.load_state_dict(model_weight, strict=False)
pred, target = pipeline.predict_and_target_dataset(pipeline.test_set, 500, 50)
preds.append(pred)
pred_sum = torch.zeros_like(preds[0])
for pred in preds:
pred_sum += pred
pred_sum /= len(preds)
result = pipeline.task.evaluate(pred_sum, target, threshold=threshold)
result = dict_tensor_to_num(result)
return result
def main(gpu):
results = []
for param in params:
print(f'Running {param["model"]}')
for fold, weights in enumerate(param['weights']):
# Searching for the best threshold computed according to the validation set
# which is stored in the below file in training process
stat_df = pd.read_csv(f'result/atpbind3d_stats.csv')
try:
threshold = stat_df[(stat_df['model_key'] == param['model']) & (
stat_df['valid_fold'] == fold)].iloc[0]['best_threshold']
print(f'Fold {fold}: found best threshold {threshold}')
except IndexError:
threshold = -1.0
print(f'Fold {fold}: best threshold not found, using default {threshold}')
# do inference
result = esm_inference(
model_key='esm-t33-gearnet',
weights=weights,
weight_base='weight',
threshold=threshold,
gpu=gpu,
)
print(f'Fold {fold}: result {result}')
record_df = read_initial_csv('result/atpbind3d_esm_stats.csv')
record_df = pd.concat([record_df, pd.DataFrame([
{'model': param['model'], 'fold': fold, 'threshold': threshold, **result}
])])
record_df.to_csv('result/atpbind3d_esm_stats.csv', index=False)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
args = parser.parse_args()
main(gpu=args.gpu)