diff --git a/src/ml_model.py b/src/ml_model.py index 109282e..90c0eca 100755 --- a/src/ml_model.py +++ b/src/ml_model.py @@ -88,13 +88,13 @@ def load_model(self) -> None: return if torch.cuda.is_available(): - safety_checker = CustomStableDiffusionSafetyChecker.from_pretrained( - os.path.join(model_settings.model_name_or_path, "safety_checker"), torch_dtype=torch.float16 - ) + # safety_checker = CustomStableDiffusionSafetyChecker.from_pretrained( + # os.path.join(model_settings.model_name_or_path, "safety_checker"), torch_dtype=torch.float16 + # ) self.diffusion_pipeline = DiffusionPipeline.from_pretrained( model_settings.model_name_or_path, torch_dtype=torch.float16, - safety_checker=safety_checker, + # safety_checker=safety_checker, ).to("cuda") self.diffusion_pipeline.enable_xformers_memory_efficient_attention() self.dimm_scheduler = DDIMScheduler.from_pretrained(