-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprocess_gbc.py
82 lines (73 loc) · 2.31 KB
/
process_gbc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
import argparse
import os
from omegaconf import OmegaConf
from hydra.utils import instantiate
from gbc.utils import setup_gbc_logger
from gbc.data import GbcGraphFull
from gbc.processing import local_process_data
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert JSONL/JSON/Parquet files to JSON/JSONL/Parquet."
)
parser.add_argument(
"--input_paths",
nargs="+",
default=None,
help="List of input files or directories.",
)
parser.add_argument(
"--input_formats",
nargs="+",
default=None,
help="List of input formats to look for (e.g., .json, .jsonl, .parquet).",
)
parser.add_argument(
"--save_format",
default=None,
help="Desired output format (e.g., .json, .jsonl, .parquet).",
)
parser.add_argument(
"--save_dir",
default=None,
help="Directory to save the converted files.",
)
parser.add_argument(
"--configs",
type=str,
nargs="*",
default=None,
help="List of configs to be used. Latter ones override former ones.",
)
args = parser.parse_args()
if args.configs is not None:
configs = []
for config in args.configs:
conf = OmegaConf.load(config)
configs.append(conf)
config = OmegaConf.merge(*configs)
else:
config = OmegaConf.create()
if "processing_config" in config:
config = config.processing_config
config = instantiate(config)
for key, value in vars(args).items():
if key == "configs":
continue
if value is not None:
config[key] = value
assert key in config, f"{key} not found in neither args nor config"
setup_gbc_logger()
os.makedirs(config.save_dir, exist_ok=True)
data_transform = config.get("data_transform", None)
name_transform = config.get("name_transform", None)
local_process_data(
config.input_paths,
save_dir=config.save_dir,
save_format=config.save_format,
input_formats=config.input_formats,
data_class=GbcGraphFull,
data_transform=data_transform,
name_transform=name_transform,
)