diff --git a/algoperf/random_utils.py b/algoperf/random_utils.py index f40a98003..a579976ad 100644 --- a/algoperf/random_utils.py +++ b/algoperf/random_utils.py @@ -16,32 +16,32 @@ FLAGS = flags.FLAGS -# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an +# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 31 - 1] (an # unsigned int), while RandomState.randint only accepts and returns signed ints. -MAX_UINT32 = 2**32 - 1 -MIN_UINT32 = 0 +MAX_INT32 = 2**31 - 1 +MIN_INT32 = 0 SeedType = Union[int, list, np.ndarray] def _signed_to_unsigned(seed: SeedType) -> SeedType: if isinstance(seed, int): - return seed % MAX_UINT32 + return seed % MAX_INT32 if isinstance(seed, list): - return [s % MAX_UINT32 for s in seed] + return [s % MAX_INT32 for s in seed] if isinstance(seed, np.ndarray): - return np.array([s % MAX_UINT32 for s in seed.tolist()]) + return np.array([s % MAX_INT32 for s in seed.tolist()]) def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32) + new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32) return [new_seed, data] def _split(seed: SeedType, num: int = 2) -> SeedType: rng = np.random.RandomState(seed=_signed_to_unsigned(seed)) - return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2]) + return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2]) def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name