diff --git a/genparse/steer.py b/genparse/steer.py index c718a390..006349b9 100644 --- a/genparse/steer.py +++ b/genparse/steer.py @@ -29,9 +29,11 @@ def set_seed(seed): random.seed(seed) + np.random.seed(seed) torch.manual_seed(seed) transformers.set_seed(seed) - np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) # ____________________________________________________________________________________