diff --git a/main_test_fbcnn_color.py b/main_test_fbcnn_color.py index a45ae25..4dffde9 100644 --- a/main_test_fbcnn_color.py +++ b/main_test_fbcnn_color.py @@ -44,8 +44,13 @@ def main(): logger = logging.getLogger(logger_name) logger.info('--------------- quality factor: {:d} ---------------'.format(quality_factor)) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') border = 0 + if torch.cuda.is_available(): + device = torch.device('cuda') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') # ---------------------------------------- diff --git a/main_test_fbcnn_color_real.py b/main_test_fbcnn_color_real.py index 82b0ab9..2e647e2 100644 --- a/main_test_fbcnn_color_real.py +++ b/main_test_fbcnn_color_real.py @@ -42,8 +42,13 @@ def main(): utils_logger.logger_info(logger_name, log_path=os.path.join(E_path, logger_name+'.log')) logger = logging.getLogger(logger_name) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') border = 0 + if torch.cuda.is_available(): + device = torch.device('cuda') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') # ---------------------------------------- @@ -83,7 +88,7 @@ def main(): #img_E,QF = model(img_L, torch.tensor([[0.6]])) img_E,QF = model(img_L) - QF = 1- QF + QF = 1 - QF img_E = util.tensor2single(img_E) img_E = util.single2uint(img_E) logger.info('predicted quality factor: {:d}'.format(round(float(QF*100)))) @@ -93,7 +98,12 @@ def main(): for QF_set in QF_control: logger.info('Flexible control by QF = {:d}'.format(QF_set)) # from IPython import embed; embed() - qf_input = torch.tensor([[1-QF_set/100]]).cuda() if device == torch.device('cuda') else torch.tensor([[1-QF_set/100]]) + if device == torch.device('cuda'): + qf_input = torch.tensor([[1-QF_set/100]]).cuda() + elif device == torch.device('mps'): + qf_input = torch.tensor([[1-QF_set/100]]).to('mps') + else: + qf_input = torch.tensor([[1-QF_set/100]]) img_E,QF = model(img_L, qf_input) QF = 1- QF img_E = util.tensor2single(img_E) diff --git a/main_test_fbcnn_gray.py b/main_test_fbcnn_gray.py index 180a402..069a184 100644 --- a/main_test_fbcnn_gray.py +++ b/main_test_fbcnn_gray.py @@ -43,8 +43,13 @@ def main(): logger = logging.getLogger(logger_name) logger.info('--------------- quality factor: {:d} ---------------'.format(quality_factor)) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') border = 0 + if torch.cuda.is_available(): + device = torch.device('cuda') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') # ---------------------------------------- diff --git a/main_test_fbcnn_gray_doublejpeg.py b/main_test_fbcnn_gray_doublejpeg.py index 401c02b..0a380dc 100644 --- a/main_test_fbcnn_gray_doublejpeg.py +++ b/main_test_fbcnn_gray_doublejpeg.py @@ -44,8 +44,13 @@ def main(): logger = logging.getLogger(logger_name) logger.info('--------------- QF1={:d}, QF2={:d} ---------------'.format(qf1,qf2)) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') border = 0 + if torch.cuda.is_available(): + device = torch.device('cuda') + elif torch.backends.mps.is_available(): + device = torch.device('mps') + else: + device = torch.device('cpu') # ----------------------------------------