Skip to content

Commit

Permalink
EHN add numpy random number generator for reproducibility
Browse files Browse the repository at this point in the history
  • Loading branch information
htwangtw committed Jul 16, 2024
1 parent a8bfce0 commit 7a88325
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions general_class_balancer/general_class_balancer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
from scipy import stats
import random


def prime(i, primes):
Expand Down Expand Up @@ -69,7 +68,9 @@ def get_prime_form(confounds, n_buckets, sorted_confounds=None):


# Given buckets, selects values that fall into each one
def get_class_selection(classes, primed, unique_classes=None):
def get_class_selection(
classes, primed, unique_classes=None, rng=np.random.default_rng()
):
if len(classes) != len(primed):
raise ValueError("Classes and primed must be the same length")
if unique_classes is None:
Expand All @@ -79,7 +80,7 @@ def get_class_selection(classes, primed, unique_classes=None):
selection = np.zeros(classes.shape, dtype=bool)
hasher = {}
rr = list(range(len(classes)))
random.shuffle(rr)
rng.shuffle(rr)

for i in rr:
if True:
Expand Down Expand Up @@ -185,6 +186,7 @@ def class_balance(
recurse=True,
exclude_none=True,
unique_classes=None,
random_seed=None,
):
"""Main function.
Takes as input classes (as integers starting from 0 in a 1D numpy array)
Expand All @@ -198,6 +200,16 @@ def class_balance(
Method returns an array of logicals that selects a subset of the given
data, also forcing equal ratios between each class.
"""
# set random seed
if isinstance(random_seed, int) or random_seed is None:
if random_seed is not None and random_seed < 0:
random_seed = None
rng = np.random.default_rng(random_seed)
elif isinstance(random_seed, np.random.Generator):
rng = random_seed
else:
rng = np.random.default_rng()

classes = np.array(classes)
confounds = np.array(confounds)
if len(confounds) == 0:
Expand Down Expand Up @@ -259,10 +271,10 @@ def class_balance(
primed = get_prime_form(confounds, n_buckets, sorted_confounds)
primed = np.prod(primed, axis=0, dtype=int)
selection = get_class_selection(
classes, primed, unique_classes=unique_classes
classes, primed, unique_classes=unique_classes, rng=rng
)
rr = list(range(confounds.shape[0]))
random.shuffle(rr)
rng.shuffle(rr)
for i in rr:
# print("h " + str(i))
if not isinstance(confounds[i, 0], str):
Expand Down Expand Up @@ -298,6 +310,7 @@ def class_balance(
plim=plim,
exclude_none=False,
unique_classes=unique_classes,
random_seed=rng,
),
)
selection = np.logical_or(selection, recurse_selection)
Expand All @@ -310,12 +323,15 @@ def class_balance(
return selection


def separate_set(selections, set_divisions=[0.5, 0.5], IDs=None):
def separate_set(
selections, set_divisions=[0.5, 0.5], IDs=None, rng=np.random.default_rng()
):

if not isinstance(set_divisions, list):
raise TypeError
set_divisions = [i / np.sum(set_divisions) for i in set_divisions]
rr = list(range(len(selections)))
random.shuffle(rr)
rng.shuffle(rr)
if IDs is None:
IDs = np.array(list(range(len(selections))))
# if len(IDs.shape) == 1:
Expand Down

0 comments on commit 7a88325

Please sign in to comment.