-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun_inference.sh
295 lines (273 loc) · 16 KB
/
run_inference.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
#!/bin/bash
[ -z "$1" ] && echo "First argument is the NL2SQL model name." && exit 1
MODEL_NAME="$1"
[ -z "$2" ] && echo "First argument is the test json file." && exit 1
TEST_FILE="$2"
[ -z "$3" ] && echo "Second argument is the schema file." && exit 1
TABLES_FILE="$3"
[ -z "$4" ] && echo "Third argument is the directory of the databases." && exit 1
DB_DIR="$4"
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
echo "===================================================================================================================================="
echo "INFO ****** MetaSQL Testing Pipeline Start ******"
echo "$TABLES_FILE" > params.txt
echo "$DB_DIR" >> params.txt
output=`python3 -m configs.get_config_for_test_bash`
OUTPUT_DIR_RERANKER=$(cut -d'@' -f1 <<< "$output")
RETRIEVAL_EMBEDDING_MODEL_NAME=$(cut -d'@' -f2 <<< "$output")
RERANKER_MODEL_DIR=$(cut -d'@' -f3 <<< "$output")
RERANKER_EMBEDDING_MODEL_NAME=$(cut -d'@' -f4 <<< "$output")
RERANKER_MODEL_NAME=$(cut -d'@' -f5 <<< "$output")
RERANKER_INPUT_FILE_NAME=$(cut -d'@' -f6 <<< "$output")
PRED_FILE_NAME=$(cut -d'@' -f7 <<< "$output")
RERANKER_MISS_FILE_NAME=$(cut -d'@' -f8 <<< "$output")
MODEL_TAR_GZ=$(cut -d'@' -f9 <<< "$output")
PRED_TOPK_FILE_NAME=$(cut -d'@' -f10 <<< "$output")
CANDIDATE_NUM=$(cut -d'@' -f11 <<< "$output")
MODE=$(cut -d'@' -f12 <<< "$output")
DEBUG=$(cut -d'@' -f13 <<< "$output")
CLASSIFIER_PREDS_FILE=$(cut -d'@' -f14 <<< "$output")
CLASSIFIER_MODEL_DIR=$(cut -d'@' -f15 <<< "$output")
META_DICT_FILE=$(cut -d'@' -f16 <<< "$output")
META_FORMAT_OUTPUT_FILE=$(cut -d'@' -f17 <<< "$output")
NL2SQL_META_PREDS_FILE=$(cut -d'@' -f18 <<< "$output")
NL2SQL_PREDS_FILE=$(cut -d'@' -f19 <<< "$output")
NL2SQL_META_MODEL_DIR=$(cut -d'@' -f20 <<< "$output")
NL2SQL_MODEL_DIR=$(cut -d'@' -f21 <<< "$output")
MODEL_BIN=$(cut -d'@' -f22 <<< "$output")
SCHEMA_CLASSIFIER_MODEL_DIR=$(cut -d'@' -f23 <<< "$output")
SERIALIZE_DATA_DIR=$(cut -d'@' -f24 <<< "$output")
USE_ORIGINAL_PREDS=$(cut -d'@' -f25 <<< "$output")
db_path="./data/database"
printf -v NL2SQL_MODEL_DIR "$NL2SQL_MODEL_DIR" $MODEL_NAME
printf -v NL2SQL_META_MODEL_DIR "$NL2SQL_META_MODEL_DIR" $MODEL_NAME
DATASET_NAME="spider"
TEST_META_FILE=$(dirname ${TEST_FILE})/meta_gap_test.json
EXPERIMENT_DIR_NAME=$OUTPUT_DIR_RERANKER/$DATASET_NAME\_$MODEL_NAME\_$CANDIDATE_NUM\_$RETRIEVAL_EMBEDDING_MODEL_NAME\_$RERANKER_EMBEDDING_MODEL_NAME\_$RERANKER_MODEL_NAME
if [ ! -d $EXPERIMENT_DIR_NAME ]; then
mkdir -p $EXPERIMENT_DIR_NAME
fi
RERANKER_INPUT_FILE=$EXPERIMENT_DIR_NAME/$RERANKER_INPUT_FILE_NAME
RERANKER_MODEL_FILE=$RERANKER_MODEL_DIR/$MODEL_TAR_GZ
RERANKER_MODEL_OUTPUT_FILE=$EXPERIMENT_DIR_NAME/$PRED_FILE_NAME
RERANKER_MODEL_OUTPUT_TOPK_FILE=$EXPERIMENT_DIR_NAME/$PRED_TOPK_FILE_NAME
RERANKER_MODEL_OUTPUT_SQL_FILE=${RERANKER_MODEL_OUTPUT_FILE/.txt/_sql.txt}
RERANKER_MODEL_OUTPUT_TOPK_SQL_FILE=${RERANKER_MODEL_OUTPUT_FILE/.txt/_sql_topk.txt}
EVALUATE_OUTPUT_FILE=${RERANKER_MODEL_OUTPUT_FILE/.txt/_evaluate.txt}
VALUE_FILTERED_OUTPUT_SQL_FILE=${RERANKER_MODEL_OUTPUT_FILE/.txt/_sql_value_filtered.txt}
VALUE_FILTERED_OUTPUT_TOPK_SQL_FILE=${RERANKER_MODEL_OUTPUT_FILE/.txt/_sql_topk_value_filtered.txt}
CLASSIFIER_MODEL_FILE=$CLASSIFIER_MODEL_DIR/$MODEL_BIN
LGESQL_META_MODEL_FILE=$NL2SQL_META_MODEL_DIR/$MODEL_BIN
LGESQL_MODEL_FILE=$NL2SQL_MODEL_DIR/$MODEL_BIN
# Get predictions from multi-label classification model
echo "INFO [Stage 1] Multi-label classification model inferencing ......"
if [ -f $CLASSIFIER_MODEL_FILE -a ! -f $CLASSIFIER_PREDS_FILE ]; then
python3 -m scripts.multi_label_classifier_infer_script --db_dir "$DB_DIR" --table_path "$TABLES_FILE" \
--dataset_path "$TEST_FILE" --saved_model "$CLASSIFIER_MODEL_DIR" --output_path "$CLASSIFIER_PREDS_FILE" --use_gpu || exit $?
python3 -m scripts.multi_label_classifier_output_format_script $CLASSIFIER_PREDS_FILE \
$META_DICT_FILE $META_FORMAT_OUTPUT_FILE || exit $?
echo "INFO Multi-label classification model inference complete & outputs formatting done!"
echo "===================================================================================================================================="
else
echo "WARNING \`$CLASSIFIER_MODEL_FILE\` not exists or \`$CLASSIFIER_PREDS_FILE\` exists."
fi
if [ "$USE_ORIGINAL_PREDS" != "True" ]; then
echo "!!!!"
fi
# Get the predictions from the LGESQL+Meta model
echo "INFO [Stage 2-(1)] NL2SQL+meta model inferencing ......"
if [ ! -f $NL2SQL_META_PREDS_FILE ]; then
case $MODEL_NAME in
"lgesql")
python3 -m nl2sql_models.lgesql.infer_with_meta_script --metadata_dict_path "$META_DICT_FILE" \
--metadata_output_path "$META_FORMAT_OUTPUT_FILE" --db_dir "$DB_DIR" --table_path "$TABLES_FILE" \
--saved_table_path "" --dataset_path "$TEST_FILE" --saved_model "$NL2SQL_META_MODEL_DIR" \
--output_path "$NL2SQL_META_PREDS_FILE" || exit $?
;;
"resdsql")
python -m nl2sql_models.resdsql.preprocessing --mode test --table_path $TABLES_FILE --input_dataset_path "./data/dev_for_resd.json" \
--output_dataset_path $EXPERIMENT_DIR_NAME/preprocessed_test.json --db_path $DB_DIR --target_type sql --metadata_dict_path "$META_DICT_FILE" \
--metadata_output_path "$META_FORMAT_OUTPUT_FILE" --meta 1
python -m nl2sql_models.resdsql.schema_item_classifier --batch_size 32 --device 0 --seed 42 \
--save_path "./saved_models/nl2sql_models/resdsql/resdsql_schema_item_classifier" \
--dev_filepath $EXPERIMENT_DIR_NAME/preprocessed_test.json \
--output_filepath $EXPERIMENT_DIR_NAME/test_with_probs.json \
--use_contents --add_fk_info --mode "test"
python -m nl2sql_models.resdsql.text2sql_data_generator \
--input_dataset_path $EXPERIMENT_DIR_NAME/test_with_probs.json \
--output_dataset_path $EXPERIMENT_DIR_NAME/resdsql_test.json \
--topk_table_num 4 --topk_column_num 5 --mode test --use_contents \
--add_fk_info --output_skeleton --target_type sql
python -m nl2sql_models.resdsql.text2sql \
--batch_size 4 --device 0 \
--seed 42 --save_path "./saved_models/nl2sql_models/resdsql/text2sql-t5-large/checkpoint-30576" \
--mode "eval" --dev_filepath $EXPERIMENT_DIR_NAME/resdsql_test.json \
--original_dev_filepath "./data/dev.json" \
--num_beams 12 \
--num_return_sequences 8 \
--target_type "sql" \
--output "$NL2SQL_META_PREDS_FILE" \
--db_path $db_path
;;
"gap")
python -m nl2sql_models.gap.generate_meta_data $TEST_FILE $TEST_META_FILE || exit $?
python -m nl2sql_models.gap.run preprocess nl2sql_models/gap/experiments/spider-configs/gap-meta-run.jsonnet || exit $?
TOKENIZERS_PARALLELISM=false python -m nl2sql_models.gap.run eval nl2sql_models/gap/experiments/spider-configs/gap-meta-run.jsonnet || exit $?
python -m nl2sql_models.gap.format_meta_output $NL2SQL_META_PREDS_FILE || exit $?
;;
*)
echo "unknown NL2SQL model!"
exit;
;;
esac
echo "INFO NL2SQL+meta model inference complete!"
echo "===================================================================================================================================="
else
echo "WARNING \`$LGESQL_META_MODEL_FILE\` not exists or \`$NL2SQL_META_PREDS_FILE\` exists."
fi
# Get the predictions from the original LGESQL model
echo "INFO [Stage 2-(2)] NL2SQL model inferencing ......"
if [ ! -f $NL2SQL_PREDS_FILE -a "$USE_ORIGINAL_PREDS" == "True" ]; then
case $MODEL_NAME in
"lgesql")
python3 -m nl2sql_models.lgesql.infer_script --db_dir "$DB_DIR" --table_path "$TABLES_FILE" \
--saved_table_path "output/tables.bin" --dataset_path "$TEST_FILE" \
--saved_model "$NL2SQL_MODEL_DIR" --output_path "$NL2SQL_PREDS_FILE" || exit $?
;;
"resdsql")
python -m nl2sql_models.resdsql.preprocessing --mode test --table_path $TABLES_FILE --input_dataset_path "./data/dev.json" \
--output_dataset_path $EXPERIMENT_DIR_NAME/preprocessed_test.json --db_path $DB_DIR --target_type sql --metadata_dict_path "$META_DICT_FILE" \
--metadata_output_path "$META_FORMAT_OUTPUT_FILE"
python -m nl2sql_models.resdsql.schema_item_classifier --batch_size 32 --device 0 --seed 42 \
--save_path "./saved_models/nl2sql_models/resdsql/resdsql_schema_item_classifier" \
--dev_filepath $EXPERIMENT_DIR_NAME/preprocessed_test.json \
--output_filepath $EXPERIMENT_DIR_NAME/test_with_probs.json \
--use_contents --add_fk_info --mode "test"
python -m nl2sql_models.resdsql.text2sql_data_generator \
--input_dataset_path $EXPERIMENT_DIR_NAME/test_with_probs.json \
--output_dataset_path $EXPERIMENT_DIR_NAME/resdsql_test.json \
--topk_table_num 4 --topk_column_num 5 --mode test --use_contents \
--add_fk_info --output_skeleton --target_type sql
python -m nl2sql_models.resdsql.text2sql \
--batch_size 4 --device 0 \
--seed 42 --save_path "./saved_models/nl2sql_models/resdsql/text2sql-t5-large/checkpoint-30576" \
--mode "eval" --dev_filepath $EXPERIMENT_DIR_NAME/resdsql_test.json \
--original_dev_filepath "./data/dev.json" \
--num_beams 12 \
--num_return_sequences 8 \
--target_type "sql" \
--output "$NL2SQL_PREDS_FILE" \
--db_path $db_path
;;
"gap")
python -m nl2sql_models.gap.run preprocess nl2sql_models/gap/experiments/spider-configs/gap-run.jsonnet || exit $?
TOKENIZERS_PARALLELISM=false python -m nl2sql_models.gap.run eval nl2sql_models/gap/experiments/spider-configs/gap-run.jsonnet || exit $?
python -m nl2sql_models.gap.format_output $NL2SQL_PREDS_FILE || exit $?
;;
*)
echo "unknown NL2SQL model!"
exit;
;;
esac
echo "INFO NL2SQL model inference complete!"
echo "===================================================================================================================================="
else
echo "WARNING \`$NL2SQL_MODEL_DIR\` not exists or \`$NL2SQL_PREDS_FILE\` exists."
fi
echo "INFO [Stage 2-(3)] Outputs Serialization ......"
if [ ! -d $SERIALIZE_DATA_DIR ]; then
python3 -m scripts.serialization_script $TEST_FILE $NL2SQL_META_PREDS_FILE \
$META_FORMAT_OUTPUT_FILE $TABLES_FILE $DB_DIR || exit $?
echo "INFO serialization complete!"
echo "===================================================================================================================================="
else
echo "WARNING \`$SERIALIZE_DATA_DIR\` exists."
fi
# Generate the input data for the re-ranking model
echo "INFO [Stage 3-(1)] First-stage ranking inferencing ......"
if [ ! -f $RERANKER_INPUT_FILE ]; then
python3 -m scripts.second_ranker_data_gen_script $DATASET_NAME $MODEL_NAME $RETRIEVAL_EMBEDDING_MODEL_NAME \
$TEST_FILE $NL2SQL_PREDS_FILE $TABLES_FILE $DB_DIR $CANDIDATE_NUM \
$MODE $DEBUG $RERANKER_INPUT_FILE || exit $?
echo "INFO First-stage ranking inference complete & Second-stage re-ranking data done!"
echo "===================================================================================================================================="
else
echo "WARNING \`$RERANKER_INPUT_FILE\` exists!"
fi
# Inference for top-1
echo "INFO [Stage 3-(2)] Second-stage re-ranking top-1 inferencing ......"
if [ ! -f $RERANKER_MODEL_OUTPUT_FILE ]; then
allennlp predict "$RERANKER_MODEL_FILE" "$RERANKER_INPUT_FILE" \
--output-file "$RERANKER_MODEL_OUTPUT_FILE" \
--file-friendly-logging --silent --predictor listwise-ranker --use-dataset-reader --cuda-device 0 \
--include-package allenmodels.dataset_readers.listwise_ranker_reader_distributed \
--include-package allenmodels.models.semantic_matcher.listwise_ranker_multigrained \
--include-package allenmodels.predictors.ranker_predictor || exit $?
echo "INFO Second-stage re-ranking inference complete!"
echo "===================================================================================================================================="
else
echo "WARNING \`$RERANKER_MODEL_FILE\` not exists or \`$RERANKER_MODEL_OUTPUT_FILE\` exists."
# exit;
fi
# Evaluate re-ranker model
echo "INFO [Stage 3-(3) Second-stage re-ranking (top-1 results) evaluating ......"
if [ -f $RERANKER_MODEL_OUTPUT_FILE -a ! -f $RERANKER_MODEL_OUTPUT_SQL_FILE ]; then
python3 -m scripts.second_ranker_evaluate $TABLES_FILE $DB_DIR $RERANKER_MODEL_OUTPUT_FILE \
$RERANKER_INPUT_FILE $EXPERIMENT_DIR_NAME || exit $?
echo "INFO Second-stage re-ranking evaluation complete!"
echo "===================================================================================================================================="
else
echo "WARNING \`RERANKER_MODEL_OUTPUT_FILE\` not exists or \`$RERANKER_MODEL_OUTPUT_SQL_FILE\` exists"
# exit;
fi
# Inference for top-k
echo "INFO [Stage 3-(4)] Second-stage re-ranking top-k inferencing ....."
if [ -f $RERANKER_MODEL_FILE -a ! -f $RERANKER_MODEL_OUTPUT_TOPK_FILE ]; then
allennlp predict "$RERANKER_MODEL_FILE" "$RERANKER_INPUT_FILE" \
--output-file "$RERANKER_MODEL_OUTPUT_TOPK_FILE" \
--file-friendly-logging --silent --predictor listwise-ranker --use-dataset-reader --cuda-device 0 \
--include-package allenmodels.dataset_readers.listwise_ranker_reader_distributed \
--include-package allenmodels.models.semantic_matcher.listwise_ranker_multigrained \
--include-package allenmodels.predictors.ranker_predictor_topk || exit $?
echo "INFO Second-stage re-ranking inference (top-k) complete!"
echo "===================================================================================================================================="
else
echo "WARNING \`$RERANKER_MODEL_FILE\` not exists or \`$RERANKER_MODEL_OUTPUT_TOPK_FILE\` exists."
# exit;
fi
# Evaluate for top-k
echo "INFO [Stage 3-(5)] Second-stage re-ranking (top-k results) evaluating ......"
if [ -f $RERANKER_MODEL_OUTPUT_TOPK_FILE -a ! -f $RERANKER_MODEL_OUTPUT_TOPK_SQL_FILE ]; then
python3 -m scripts.second_ranker_evaluate_topk $TABLES_FILE $DB_DIR \
$RERANKER_MODEL_OUTPUT_TOPK_FILE $RERANKER_INPUT_FILE $EXPERIMENT_DIR_NAME || exit $?
echo "INFO Second-stage re-ranking evaluation complete!"
echo "===================================================================================================================================="
else
echo "WARNING \`$RERANKER_MODEL_OUTPUT_TOPK_SQL_FILE\` exists"
# exit;
fi
# Post-processing
echo "INFO [Stage 3-(6)] Postprocessing ......"
if [ -f $RERANKER_MODEL_OUTPUT_SQL_FILE -a ! -f $VALUE_FILTERED_OUTPUT_TOPK_SQL_FILE ]; then
python3 -m scripts.postprocessing_script "$TEST_FILE" "$NL2SQL_PREDS_FILE" "$RERANKER_INPUT_FILE" \
"$RERANKER_MODEL_OUTPUT_TOPK_SQL_FILE" "$TABLES_FILE" "$DB_DIR" \
"$VALUE_FILTERED_OUTPUT_SQL_FILE" "$VALUE_FILTERED_OUTPUT_TOPK_SQL_FILE" || exit $?
echo "INFO Postprocessing complete!"
echo "===================================================================================================================================="
else
echo "WARNING \`$VALUE_FILTERED_OUTPUT_TOPK_SQL_FILE\` exist!"
exit;
fi
# # Final Evaluation
# echo "INFO Spider script evaluating ......"
# if [ -f $VALUE_FILTERED_OUTPUT_SQL_FILE -a ! -f $EVALUATE_OUTPUT_FILE ]; then
# python3 -m utils.spider_utils.evaluation.evaluate --gold "data/dev_gold.sql" --pred "$VALUE_FILTERED_OUTPUT_SQL_FILE" \
# --etype "match" --db "$DB_DIR" --table "$TABLES_FILE" \
# --candidates "$VALUE_FILTERED_OUTPUT_TOPK_SQL_FILE" > "$EVALUATE_OUTPUT_FILE"
# echo "Spider evaluation complete! Results are saved in \`$EVALUATE_OUTPUT_FILE\`"
# echo "===================================================================================================================================="
# else
# echo "\`$EVALUATE_OUTPUT_FILE\` exist!"
# echo "===================================================================================================================================="
# # exit
# fi