Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add audio padding #74

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions hallo/datasets/audio_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(
self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True)


def preprocess(self, wav_file: str):
def preprocess(self, wav_file: str, clip_length: int):
"""
Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate.
The separated vocal track is then converted into wav2vec2 for further processing or analysis.
Expand Down Expand Up @@ -106,8 +106,12 @@ def preprocess(self, wav_file: str):
speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate)
audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values)
seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps)
audio_length = seq_len

audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device)
if seq_len % clip_length != 0:
audio_feature = torch.nn.functional.pad(audio_feature, (0, (clip_length - seq_len % clip_length) * (self.sample_rate // self.fps)), 'constant', 0.0)
seq_len += clip_length - seq_len % clip_length
audio_feature = audio_feature.unsqueeze(0)

with torch.no_grad():
Expand All @@ -121,7 +125,7 @@ def preprocess(self, wav_file: str):

audio_emb = audio_emb.cpu().detach()

return audio_emb
return audio_emb, audio_length

def get_embedding(self, wav_file: str):
"""preprocess wav audio file convert to embeddings
Expand Down
3 changes: 2 additions & 1 deletion scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def inference_process(args: argparse.Namespace):
os.path.basename(audio_separator_model_file),
os.path.join(save_path, "audio_preprocess")
) as audio_processor:
audio_emb = audio_processor.preprocess(driving_audio_path)
audio_emb, audio_length = audio_processor.preprocess(driving_audio_path, clip_length)

# 4. build modules
sched_kwargs = OmegaConf.to_container(config.noise_scheduler_kwargs)
Expand Down Expand Up @@ -339,6 +339,7 @@ def inference_process(args: argparse.Namespace):

tensor_result = torch.cat(tensor_result, dim=2)
tensor_result = tensor_result.squeeze(0)
tensor_result = tensor_result[:, :audio_length]

output_file = config.output
# save the result after all iteration
Expand Down
Loading