Skip to content

Commit

Permalink
optimize the cli (#15)
Browse files Browse the repository at this point in the history
* support command line mode

* support commandline mode

* format

* format

* check

* optimize the cli

* remove comment
  • Loading branch information
KyleZhang1118 authored Oct 31, 2024
1 parent 5b3e808 commit f0e479e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 3 additions & 2 deletions wesep/cli/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, model_dir: str):
self.resample_rate = configs["dataset_args"].get("resample_rate", 16000)
self.apply_vad = False
self.device = torch.device("cpu")
self.wavform_norm = False
self.wavform_norm = True

self.speaker_feat = configs["model_args"]["tse_model"].get("spk_feat", False)
self.joint_training = configs["model_args"]["tse_model"].get(
Expand Down Expand Up @@ -165,7 +165,6 @@ def main():
model = load_model("bsrnn")
else:
model = load_model(args.language)
model.set_wavform_norm(True)
else:
model = load_model_local(args.pretrain)
model.set_resample_rate(args.resample_rate)
Expand All @@ -174,6 +173,8 @@ def main():
if args.task == "extraction":
speech = model.extract_speech(args.audio_file, args.audio_file2)
if speech is not None:
if args.normalize_output:
speech = speech / abs(speech).max(dim=1, keepdim=True).values * 0.9
soundfile.write(args.output_file, speech[0], args.resample_rate)
print("Succeed, see {}".format(args.output_file))
else:
Expand Down
5 changes: 5 additions & 0 deletions wesep/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,10 @@ def get_args():
default='./extracted_speech.wav',
help="extracted speech saved in .wav"
)
parser.add_argument(
"--normalize_output",
default=True,
help="Control if normalize the ouput audio in .wav"
)
args = parser.parse_args()
return args

0 comments on commit f0e479e

Please sign in to comment.