diff --git a/medkit/audio/transcription/sb_transcriber.py b/medkit/audio/transcription/sb_transcriber.py index 647d4cab..ea666e4c 100644 --- a/medkit/audio/transcription/sb_transcriber.py +++ b/medkit/audio/transcription/sb_transcriber.py @@ -89,7 +89,9 @@ def __init__( self.batch_size = batch_size self._torch_device = "cpu" if self.device < 0 else f"cuda:{self.device}" - asr_class = speechbrain.pretrained.EncoderDecoderASR if needs_decoder else speechbrain.pretrained.EncoderASR + asr_class = ( + speechbrain.inference.ASR.EncoderDecoderASR if needs_decoder else speechbrain.inference.ASR.EncoderASR + ) self._asr = asr_class.from_hparams(source=model, savedir=cache_dir, run_opts={"device": self._torch_device}) diff --git a/pyproject.toml b/pyproject.toml index 13c885ec..79ee606c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ metrics-text-classification = [ "scikit-learn>=1.3.2", ] metrics-transcription = [ - "speechbrain>=0.5", + "speechbrain>=1.0", ] nlstruct = [ "huggingface-hub", diff --git a/tests/unit/audio/transcription/test_sb_transcriber.py b/tests/unit/audio/transcription/test_sb_transcriber.py index 00590658..8bac4b10 100644 --- a/tests/unit/audio/transcription/test_sb_transcriber.py +++ b/tests/unit/audio/transcription/test_sb_transcriber.py @@ -56,8 +56,8 @@ def __init__(self): @pytest.fixture(scope="module", autouse=True) def _mocked_asr(module_mocker): - module_mocker.patch("speechbrain.pretrained.EncoderASR", _MockSpeechbrainASR) - module_mocker.patch("speechbrain.pretrained.EncoderDecoderASR", _MockSpeechbrainASR) + module_mocker.patch("speechbrain.inference.ASR.EncoderASR", _MockSpeechbrainASR) + module_mocker.patch("speechbrain.inference.ASR.EncoderDecoderASR", _MockSpeechbrainASR) def _gen_segment(nb_samples) -> Segment: