Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to BCRPhylo to allow for mapping BCR sequence to affinity and expression #7

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions bin/GCutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import absolute_import
from Bio.Seq import Seq
from Bio.Seq import translate as bio_translate
from Bio.Alphabet import generic_dna
from ete3 import TreeNode, NodeStyle, TreeStyle, TextFace, CircleFace, PieChartFace, faces, SVG_COLORS
import scipy
import numpy as np
Expand All @@ -15,18 +14,18 @@
except:
import pickle

try:
import jellyfish

def hamming_distance(s1, s2):
if s1 == s2:
return 0
else:
return jellyfish.hamming_distance(unicode(s1), unicode(s2))
except:
def hamming_distance(seq1, seq2):
'''Hamming distance between two sequences of equal length'''
return sum(x != y for x, y in zip(seq1, seq2))
#try:
#import jellyfish

#def hamming_distance(s1, s2):
#if s1 == s2:
#return 0
#else:
#return jellyfish.hamming_distance(s1, s2)
#except:
def hamming_distance(seq1, seq2):
'''Hamming distance between two sequences of equal length'''
return sum(x != y for x, y in zip(seq1, seq2))
print('Couldn\'t find the python module "jellyfish" which is used for fast string comparison. Falling back to pure python function.')

global ISO_TYPE_ORDER
Expand All @@ -44,10 +43,11 @@ def local_translate(seq):

# ----------------------------------------------------------------------------------------
def replace_codon_in_aa_seq(new_nuc_seq, old_aa_seq, inuc): # <inuc>: single nucleotide position that was mutated from old nuc seq (which corresponds to old_aa_seq) to new_nuc_seq
istart = 3 * int(math.floor(inuc / 3.)) # nucleotide position of start of mutated codon
istart = 3 * math.floor(inuc / 3) # nucleotide position of start of mutated codon
aa_new = math.floor(inuc / 3)
new_codon = local_translate(new_nuc_seq[istart : istart + 3])
return old_aa_seq[:inuc / 3] + new_codon + old_aa_seq[inuc / 3 + 1:] # would be nice to check for synonymity and not do any translation unless we need to

new_seq = old_aa_seq[:aa_new] + new_codon + old_aa_seq[aa_new + 1:] # would be nice to check for synonymity and not do any translation unless we need to
return new_seq
# ----------------------------------------------------------------------------------------
class TranslatedSeq(object):
# ----------------------------------------------------------------------------------------
Expand Down Expand Up @@ -270,7 +270,7 @@ def my_layout(node):
if idlabel:
aln = MultipleSeqAlignment([])
for node in self.tree.traverse():
aln.append(SeqRecord(Seq(str(node.nuc_seq), generic_dna), id=node.name, description='abundance={}'.format(node.frequency)))
aln.append(SeqRecord(Seq(str(node.nuc_seq)), id=node.name, description='abundance={}'.format(node.frequency)))
AlignIO.write(aln, open(os.path.splitext(outfile)[0] + '.fasta', 'w'), 'fasta')

def write(self, file_name):
Expand Down
175 changes: 145 additions & 30 deletions bin/selection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,16 @@ def init_blosum():
blinfo = {aa1 : {aa2 : None for aa2 in all_amino_acids} for aa1 in all_amino_acids}
init_blosum()
sdists = { # config info necessary for rescaling the two versions of aa similarity distance
'ascii' : {
'scale_min' : 0., # you want the mean to be kinda sorta around 1, so the --target_dist ends up being comparable
'scale_max' : 4.7,
},
'blosum' : {
'scale_min' : 0.,
# 'scale_max' : 3.95, # use this value if you go back to taking the exp() (note that now the exp() would need to be added in two places)
'scale_max' : 1.55,
},
}
'ascii' : {
'scale_min' : 0., # you want the mean to be kinda sorta around 1, so the --target_dist ends up being comparable
'scale_max' : 4.7,
},
'blosum' : {
'scale_min' : 0.,
# 'scale_max' : 3.95, # use this value if you go back to taking the exp() (note that now the exp() would need to be added in two places)
'scale_max' : 1.55,
},
}
for sdtype, sdinfo in sdists.items():
if sdtype == 'ascii':
dfcn = aa_ascii_code_distance
Expand Down Expand Up @@ -134,9 +134,9 @@ def plot_sdists():
print(' raw rescaled raw rescaled')
for aa2 in all_amino_acids:
print(' %s %5.2f %5.2f %5.1f %5.2f' % (aa2,
aa_inverse_similarity(aa1, aa2, 'blosum', dont_rescale=True), aa_inverse_similarity(aa1, aa2, 'blosum'),
aa_inverse_similarity(aa1, aa2, 'ascii', dont_rescale=True), aa_inverse_similarity(aa1, aa2, 'ascii')))
import plotutils
aa_inverse_similarity(aa1, aa2, 'blosum', dont_rescale=True), aa_inverse_similarity(aa1, aa2, 'blosum'),
aa_inverse_similarity(aa1, aa2, 'ascii', dont_rescale=True), aa_inverse_similarity(aa1, aa2, 'ascii')))
import plotutils
for sdtype in ['ascii', 'blosum']:
print(sdtype)
all_rescaled_vals = [aa_inverse_similarity(aa1, aa2, sdtype=sdtype) for aa1, aa2 in itertools.combinations_with_replacement(all_amino_acids, 2)]
Expand All @@ -161,58 +161,98 @@ def target_distance_fcn(args, this_seq, target_seqs):
return itarget, tdist

# ----------------------------------------------------------------------------------------
def count_aa(aa_seq,which_aa):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i could be mistaken, but i think this is equivalent to aa_seq.count(which_aa)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I missed that function entirely. Where is it located?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list_of_aa = list(aa_seq)
res1 = []
for i in range(len(list_of_aa)):
if list_of_aa[i] == which_aa:
res1.append(i)
return len(res1)

def count_not_that_aa(aa_seq,which_aa):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it may not make a difference, but i think something like return len([a for a in aa_seq if a!=which_aa]) would be faster here (also maybe clearer?). Also just two general things -- it's best to have at least a few word description for even small functions (so you don't have to read the whole function to figure out what's getting counted where), and could you add spaces after commas to match the style elsewhere?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right- that is a better function. And yes, I can add spaces.

list_of_aa = list(aa_seq)
res1 = []
for i in range(len(list_of_aa)):
if list_of_aa[i] != which_aa:
res1.append(i)
return len(res1)

def calc_kd(node, args):
if has_stop_aa(node.aa_seq): # nonsense sequences have zero affinity/infinite kd
return float('inf')
if not args.selection:
return 1.

assert args.mature_kd < args.naive_kd
tdist = node.target_distance if args.min_target_distance is None else max(node.target_distance, args.min_target_distance)
kd = args.mature_kd + (args.naive_kd - args.mature_kd) * (tdist / float(args.target_distance))**args.k_exp # transformation from distance to kd

#tdist = node.target_distance if args.min_target_distance is None else max(node.target_distance, args.min_target_distance)
#kd = args.mature_kd + (args.naive_kd - args.mature_kd) * (tdist / float(args.target_distance))**args.k_exp # transformation from distance to kd
#kd = 100 - 3*min(count_not_that_aa(node.aa_seq,'A'),33)
kd = 100 - 12*min(count_aa(node.aa_seq,'L'),9)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it completely changes the way kd is calculated? My understanding was that we were going to add some new options that allowed to turn on new features, but not modify existing behavior.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you can see by all the commented out portions, this was my rough initial testing of some things. Perhaps the best way to implement it would be flag (i.e. --torchDMS) so that if --torchDMS "true" was passed, it would used a torchDMS prediction, if --neutral were passed it would return some nM affinity for every value that allows for GC expansion (i.e. lower than 100 nM), and if --target_distance were passed it returns the target distance based Kd. Does that sound good?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, sorry, I misunderstood -- we should get this finished up before merging the pull request. Yes, something like that sounds great. I think the default should remain unchanged (i.e if someone runs with the same command line args before and after we make these changes, it's important they get identical results). I think a string flag with several choices (like this) might be the way to go, perhaps with a name like --affinity-to-fitness-mapping?

return kd

def calc_B_exp(node, args):
if has_stop_aa(node.aa_seq): # nonsense sequences have no production
B_exp = 0
if not args.selection:
return 10000
#B_exp = 5000 + 250*min(count_aa(node.aa_seq,'T'),15)
#B_exp = 5000 + 250*min(count_aa(node.aa_seq,'L'),15)
B_exp = 5000
if B_exp == 0:
print(node.aa_seq)
return B_exp

# ----------------------------------------------------------------------------------------
def update_lambda_values(live_leaves, A_total, B_total, logi_params, selection_strength, lambda_min=10e-10):
''' update the lambda_ feature (parameter for the poisson progeny distribution) for each leaf in <live_leaves> '''

# ----------------------------------------------------------------------------------------
def calc_BnA(Kd_n, A, B_total):
#Since each B cell now has its own BCR expression level, would need to calculate B_total
def sum_of_B(B_n):
B_total = numpy.sum(B_n)
return B_total

def calc_BnA(Kd_n, A, B_n):
'''
This calculates the fraction B:A (B bound to A), at equilibrium also referred to as "binding time",
of all the different Bs in the population given the number of free As in solution.
'''
BnA = B_total/(1+Kd_n/A)
BnA = B_n/(1+Kd_n/A)
return(BnA)

# ----------------------------------------------------------------------------------------
def return_objective_A(Kd_n, A_total, B_total):
def return_objective_A(Kd_n, A_total, B_n):
'''
The objective function that solves the set of differential equations setup to find the number of free As,
at equilibrium, given a number of Bs with some affinity listed in Kd_n.
'''
return lambda A: (A_total - (A + scipy.sum(B_total/(1+Kd_n/A))))**2

return lambda A: (A_total - (A + numpy.sum(B_n/(1+Kd_n/A))))**2
# ----------------------------------------------------------------------------------------
def calc_binding_time(Kd_n, A_total, B_total):
def calc_binding_time(Kd_n, A_total, B_n):
'''
Solves the objective function to find the number of free As and then uses this,
to calculate the fraction B:A (B bound to A) for all the different Bs.
'''
obj = return_objective_A(Kd_n, A_total, B_total)
obj = return_objective_A(Kd_n, A_total, B_n)
# Different minimizers have been tested and 'L-BFGS-B' was significant faster than anything else:
obj_min = minimize(obj, A_total, bounds=[[1e-10, A_total]], method='L-BFGS-B', tol=1e-20)
BnA = calc_BnA(Kd_n, obj_min.x[0], B_total)
BnA = calc_BnA(Kd_n, obj_min.x[0], B_n)
#print(Kd_n)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's best not to commit commented lines, just to keep things clean.

#print(obj_min.x[0])
#print(B_n)
# Terminate if the precision is not good enough:
assert(BnA.sum()+obj_min.x[0]-A_total < A_total/100)
assert(BnA.sum()+obj_min.x[0]-A_total < A_total/100)
return BnA

# ----------------------------------------------------------------------------------------
def trans_BA(BnA):
'''Transform the fraction B:A (B bound to A) to a poisson lambda between 0 and 2.'''
# We keep alpha to enable the possibility that there is a minimum lambda_:
lambda_ = alpha + (2 - alpha) / (1 + Q*scipy.exp(-beta*BnA))
#print(beta)
#print(BnA)
power_val = -beta*BnA
power_val[power_val < -125] = -125
#lambda_ = 2*(alpha + (2 - alpha) / (1 + Q*numpy.exp(-beta*BnA)))
lambda_ = 2*(alpha + (2 - alpha) / (1 + Q*numpy.exp(power_val))) ##multiplied by 2 here
return [max(lambda_min, l) for l in lambda_]

# ----------------------------------------------------------------------------------------
Expand Down Expand Up @@ -240,8 +280,9 @@ def getvar(lvals):

# ----------------------------------------------------------------------------------------
alpha, beta, Q = logi_params
Kd_n = scipy.array([l.Kd for l in live_leaves])
BnA = calc_binding_time(Kd_n, A_total, B_total) # get list of binding time values for each cell
Kd_n = numpy.array([l.Kd for l in live_leaves])
B_n = numpy.array([l.B_exp for l in live_leaves])
BnA = calc_binding_time(Kd_n, A_total, B_n) # get list of binding time values for each cell
new_lambdas = trans_BA(BnA) # convert binding time list to list of poisson lambdas for each cell (which determine number of offspring)
if selection_strength < 1:
new_lambdas = apply_selection_strength_scaling(new_lambdas)
Expand All @@ -260,7 +301,7 @@ def A_obj(carry_cap, B_total, f_full, Kd_n, U):
def obj(A): return((carry_cap - C_A(A, A_total_fun(A, B_total, Kd_n), f_full, U))**2)
return obj

Kd_n = scipy.array([mature_kd] * carry_cap)
Kd_n = numpy.array([mature_kd] * carry_cap)
obj = A_obj(carry_cap, B_total, f_full, Kd_n, U)
# Some funny "zero encountered in true_divide" errors are not affecting results so ignore them:
old_settings = scipy.seterr(all='ignore') # Keep old settings
Expand Down Expand Up @@ -360,6 +401,80 @@ def make_bounds(tdist_hists): # tdist_hists: list (over generations) of scipy.h
plt.title('population over time of cells grouped by min. distance to target')
fig.savefig(outbase + '.selection_sim.runstats.pdf')

def make_list(array):
#print(type(array))
dummy_list = []
for row in array:
dummy_list.append(numpy.ndarray.tolist(row))
#dummy_list.append(row)
#print(row)
#print(type(row))
return dummy_list

def make_equal_lengths(a):
row_lengths = []
for row in a:
row_lengths.append(len(row))
max_length = max(row_lengths)
#print(max_length)
b = make_list(a)
for row in b:
while len(row) < max_length:
row.append(None)
#print(row)
return b

def sort_it_out(bcell_index):
#inf = float("inf")
sorted_scatter = []
for element in bcell_index:
sort_row = sorted(element, key=lambda x: float('inf') if x is None else x)
sorted_scatter.append(sort_row)
return sorted_scatter

def none_removed(index_sorted):
#inf = float("inf")
dummy = []
for element in index_sorted:
sort_row = []
for i in element:
if not math.isinf(i):
sort_row.append(i)
dummy.append(sort_row)
return dummy

def plot_runstats2(scatter_value, scatter_index, desired_name):
fig = plt.figure()
ax = plt.subplot(111)
x_index = make_equal_lengths(scatter_index)
y_index = make_equal_lengths(scatter_value)
#print(x_index)
#print(y_index)
ax.scatter(x_index, y_index)
#plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0)
# Shrink current axis by 20% to make the legend fit:
#box = ax.get_position()
#ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])

plt.ylabel(desired_name)
plt.xlabel('GC generation')
plt.title('distribution of BCR value at each generation')
fig.savefig('GCsim' + desired_name + '.sim_scatter.pdf')

def plot_runstats3(bcr_value, scatter_index, desired_name):
numpy.seterr(under='ignore')
fig = plt.figure()
ax = plt.subplot(111)
what_i_want_to_plot = none_removed(sort_it_out(bcr_value))
y_index = []
for i in range(len(scatter_index)//5):
y_index.append(what_i_want_to_plot[5*i+5])
ax.violinplot(y_index)
plt.ylabel(desired_name)
plt.xlabel('Every 5th GC generation')
plt.title('distribution of BCR value at each generation')
fig.savefig('GCsim' + desired_name + '.sim_violin.pdf')

# ----------------------------------------------------------------------------------------
# bash color codes
Colors = {}
Expand Down
Loading