Skip to content

Commit

Permalink
fix: changing the dtype in random_utils to uint32
Browse files Browse the repository at this point in the history
  • Loading branch information
init-22 committed Nov 14, 2024
1 parent adc5ea9 commit e16ebe0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions algorithmic_efficiency/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

# Annoyingly, RandomState(seed) requires seed to be in [0, 2 ** 32 - 1] (an
# unsigned int), while RandomState.randint only accepts and returns signed ints.
MAX_INT32 = 2**31
MIN_INT32 = -MAX_INT32
MAX_UINT32 = 2**31
MIN_UINT32 = 0

SeedType = Union[int, list, np.ndarray]

Expand All @@ -35,13 +35,13 @@ def _signed_to_unsigned(seed: SeedType) -> SeedType:

def _fold_in(seed: SeedType, data: Any) -> List[Union[SeedType, Any]]:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
new_seed = rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32)
new_seed = rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32)
return [new_seed, data]


def _split(seed: SeedType, num: int = 2) -> SeedType:
rng = np.random.RandomState(seed=_signed_to_unsigned(seed))
return rng.randint(MIN_INT32, MAX_INT32, dtype=np.int32, size=[num, 2])
return rng.randint(MIN_UINT32, MAX_UINT32, dtype=np.uint32, size=[num, 2])


def _PRNGKey(seed: SeedType) -> SeedType: # pylint: disable=invalid-name
Expand Down

0 comments on commit e16ebe0

Please sign in to comment.