diff --git a/app_svc.py b/app_svc.py index 4b18095..cd6e972 100644 --- a/app_svc.py +++ b/app_svc.py @@ -11,9 +11,9 @@ from pydub import AudioSegment import argparse # Load model and configuration -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") fp16 = False +device = None def load_models(args): global sr, hop_length, fp16 fp16 = args.fp16 @@ -433,5 +433,8 @@ def main(args): parser.add_argument("--config-path", type=str, help="Path to the config file", default=None) parser.add_argument("--share", type=str2bool, nargs="?", const=True, default=False, help="Whether to share the app") parser.add_argument("--fp16", type=str2bool, nargs="?", const=True, help="Whether to use fp16", default=True) + parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0) args = parser.parse_args() + cuda_target = f"cuda:{args.gpu}" if args.gpu else "cuda" + device = torch.device(cuda_target if torch.cuda.is_available() else "cpu") main(args) diff --git a/app_vc.py b/app_vc.py index 9ac37cc..3a049a9 100644 --- a/app_vc.py +++ b/app_vc.py @@ -12,8 +12,8 @@ import argparse # Load model and configuration -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") fp16 = False +device = None def load_models(args): global sr, hop_length, fp16 fp16 = args.fp16 @@ -386,5 +386,8 @@ def main(args): parser.add_argument("--config-path", type=str, help="Path to the config file", default=None) parser.add_argument("--share", type=str2bool, nargs="?", const=True, default=False, help="Whether to share the app") parser.add_argument("--fp16", type=str2bool, nargs="?", const=True, help="Whether to use fp16", default=True) + parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0) args = parser.parse_args() - main(args) + cuda_target = f"cuda:{args.gpu}" if args.gpu else "cuda" + device = torch.device(cuda_target if torch.cuda.is_available() else "cpu") + main(args) \ No newline at end of file diff --git a/real-time-gui.py b/real-time-gui.py index 0614ee8..db68b03 100644 --- a/real-time-gui.py +++ b/real-time-gui.py @@ -30,7 +30,7 @@ import torch from modules.commons import str2bool # Load model and configuration -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = None flag_vc = False @@ -328,7 +328,7 @@ def printt(strr, *args): class Config: def __init__(self): - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = device if __name__ == "__main__": @@ -1137,5 +1137,8 @@ def get_device_channels(self): parser.add_argument("--checkpoint-path", type=str, default=None, help="Path to the model checkpoint") parser.add_argument("--config-path", type=str, default=None, help="Path to the vocoder checkpoint") parser.add_argument("--fp16", type=str2bool, nargs="?", const=True, help="Whether to use fp16", default=True) + parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0) args = parser.parse_args() + cuda_target = f"cuda:{args.gpu}" if args.gpu else "cuda" + device = torch.device(cuda_target if torch.cuda.is_available() else "cpu") gui = GUI(args) \ No newline at end of file diff --git a/train.py b/train.py index a8cc54e..955664f 100644 --- a/train.py +++ b/train.py @@ -18,7 +18,6 @@ - class Trainer: def __init__(self, config_path, @@ -385,6 +384,7 @@ def main(args): max_epochs=args.max_epochs, save_interval=args.save_every, num_workers=args.num_workers, + device=args.device ) trainer.train() @@ -399,5 +399,7 @@ def main(args): parser.add_argument('--max-epochs', type=int, default=1000) parser.add_argument('--save-every', type=int, default=500) parser.add_argument('--num-workers', type=int, default=0) + parser.add_argument("--gpu", type=int, help="Which GPU id to use", default=0) args = parser.parse_args() + args.device = f"cuda:{args.gpu}" if args.gpu else "cuda:0" main(args) \ No newline at end of file