diff --git a/general_class_balancer/general_class_balancer.py b/general_class_balancer/general_class_balancer.py index 415d233..030bab7 100644 --- a/general_class_balancer/general_class_balancer.py +++ b/general_class_balancer/general_class_balancer.py @@ -1,6 +1,5 @@ import numpy as np from scipy import stats -import random def prime(i, primes): @@ -69,7 +68,9 @@ def get_prime_form(confounds, n_buckets, sorted_confounds=None): # Given buckets, selects values that fall into each one -def get_class_selection(classes, primed, unique_classes=None): +def get_class_selection( + classes, primed, unique_classes=None, rng=np.random.default_rng() +): if len(classes) != len(primed): raise ValueError("Classes and primed must be the same length") if unique_classes is None: @@ -79,7 +80,7 @@ def get_class_selection(classes, primed, unique_classes=None): selection = np.zeros(classes.shape, dtype=bool) hasher = {} rr = list(range(len(classes))) - random.shuffle(rr) + rng.shuffle(rr) for i in rr: if True: @@ -185,6 +186,7 @@ def class_balance( recurse=True, exclude_none=True, unique_classes=None, + random_seed=None, ): """Main function. Takes as input classes (as integers starting from 0 in a 1D numpy array) @@ -198,6 +200,16 @@ def class_balance( Method returns an array of logicals that selects a subset of the given data, also forcing equal ratios between each class. """ + # set random seed + if isinstance(random_seed, int) or random_seed is None: + if random_seed is not None and random_seed < 0: + random_seed = None + rng = np.random.default_rng(random_seed) + elif isinstance(random_seed, np.random.Generator): + rng = random_seed + else: + rng = np.random.default_rng() + classes = np.array(classes) confounds = np.array(confounds) if len(confounds) == 0: @@ -259,10 +271,10 @@ def class_balance( primed = get_prime_form(confounds, n_buckets, sorted_confounds) primed = np.prod(primed, axis=0, dtype=int) selection = get_class_selection( - classes, primed, unique_classes=unique_classes + classes, primed, unique_classes=unique_classes, rng=rng ) rr = list(range(confounds.shape[0])) - random.shuffle(rr) + rng.shuffle(rr) for i in rr: # print("h " + str(i)) if not isinstance(confounds[i, 0], str): @@ -298,6 +310,7 @@ def class_balance( plim=plim, exclude_none=False, unique_classes=unique_classes, + random_seed=rng, ), ) selection = np.logical_or(selection, recurse_selection) @@ -310,12 +323,15 @@ def class_balance( return selection -def separate_set(selections, set_divisions=[0.5, 0.5], IDs=None): +def separate_set( + selections, set_divisions=[0.5, 0.5], IDs=None, rng=np.random.default_rng() +): + if not isinstance(set_divisions, list): raise TypeError set_divisions = [i / np.sum(set_divisions) for i in set_divisions] rr = list(range(len(selections))) - random.shuffle(rr) + rng.shuffle(rr) if IDs is None: IDs = np.array(list(range(len(selections)))) # if len(IDs.shape) == 1: