-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmsa_emd.py
90 lines (72 loc) · 3.37 KB
/
msa_emd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable
import itertools
import os
import string
from pathlib import Path
import numpy as np
import torch
from scipy.spatial.distance import squareform, pdist, cdist
#import matplotlib.pyplot as plt
#import matplotlib as mpl
from Bio import SeqIO
#import biotite.structure as bs
#from biotite.structure.io.pdbx import PDBxFile, get_structure
#from biotite.database import rcsb
from tqdm import tqdm
import pandas as pd
import glob
import esm
torch.set_grad_enabled(False)
# This is an efficient way to delete lowercase characters and insertion characters from a string
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)
def read_sequence(filename: str) -> Tuple[str, str]:
""" Reads the first (reference) sequences from a fasta or MSA file."""
record = next(SeqIO.parse(filename, "fasta"))
return record.description, str(record.seq)
def remove_insertions(sequence: str) -> str:
""" Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """
return sequence.translate(translation)
def read_msa(filename: str) -> List[Tuple[str, str]]:
""" Reads the sequences from an MSA file, automatically removes insertions."""
return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")]
#folders=glob.glob("/p/haicluster/msa_embeddings/*")
#folders[]
# Select sequences from the MSA to maximize the hamming distance
# Alternatively, can use hhfilter
def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]:
assert mode in ("max", "min")
#print(msa[0])
if len(msa) <= num_seqs:
return msa
array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8)
optfunc = np.argmax if mode == "max" else np.argmin
all_indices = np.arange(len(msa))
indices = [0]
pairwise_distances = np.zeros((0, len(msa)))
for _ in range(num_seqs - 1):
dist = cdist(array[indices[-1:]], array, "hamming")
pairwise_distances = np.concatenate([pairwise_distances, dist])
shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0)
shifted_index = optfunc(shifted_distance)
index = np.delete(all_indices, indices)[shifted_index]
indices.append(index)
indices = sorted(indices)
return [msa[idx] for idx in indices]
msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
msa_transformer = msa_transformer.eval().cuda()
msa_transformer_batch_converter = msa_transformer_alphabet.get_batch_converter()
msas['A0A0A0B5Q0']=read_msa(f"A0A0A0B5Q0.a3m")
msa_transformer_predictions = {}
msa_transformer_predictions_no_contacts = {}
msa_transformer_results = []
for name, inputs in msas.items():
inputs = greedy_select(inputs, num_seqs=128) # can change this to pass more/fewer sequences
msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
temp=msa_transformer(msa_transformer_batch_tokens,repr_layers=[12])['representations']
temp=temp[12][:,:,0,:]
temp=torch.mean(temp,(0,1))
print(temp)