Skip to content

Commit

Permalink
resolved underflow to nan issue when computing total probability and …
Browse files Browse the repository at this point in the history
…log likelihood with very small BMR PMF estimates such as with SKCM MUCH16_M or TTN_M.

procedure now replaces nan values in responsiblities during EM with very small value 2e-100 to represent the small probability for those samples.
additionally, i made new helper methods to get bmr pmf values for a specific count to default to a really small value of 1e-100 when the count is too large to be found in the categorical distribution.
case in which the count is 0 is still handled the same such that the probability of -1 defaults to 0.
  • Loading branch information
ashuaibi7 committed Jan 21, 2025
1 parent aee2f88 commit 51ec207
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 26 deletions.
24 changes: 20 additions & 4 deletions src/dialect/models/gene.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def __str__(self):
for k, v in itertools.islice(self.bmr_pmf.items(), 3)
) # Format the first three key-value pairs in bmr_pmf
pi_info = (
f"Pi: {self.pi:.3e}" if self.pi is not None else "Pi: Not estimated"
f"Pi: {self.pi:.3e}"
if self.pi is not None
else "Pi: Not estimated"
)
total_mutations = np.sum(self.counts)
return (
Expand Down Expand Up @@ -139,18 +141,32 @@ def compute_log_likelihood(self, pi):
**Raises**:
:raises ValueError: If `bmr_pmf`, `counts`, or `pi` is not properly defined.
"""

# logging.info(
# f"Computing log likelihood for gene {self.name}. Pi: {pi:.3e}. "
# f"BMR PMF: {{ {', '.join(f'{k}: {v:.3e}' for k, v in itertools.islice(self.bmr_pmf.items(), 3))} }}"
# )
# TODO: move to unified helper scripts file
def safe_get_no_default(pmf, c, min_val=1e-100):
# if c is greater than the max count in pmf
if c > max(pmf.keys()):
return min_val
return pmf.get(c)

def safe_get_with_default(pmf, c, min_val=1e-100):
# if c is greater than the max count in pmf
if c > max(pmf.keys()):
return min_val
return pmf.get(c, 0)

self.verify_pi_is_valid(pi)
self.verify_bmr_pmf_and_counts_exist()
self.verify_bmr_pmf_contains_all_count_keys()
# self.verify_bmr_pmf_contains_all_count_keys()

log_likelihood = sum(
np.log(
self.bmr_pmf.get(c) * (1 - pi) + self.bmr_pmf.get(c - 1, 0) * pi
safe_get_no_default(self.bmr_pmf, c) * (1 - pi)
+ safe_get_with_default(self.bmr_pmf, c - 1) * pi
)
for c in self.counts
)
Expand Down Expand Up @@ -284,7 +300,7 @@ def estimate_pi_with_em_from_scratch(
)

self.verify_bmr_pmf_and_counts_exist()
self.verify_bmr_pmf_contains_all_count_keys()
# self.verify_bmr_pmf_contains_all_count_keys()

nonzero_probability_counts = [
c for c in self.counts if c in self.bmr_pmf and self.bmr_pmf[c] > 0
Expand Down
90 changes: 68 additions & 22 deletions src/dialect/models/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def __init__(self, gene_a, gene_b):
:param gene_b (Gene): The second gene in the interaction.
"""
if not isinstance(gene_a, Gene) or not isinstance(gene_b, Gene):
raise ValueError("Both inputs must be instances of the Gene class.")
raise ValueError(
"Both inputs must be instances of the Gene class."
)

self.gene_a = gene_a
self.gene_b = gene_b
Expand Down Expand Up @@ -85,7 +87,9 @@ def compute_contingency_table(self):
"""
gene_a_mutations = (self.gene_a.counts > 0).astype(int)
gene_b_mutations = (self.gene_b.counts > 0).astype(int)
cm = confusion_matrix(gene_a_mutations, gene_b_mutations, labels=[0, 1])
cm = confusion_matrix(
gene_a_mutations, gene_b_mutations, labels=[0, 1]
)
return cm

def get_set_of_cooccurring_samples(self):
Expand Down Expand Up @@ -142,7 +146,7 @@ def verify_bmr_pmf_and_counts_exist(self):
if self.gene_a.counts is None or self.gene_b.counts is None:
raise ValueError("Counts are not defined for one or both genes.")

def verify_taus_are_valid(self, taus, tol=1e-6):
def verify_taus_are_valid(self, taus, tol=1e-2):
"""
Verify that tau parameters are valid (0 <= tau_i <= 1 and sum(tau) == 1).
Expand Down Expand Up @@ -189,33 +193,47 @@ def verify_pi_values(self, pi_a, pi_b):
# TODO: (LOW PRIORITY): Add additional metrics (KL, MI, etc.)

def compute_joint_probability(self, tau, u, v):
# TODO: move to unified helper scripts file
def safe_get(pmf, c, min_val=1e-100):
# if c is greater than the max count in pmf
if c > max(pmf.keys()):
return min_val
return pmf.get(c, 0)

joint_probability = np.array(
[
tau
* self.gene_a.bmr_pmf.get(c_a - u, 0)
* self.gene_b.bmr_pmf.get(c_b - v, 0)
* safe_get(self.gene_a.bmr_pmf, c_a - u, 0)
* safe_get(self.gene_b.bmr_pmf, c_b - v, 0)
for c_a, c_b in zip(self.gene_a.counts, self.gene_b.counts)
]
)
return joint_probability

def compute_total_probability(self, tau_00, tau_01, tau_10, tau_11):
# TODO: move to helper scripts file
def safe_get(pmf, c, min_val=1e-100):
# if c is greater than the max count in pmf
if c > max(pmf.keys()):
return min_val
return pmf.get(c, 0)

total_probabilities = np.array(
[
sum(
(
tau_00
* self.gene_a.bmr_pmf.get(c_a, 0)
* self.gene_b.bmr_pmf.get(c_b, 0),
* safe_get(self.gene_a.bmr_pmf, c_a, 0)
* safe_get(self.gene_b.bmr_pmf, c_b, 0),
tau_01
* self.gene_a.bmr_pmf.get(c_a, 0)
* self.gene_b.bmr_pmf.get(c_b - 1, 0),
* safe_get(self.gene_a.bmr_pmf, c_a, 0)
* safe_get(self.gene_b.bmr_pmf, c_b - 1, 0),
tau_10
* self.gene_a.bmr_pmf.get(c_a - 1, 0)
* self.gene_b.bmr_pmf.get(c_b, 0),
* safe_get(self.gene_a.bmr_pmf, c_a - 1, 0)
* safe_get(self.gene_b.bmr_pmf, c_b, 0),
tau_11
* self.gene_a.bmr_pmf.get(c_a - 1, 0)
* self.gene_b.bmr_pmf.get(c_b - 1, 0),
* safe_get(self.gene_a.bmr_pmf, c_a - 1, 0)
* safe_get(self.gene_b.bmr_pmf, c_b - 1, 0),
)
)
for c_a, c_b in zip(self.gene_a.counts, self.gene_b.counts)
Expand Down Expand Up @@ -257,6 +275,18 @@ def compute_log_likelihood(self, taus):

# TODO: add verbose option for logging
# logging.info(f"Computing log likelihood for {self.name}. Taus: {taus}")
# TODO: move to unified helper scripts file
def safe_get_no_default(pmf, c, min_val=1e-100):
# if c is greater than the max count in pmf
if c > max(pmf.keys()):
return min_val
return pmf.get(c)

def safe_get_with_default(pmf, c, min_val=1e-100):
# if c is greater than the max count in pmf
if c > max(pmf.keys()):
return min_val
return pmf.get(c, 0)

self.verify_bmr_pmf_and_counts_exist()
self.verify_taus_are_valid(taus)
Expand All @@ -266,10 +296,18 @@ def compute_log_likelihood(self, taus):
tau_00, tau_01, tau_10, tau_11 = taus
log_likelihood = sum(
np.log(
a_bmr_pmf.get(c_a) * b_bmr_pmf.get(c_b) * tau_00
+ a_bmr_pmf.get(c_a) * b_bmr_pmf.get(c_b - 1, 0) * tau_01
+ a_bmr_pmf.get(c_a - 1, 0) * b_bmr_pmf.get(c_b) * tau_10
+ a_bmr_pmf.get(c_a - 1, 0) * b_bmr_pmf.get(c_b - 1, 0) * tau_11
safe_get_no_default(a_bmr_pmf, c_a)
* safe_get_no_default(b_bmr_pmf, c_b)
* tau_00
+ safe_get_no_default(a_bmr_pmf, c_a)
* safe_get_with_default(b_bmr_pmf, c_b - 1)
* tau_01
+ safe_get_with_default(a_bmr_pmf, c_a - 1)
* safe_get_no_default(b_bmr_pmf, c_b)
* tau_10
+ safe_get_with_default(a_bmr_pmf, c_a - 1)
* safe_get_with_default(b_bmr_pmf, c_b - 1)
* tau_11
)
for c_a, c_b in zip(a_counts, b_counts)
)
Expand All @@ -295,7 +333,9 @@ def compute_likelihood_ratio(self, taus):
:return: (float) The computed likelihood ratio test statistic (\( \lambda_{LR} \)).
"""

logging.info(f"Computing likelihood ratio for interaction {self.name}.")
logging.info(
f"Computing likelihood ratio for interaction {self.name}."
)

tau_00, tau_01, tau_10, tau_11 = taus
driver_a_marginal = tau_10 + tau_11
Expand Down Expand Up @@ -562,11 +602,17 @@ def estimate_tau_with_em_from_scratch(
/ total_probabilities
)

# TODO: map out other reasons for nan values and standardize handling
# remove nans to avoid underflow issues in bmr estimates
z_i_00_no_nan = np.nan_to_num(z_i_00, nan=2e-100)
z_i_01_no_nan = np.nan_to_num(z_i_01, nan=2e-100)
z_i_10_no_nan = np.nan_to_num(z_i_10, nan=2e-100)
z_i_11_no_nan = np.nan_to_num(z_i_11, nan=2e-100)
# M-Step: Update tau parameters
curr_tau_00 = np.mean(z_i_00)
curr_tau_01 = np.mean(z_i_01)
curr_tau_10 = np.mean(z_i_10)
curr_tau_11 = np.mean(z_i_11)
curr_tau_00 = np.mean(z_i_00_no_nan)
curr_tau_01 = np.mean(z_i_01_no_nan)
curr_tau_10 = np.mean(z_i_10_no_nan)
curr_tau_11 = np.mean(z_i_11_no_nan)

# Check for convergence
prev_log_likelihood = self.compute_log_likelihood(
Expand Down

0 comments on commit 51ec207

Please sign in to comment.