Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics for CL #68

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
94c7d79
rebase
Muedi Aug 12, 2023
41ae914
Add function to get base sequences will remove again, but kept it for…
Muedi Aug 12, 2023
b138c4b
rebase
Muedi Aug 13, 2023
29bb2ce
function to check and set cwd as base folder for eval script
Muedi Aug 22, 2023
13f620a
runnable again
Muedi Aug 22, 2023
917d91b
runs with esm pip package, so I removed the code again :) error stil…
Muedi Sep 7, 2023
4e1e7d1
Merge branch 'OpenBioML:main' into main
Muedi Sep 7, 2023
bece51f
wt margins works (label_rows)
Muedi Sep 7, 2023
5123fbc
eval script, runs for ESM baseline, PR draft
Muedi Sep 11, 2023
aad47d2
add __init.py__ in getters folder
Muedi Sep 13, 2023
207cfef
add AUTOREG variable, move train/test split to supervised block
Muedi Sep 13, 2023
ac0b743
APT zero-shot eval
Muedi Sep 13, 2023
2240890
small changes for vars
Muedi Sep 13, 2023
d3ba1a1
add cuda yaml and info on readme, Issue (#53)
Muedi Sep 15, 2023
194b121
break up eval script into multiple smaller ones
Muedi Sep 26, 2023
fb83e81
added esm and APT zero shot as to the comments in PR.
Muedi Sep 26, 2023
7f878c8
removed old eval script, added changes for Issues (#4) and (#53)
Muedi Sep 26, 2023
0d5a24d
Merge branch 'OpenBioML:main' into main
Muedi Oct 12, 2023
3ca4460
combine folders in eval (#60)
Muedi Oct 12, 2023
57f89f5
update yml and gitignore
Muedi Oct 12, 2023
db4e96d
add rope config args to toy_hf
Muedi Oct 12, 2023
05af87e
Merge branch 'OpenBioML:main' into main
Muedi Jan 30, 2024
1a82f12
Merge branch 'main' of github.com:Muedi/protein-lm-scaling
Muedi Jan 30, 2024
48992a4
script for sequence complexity metrics
Muedi Jan 31, 2024
dc96ab3
scripts to read a fasta and compute entropy and KLD
Muedi Feb 19, 2024
34b6965
added amelies intrinsic dim stuff, need to add fasta support
Muedi Feb 19, 2024
ddfc596
typo in script name :D
Muedi Mar 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
protein_lm/dataset/ProteinGym/
protein_lm/evaluation/output/
/esm2*/
/toy/
/toy*/
*.lock
*.pyc
wandb/
checkpoints/
__pycache__/
protein_lm.egg-info/
*.DS_Store
/.vs/
2 changes: 2 additions & 0 deletions protein_lm/configs/train/toy_hf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ model:
nn_model_type: "APT"
nn_model_config_args:
position_embedding: "learned"
rope_scaling_factor: 1.0
rope_theta: 10000
max_sequence_length: 10
pretrained_checkpoint: null

Expand Down
251 changes: 251 additions & 0 deletions protein_lm/dataset/clustering_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# %%
from math import log2
from Bio import SeqIO

import numpy as np
from sklearn.linear_model import LinearRegression
from transformers import AutoTokenizer, AutoModel, EsmModel
import torch
from scipy.sparse.csgraph import minimum_spanning_tree



# %%
# simple metrics
def get_AA_counts(seq, alphabet):
"""computes the counts of AAs in a seq"""
AA_counts = {x:0 for x in alphabet}
for AA in alphabet:
AA_counts[AA] = seq.count(AA)
return AA_counts

def get_frquency(seq, alphabet):
"""computes the frequency of AAs in seq"""
AA_counts = get_AA_counts(seq, alphabet)
return {k: v/len(seq) for k,v in AA_counts.items()}


def compute_entropy(seq, alphabet):
"""
computes single seq entropy,
as zero freq AAs are a possibility, we skip zero frqs in the sumation.
I thought that would be better than smoothing with 1 or another constant, but I am not sure.
I found a paper where it is also done that way:
https://academic.oup.com/mbe/article/40/4/msad084/7111731
"""
AA_counts = get_frquency(seq, alphabet)
E_dict = {k: v * log2(v) for k, v in AA_counts.items() if v != 0}
E = - sum(E_dict.values())
return E

def compute_kullback_leibler(seq, alphabet, background_freq=None):
"""
computesKL-divergence given a background frequency
as zero freq AAs are a possibility, we skip zero frqs in the sumation.
if AAs in background equal zero, this will also throw an error!
TODO: either drop similarly to AA freq in seq or smooth?
"""
# set background to 0.05 for each if none given
if background_freq == None:
background_freq = {'A': 0.05,
'R': 0.05,
'N': 0.05,
'D': 0.05,
'C': 0.05,
'E': 0.05,
'Q': 0.05,
'G': 0.05,
'H': 0.05,
'I': 0.05,
'L': 0.05,
'K': 0.05,
'M': 0.05,
'F': 0.05,
'P': 0.05,
'S': 0.05,
'T': 0.05,
'W': 0.05,
'Y': 0.05,
'V': 0.05
}

AA_counts = get_frquency(seq, alphabet)
KLD_dict = {k: v * log2(v/background_freq[k]) for k, v in AA_counts.items() if v != 0}
KLD = sum(KLD_dict.values())
return KLD


def process_sequence(rec):
"""
Process a single sequence and return counts and frequencies
"""
seq = str(rec.seq)
counts = get_AA_counts(seq, alphabet)
frequency = get_frquency(seq, alphabet)
return counts, frequency


def get_background_from_fasta_no_alignment(fasta_file, alphabet, num_seqs):
"""
iterates over fasta to get the AA frequencies of all seqs.
"""
fasta_iterator = SeqIO.parse(fasta_file, "fasta")

# Initialize dictionaries to store counts and frequencies
total_frequency = {x: 0 for x in alphabet}

# Iterate over the records in the FASTA file
for record in fasta_iterator:
# Get sequence as a string
seq = str(record.seq)

# Compute counts and frequencies for this sequence
frequency = get_frquency(seq, alphabet)

# Accumulate counts and frequencies
for aa in alphabet:
total_frequency[aa] += frequency[aa]

# Normalize frequencies by the number of sequences
num_sequences = sum(1 for _ in SeqIO.parse(fasta_file, "fasta"))
for aa in alphabet:
total_frequency[aa] /= num_seqs

return total_frequency

def compute_KLD_fasta(fasta_file, alphabet, background_freq):
"""
computes the KLD witht the background of the given fasta.
"""

fasta_iterator = SeqIO.parse(fasta_file, "fasta")

KLDs = {}

for rec in fasta_iterator:
# get ID and seq, pack into dict as id:KLD
KLDs[rec.id] = compute_kullback_leibler(str(rec.seq), alphabet, background_freq)

return KLDs

# %%
# intrinsic dimension as suggested by @Amelie-Schreiber
# https://huggingface.co/blog/AmelieSchreiber/intrinsic-dimension-of-proteins

def estimate_persistent_homology_dimension_avg(sequence, model, tokenizer, num_subsets=5, num_iterations=10):
"""
Estimate the persistent homology dimension of a given protein sequence.

Parameters:
- sequence: A string representing the protein sequence.
- model: a model that computes embeddings from the prot seq, tested only with esm atm
- tokenizer: tokenizer fitting to the model.
- num_subsets: A positive integer indicating the number of subsets of the embedding vectors to use. Max of 2**n where n=len(sequence).
- num_iterations: A positive integer indicating the number of iterations for averaging.

Returns:
- avg_phd: Average estimated persistent homology dimension.
"""

phd_values = [] # List to store PHD values for each iteration

for _ in range(num_iterations):

# Tokenize the input and convert to tensors
inputs = tokenizer(sequence, return_tensors='pt')

# Get the embeddings
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state[0].numpy()

# Remove the first and last embeddings (<CLS> and <EOS>)
embeddings = embeddings[1:-1]

# Sizes for the subsets to sample
sizes = np.linspace(2, len(embeddings), num=num_subsets, dtype=int)

# Prepare data for linear regression
x = []
y = []

for size in sizes:
# Sample a subset of the embeddings
subset = np.random.choice(len(embeddings), size, replace=False)
subset_embeddings = embeddings[subset]

# Compute the distance matrix
dist_matrix = np.sqrt(np.sum((subset_embeddings[:, None] - subset_embeddings)**2, axis=-1))

# Compute the minimum spanning tree
mst = minimum_spanning_tree(dist_matrix).toarray()

# Calculate the persistent score E (the maximum edge length in the MST)
E = np.max(mst)

# Append to the data for linear regression
x.append(np.log(size))
y.append(np.log(E))

# Reshape for sklearn
X = np.array(x).reshape(-1, 1)
Y = np.array(y).reshape(-1, 1)

# Linear regression
reg = LinearRegression().fit(X, Y)

# Estimated Persistent Homology Dimension for this iteration
phd = 1 / (1 - reg.coef_[0][0])

phd_values.append(phd)

avg_phd = np.mean(phd_values) # Average over all iterations
return avg_phd
# %%

prot_seq = "ARNDCEQGHILKMFPSTWYVARNDCEQGHILKMFPSTWYV"
prot_seq_halfhalf = "ARNDCEQGHILKMFPSTWYVAAAAAAAAAAAAAAAAAAAA"
prot_seq_low_comp = "MMAAAMMAAAMMAAAMMAAAMMAAAMMAAAMMAAAMMAAA"
prot_seq_homo_rep = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
prot_seq_homo_rep_halflength = "AAAAAAAAAAAAAAAAAAAA"

alphabet = [
'A',
'R',
'N',
'D',
'C',
'E',
'Q',
'G',
'H',
'I',
'L',
'K',
'M',
'F',
'P',
'S',
'T',
'W',
'Y',
'V'
]

test_fasta_path = "C:/Users/maxsp/Work/prots_test_complexity.fasta"
num_sequences = sum(1 for _ in SeqIO.parse(test_fasta_path, "fasta"))

# %%
# run on test fasta
background = get_background_from_fasta_no_alignment(test_fasta_path, alphabet, num_sequences)
KLD = compute_KLD_fasta(test_fasta_path, alphabet, background)

# %%
# test intrinsic dim stuff
# Load the tokenizer and model
model_path = "facebook/esm2_t6_8M_UR50D"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = EsmModel.from_pretrained(model_path)

estimate_persistent_homology_dimension_avg(prot_seq, model, tokenizer, num_subsets=2, num_iterations=10)
estimate_persistent_homology_dimension_avg(prot_seq_low_comp, model, tokenizer, num_subsets=2, num_iterations=10)
Loading