Skip to content

Commit

Permalink
Update for PraatIO 6.0 and related fixes (#547)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcauliffe authored Feb 6, 2023
1 parent a285cb4 commit 6c589c5
Show file tree
Hide file tree
Showing 15 changed files with 133 additions and 208 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies:
- hdbscan
- baumwelch
- ngram
- praatio
- praatio=6.0.0
- biopython=1.79
- sqlalchemy>=2.0
- pgvector
Expand Down
41 changes: 5 additions & 36 deletions montreal_forced_aligner/alignment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,7 @@ def align(self, workflow_name=None) -> None:
logger.info("Performing first-pass alignment...")
self.uses_speaker_adaptation = False
for j in self.jobs:
paths = j.construct_dictionary_dependent_paths(
self.working_directory, "trans", "ark"
)
paths = j.construct_path_dictionary(self.working_directory, "trans", "ark")
for p in paths.values():
if os.path.exists(p):
os.remove(p)
Expand Down Expand Up @@ -679,31 +677,10 @@ def collect_alignments(self) -> None:
:meth:`.CorpusAligner.alignment_extraction_arguments`
Arguments for extraction
"""
indices = [
("word_utterance_workflow_index", "word_interval", ["utterance_id", "workflow_id"]),
("phone_utterance_workflow_index", "phone_interval", ["utterance_id", "workflow_id"]),
("ix_word_interval_workflow_id", "word_interval", ["workflow_id"]),
("ix_word_interval_word_id", "word_interval", ["word_id"]),
("ix_word_interval_utterance_id", "word_interval", ["utterance_id"]),
("ix_word_interval_pronunciation_id", "word_interval", ["pronunciation_id"]),
("ix_word_interval_begin", "word_interval", ["begin"]),
("ix_phone_interval_workflow_id", "phone_interval", ["workflow_id"]),
("ix_phone_interval_word_interval_id", "phone_interval", ["word_interval_id"]),
("ix_phone_interval_utterance_id", "phone_interval", ["utterance_id"]),
("ix_phone_interval_phone_id", "phone_interval", ["phone_id"]),
("ix_phone_interval_begin", "phone_interval", ["begin"]),
]
with self.session() as session:
session.execute(sqlalchemy.text("ALTER TABLE word_interval DISABLE TRIGGER all"))
session.execute(sqlalchemy.text("ALTER TABLE phone_interval DISABLE TRIGGER all"))
session.commit()
for ix in indices:
try:
session.execute(sqlalchemy.text(f"DROP INDEX {ix[0]}"))
except Exception:
pass
session.commit()
with self.session() as session:
workflow = (
session.query(CorpusWorkflow)
.filter(CorpusWorkflow.current == True) # noqa
Expand Down Expand Up @@ -840,21 +817,10 @@ def collect_alignments(self) -> None:
phone_buf.seek(0)
conn.commit()
conn.close()
logger.info("Refreshing indices...")
with tqdm.tqdm(
total=len(indices), disable=GLOBAL_CONFIG.quiet
) as pbar, self.session() as session:
with self.session() as session:
if new_words:
session.execute(sqlalchemy.insert(Word).values(new_words))
for ix in indices:
session.execute(
sqlalchemy.text(f'CREATE INDEX {ix[0]} ON {ix[1]} ({", ".join(ix[2])})')
)
session.commit()
pbar.update(1)
session.execute(sqlalchemy.text("ALTER TABLE word_interval ENABLE TRIGGER all"))
session.execute(sqlalchemy.text("ALTER TABLE phone_interval ENABLE TRIGGER all"))
session.commit()

with self.session() as session:
workflow = (
Expand Down Expand Up @@ -885,6 +851,9 @@ def collect_alignments(self) -> None:
{CorpusWorkflow.alignments_collected: True}
)
session.commit()
session.execute(sqlalchemy.text("ALTER TABLE word_interval ENABLE TRIGGER all"))
session.execute(sqlalchemy.text("ALTER TABLE phone_interval ENABLE TRIGGER all"))
session.commit()

def fine_tune_alignments(self) -> None:
"""
Expand Down
33 changes: 16 additions & 17 deletions montreal_forced_aligner/corpus/acoustic_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,37 +134,36 @@ def has_alignments(self, workflow_id: typing.Optional[int] = None):
)
return check

def has_ivectors(self, speaker=False):
def has_ivectors(self):
with self.session() as session:
if speaker:
check = (
session.query(Speaker).filter(Speaker.ivector != None).limit(1).first() # noqa
is not None
)
else:
check = (
session.query(Utterance)
.filter(Utterance.ivector != None) # noqa
.limit(1)
.first()
is not None
)
check = (
session.query(Corpus)
.filter(Corpus.ivectors_calculated == True) # noqa
.limit(1)
.first()
is not None
)
return check

def has_xvectors(self):
with self.session() as session:
check = (
session.query(Utterance).filter(Utterance.xvector != None).limit(1).first() # noqa
session.query(Corpus)
.filter(Corpus.xvectors_loaded == True) # noqa
.limit(1)
.first()
is not None
)
return check

def has_any_ivectors(self):
with self.session() as session:
check = (
session.query(Utterance)
session.query(Corpus)
.filter(
sqlalchemy.or_(Utterance.xvector != None, Utterance.ivector != None) # noqa
sqlalchemy.or_(
Corpus.ivectors_calculated == True, Corpus.xvectors_loaded == True # noqa
)
)
.limit(1)
.first()
Expand Down
10 changes: 5 additions & 5 deletions montreal_forced_aligner/corpus/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ def load_text(
"\n".join(traceback.format_exception(exc_type, exc_value, exc_traceback)),
)

num_tiers = len(tg.tierNameList)
num_tiers = len(tg.tierNames)
if num_tiers == 0:
raise TextGridParseError(self.text_path, "Number of tiers parsed was zero")
for i, tier_name in enumerate(tg.tierNameList):
ti = tg.tierDict[tier_name]
for i, tier_name in enumerate(tg.tierNames):
ti = tg._tierDict[tier_name]
if tier_name.lower() == "notes":
continue
if not isinstance(ti, textgrid.IntervalTier):
Expand All @@ -181,14 +181,14 @@ def load_text(
num_channels = self.wav_info.num_channels
else:
duration = tg.maxTimestamp
for begin, end, text in ti.entryList:
for begin, end, text in ti.entries:
text = text.lower().strip()
if not text:
continue
begin, end = round(begin, 4), round(end, 4)
end = min(end, duration)
channel = 0
if num_channels == 2 and i >= i / len(tg.tierNameList):
if num_channels == 2 and i >= i / num_tiers:
channel = 1
utt = UtteranceData(
speaker_name=speaker_name,
Expand Down
12 changes: 6 additions & 6 deletions montreal_forced_aligner/corpus/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,7 +1986,7 @@ def process_ivectors(self, ivectors: np.ndarray, counts: np.ndarray = None) -> n
numpy.ndarray
Transformed ivectors
"""
# ivectors = self.preprocess_ivectors(ivectors)
ivectors = self.preprocess_ivectors(ivectors)
# ivectors = self.compute_pca_transform(ivectors)
ivectors = self.transform_ivectors(ivectors, counts=counts)
return ivectors
Expand All @@ -2005,11 +2005,12 @@ def preprocess_ivectors(self, ivectors: np.ndarray) -> np.ndarray:
numpy.ndarray
Preprocessed ivectors
"""
print(ivectors.shape)
ivectors = ivectors.T # DX N
dim = ivectors.shape[1]
# preprocessing
# mean subtraction
# ivectors = ivectors - self.mean[:, np.newaxis]
ivectors = ivectors - self.mean[:, np.newaxis]
# PCA transform
# ivectors = self.diagonalizing_transform @ ivectors
l2_norm = np.linalg.norm(ivectors, axis=0, keepdims=True)
Expand Down Expand Up @@ -2122,17 +2123,16 @@ def transform_ivectors(self, ivectors: np.ndarray, counts: np.ndarray = None) ->
# Defaults : normalize_length(true), simple_length_norm(false)
X_new_sq = X_new**2

Dim = D.shape[0]
if counts is not None:
dot_prod = np.zeros((X_new.shape[0], 1))
for i in range(dot_prod.shape[0]):
inv_covar = self.psi + (1.0 / counts[i])
inv_covar = 1.0 / inv_covar
dot_prod[i] = np.dot(X_new_sq[i], inv_covar)
normfactor = np.sqrt(Dim / dot_prod)
else:
inv_covar = (1.0 / (1.0 + self.psi)).reshape(-1, 1)
dot_prod = X_new_sq @ inv_covar # N X 1
Dim = D.shape[0]
normfactor = np.sqrt(Dim / dot_prod)
normfactor = np.sqrt(Dim) / np.sqrt(np.sum(X_new_sq))
X_new = X_new * normfactor

return X_new
Expand Down
63 changes: 17 additions & 46 deletions montreal_forced_aligner/corpus/ivector_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import List

import numpy as np
import sqlalchemy
import tqdm

from montreal_forced_aligner.config import GLOBAL_CONFIG, IVECTOR_DIMENSION
Expand Down Expand Up @@ -273,7 +272,7 @@ def transform_ivectors(self):
if os.path.exists(plda_transform_path):
with open(plda_transform_path, "rb") as f:
self.plda = pickle.load(f)
if self.has_ivectors(speaker=False) and os.path.exists(plda_transform_path):
if self.has_ivectors() and os.path.exists(plda_transform_path):
return
plda_path = (
self.adapted_plda_path if os.path.exists(self.adapted_plda_path) else self.plda_path
Expand All @@ -287,9 +286,6 @@ def transform_ivectors(self):
)
self.plda = PldaModel.load(plda_path)
with self.session() as session:
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS utterance_plda_vector_index"))
session.execute(sqlalchemy.text("ALTER TABLE utterance DISABLE TRIGGER all"))
session.commit()
query = session.query(Utterance.id, Utterance.ivector).filter(
Utterance.ivector != None # noqa
)
Expand All @@ -303,14 +299,6 @@ def transform_ivectors(self):
for i, utt_id in enumerate(utterance_ids):
update_mapping.append({"id": utt_id, "plda_vector": ivectors[i, :]})
bulk_update(session, Utterance, update_mapping)
session.execute(
sqlalchemy.text(
"CREATE INDEX utterance_plda_vector_index ON utterance "
"USING ivfflat (plda_vector vector_cosine_ops) "
"WITH (lists = 100)"
)
)
session.execute(sqlalchemy.text("ALTER TABLE utterance ENABLE TRIGGER all"))
session.commit()
with open(plda_transform_path, "wb") as f:
pickle.dump(self.plda, f)
Expand All @@ -328,9 +316,6 @@ def collect_utterance_ivectors(self) -> None:
with self.session() as session, tqdm.tqdm(
total=self.num_utterances, disable=GLOBAL_CONFIG.quiet
) as pbar:
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS utterance_ivector_index"))
session.execute(sqlalchemy.text("ALTER TABLE utterance DISABLE TRIGGER all"))
session.commit()
update_mapping = {}
for j in self.jobs:
ivector_scp_path = j.construct_path(self.split_directory, "ivectors", "scp")
Expand Down Expand Up @@ -375,37 +360,30 @@ def collect_utterance_ivectors(self) -> None:
pbar.update(1)
bulk_update(session, Utterance, list(update_mapping.values()))
session.query(Corpus).update({Corpus.ivectors_calculated: True})
session.execute(
sqlalchemy.text(
"CREATE INDEX utterance_ivector_index ON utterance "
"USING ivfflat (ivector vector_cosine_ops) "
"WITH (lists = 100)"
)
)
session.execute(sqlalchemy.text("ALTER TABLE utterance ENABLE TRIGGER all"))
session.commit()
self._write_ivectors()
self.transform_ivectors()

def collect_speaker_ivectors(self) -> None:
"""Collect trained per-speaker ivectors"""
if self.has_ivectors(speaker=True):
return
if self.plda is None:
self.collect_utterance_ivectors()
logger.info("Collecting speaker ivectors...")
speaker_ivector_ark_path = os.path.join(
self.working_directory, "current_speaker_ivectors.ark"
)
num_utts_path = os.path.join(self.working_directory, "current_num_utts.ark")
if not os.path.exists(speaker_ivector_ark_path):
self.compute_speaker_ivectors()
with self.session() as session, tqdm.tqdm(
total=self.num_speakers, disable=GLOBAL_CONFIG.quiet
) as pbar:
session.execute(sqlalchemy.text("ALTER TABLE speaker DISABLE TRIGGER all"))
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS speaker_ivector_index"))
session.execute(sqlalchemy.text("DROP INDEX IF EXISTS speaker_plda_vector_index"))
session.commit()
utterance_counts = {}
with open(num_utts_path) as f:
for line in f:
speaker, utt_count = line.strip().split()
utt_count = int(utt_count)
utterance_counts[int(speaker)] = utt_count
copy_proc = subprocess.Popen(
[thirdparty_binary("copy-vector"), f"ark:{speaker_ivector_ark_path}", "ark,t:-"],
stdout=subprocess.PIPE,
Expand All @@ -414,31 +392,24 @@ def collect_speaker_ivectors(self) -> None:
)
ivectors = []
speaker_ids = []
speaker_counts = []
update_mapping = {}
for speaker_id, ivector in read_feats(copy_proc, raw_id=True):
speaker_id = int(speaker_id)
if speaker_id not in utterance_counts:
continue
speaker_ids.append(speaker_id)
ivectors.append(ivector)
speaker_counts.append(utterance_counts[speaker_id])
update_mapping[speaker_id] = {"id": speaker_id, "ivector": ivector}
pbar.update(1)
ivectors = np.array(ivectors)
ivectors = self.plda.process_ivectors(ivectors)
if len(ivectors.shape) < 2:
ivectors = ivectors[np.newaxis, :]
print(ivectors.shape)
speaker_counts = np.array(speaker_counts)
ivectors = self.plda.process_ivectors(ivectors, counts=speaker_counts)
for i, speaker_id in enumerate(speaker_ids):
update_mapping[speaker_id]["plda_vector"] = ivectors[i, :]
bulk_update(session, Speaker, list(update_mapping.values()))
session.execute(
sqlalchemy.text(
"CREATE INDEX speaker_ivector_index ON speaker "
"USING ivfflat (ivector vector_cosine_ops) "
"WITH (lists = 1000)"
)
)
session.execute(
sqlalchemy.text(
"CREATE INDEX speaker_plda_vector_index ON speaker "
"USING ivfflat (plda_vector vector_cosine_ops) "
"WITH (lists = 1000)"
)
)
session.execute(sqlalchemy.text("ALTER TABLE speaker ENABLE TRIGGER all"))
session.commit()
6 changes: 3 additions & 3 deletions montreal_forced_aligner/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1741,10 +1741,10 @@ def to_tg_interval(self, file_duration=None) -> Interval:
"""
if self.end < -1 or self.begin == 1000000:
raise CtmError(self)
end = round(self.end, 5)
end = round(self.end, 6)
if file_duration is not None and end > file_duration:
end = file_duration
return Interval(round(self.begin, 5), end, self.label)
end = round(file_duration, 6)
return Interval(round(self.begin, 6), end, self.label)


# noinspection PyUnresolvedReferences
Expand Down
Loading

0 comments on commit 6c589c5

Please sign in to comment.