forked from yzyouzhang/AIR-ASVspoof
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils_dsp.py
249 lines (195 loc) · 6.78 KB
/
utils_dsp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
#!/usr/bin/env python
## Adapted from https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/blob/newfunctions/
"""
util_dsp.py
Utilities for signal processing
MuLaw Code adapted from
https://github.com/fatchord/WaveRNN/blob/master/utils/distribution.py
DCT code adapted from
https://github.com/zh217/torch-dct
"""
from __future__ import absolute_import
from __future__ import print_function
import numpy as np
import torch
import torch.nn as torch_nn
import torch.nn.functional as torch_nn_func
__author__ = "Xin Wang"
__email__ = "[email protected]"
__copyright__ = "Copyright 2020, Xin Wang"
######################
### WaveForm utilities
######################
def label_2_float(x, bits):
"""Convert integer numbers to float values
Note: dtype conversion is not handled
Args:
-----
x: data to be converted Tensor.long or int, any shape.
bits: number of bits, int
Return:
-------
tensor.float
"""
return 2 * x / (2 ** bits - 1.) - 1.
def float_2_label(x, bits):
"""Convert float wavs back to integer (quantization)
Note: dtype conversion is not handled
Args:
-----
x: data to be converted Tensor.float, any shape.
bits: number of bits, int
Return:
-------
tensor.float
"""
# assert abs(x).max() <= 1.0
peak = torch.abs(x).max()
if peak > 1.0:
x /= peak
x = (x + 1.) * (2 ** bits - 1) / 2
return torch.clamp(x, 0, 2 ** bits - 1)
def mulaw_encode(x, quantization_channels, scale_to_int=True):
"""Adapted from torchaudio
https://pytorch.org/audio/functional.html mu_law_encoding
Args:
x (Tensor): Input tensor, float-valued waveforms in (-1, 1)
quantization_channels (int): Number of channels
scale_to_int: Bool
True: scale mu-law companded to int
False: return mu-law in (-1, 1)
Returns:
Tensor: Input after mu-law encoding
"""
# mu
mu = quantization_channels - 1.0
# no check on the value of x
if not x.is_floating_point():
x = x.to(torch.float)
mu = torch.tensor(mu, dtype=x.dtype, device=x.device)
x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
if scale_to_int:
x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64)
return x_mu
def mulaw_decode(x_mu, quantization_channels, input_int=True):
"""Adapted from torchaudio
https://pytorch.org/audio/functional.html mu_law_encoding
Args:
x_mu (Tensor): Input tensor
quantization_channels (int): Number of channels
Returns:
Tensor: Input after mu-law decoding (float-value waveform (-1, 1))
"""
mu = quantization_channels - 1.0
if not x_mu.is_floating_point():
x_mu = x_mu.to(torch.float)
mu = torch.tensor(mu, dtype=x_mu.dtype, device=x_mu.device)
if input_int:
x = ((x_mu) / mu) * 2 - 1.0
else:
x = x_mu
x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
return x
######################
### DCT utilities
### https://github.com/zh217/torch-dct
### LICENSE: MIT
###
######################
def dct1(x):
"""
Discrete Cosine Transform, Type I
:param x: the input signal
:return: the DCT-I of the signal over the last dimension
"""
x_shape = x.shape
x = x.view(-1, x_shape[-1])
return torch.rfft(
torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape)
def idct1(X):
"""
The inverse of DCT-I, which is just a scaled DCT-I
Our definition if idct1 is such that idct1(dct1(x)) == x
:param X: the input signal
:return: the inverse DCT-I of the signal over the last dimension
"""
n = X.shape[-1]
return dct1(X) / (2 * (n - 1))
def dct(x, norm=None):
"""
Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/ scipy.fftpack.dct.html
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last dimension
"""
x_shape = x.shape
N = x_shape[-1]
x = x.contiguous().view(-1, N)
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
Vc = torch.rfft(v, 1, onesided=False)
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
if norm == 'ortho':
V[:, 0] /= np.sqrt(N) * 2
V[:, 1:] /= np.sqrt(N / 2) * 2
V = 2 * V.view(*x_shape)
return V
def idct(X, norm=None):
"""
The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
Our definition of idct is that idct(dct(x)) == x
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/ scipy.fftpack.dct.html
:param X: the input signal
:param norm: the normalization, None or 'ortho'
:return: the inverse DCT-II of the signal over the last dimension
"""
x_shape = X.shape
N = x_shape[-1]
X_v = X.contiguous().view(-1, x_shape[-1]) / 2
if norm == 'ortho':
X_v[:, 0] *= np.sqrt(N) * 2
X_v[:, 1:] *= np.sqrt(N / 2) * 2
k = torch.arange(x_shape[-1], dtype=X.dtype,
device=X.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)
V_t_r = X_v
V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
V_r = V_t_r * W_r - V_t_i * W_i
V_i = V_t_r * W_i + V_t_i * W_r
V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
v = torch.irfft(V, 1, onesided=False)
x = v.new_zeros(v.shape)
x[:, ::2] += v[:, :N - (N // 2)]
x[:, 1::2] += v.flip([1])[:, :N // 2]
return x.view(*x_shape)
class LinearDCT(torch_nn.Linear):
"""Implement any DCT as a linear layer; in practice this executes around
50x faster on GPU. Unfortunately, the DCT matrix is stored, which will
increase memory usage.
:param in_features: size of expected input
:param type: which dct function in this file to use"""
def __init__(self, in_features, type, norm=None, bias=False):
self.type = type
self.N = in_features
self.norm = norm
super(LinearDCT, self).__init__(in_features, in_features, bias=bias)
def reset_parameters(self):
# initialise using dct function
I = torch.eye(self.N)
if self.type == 'dct1':
self.weight.data = dct1(I).data.t()
elif self.type == 'idct1':
self.weight.data = idct1(I).data.t()
elif self.type == 'dct':
self.weight.data = dct(I, norm=self.norm).data.t()
elif self.type == 'idct':
self.weight.data = idct(I, norm=self.norm).data.t()
self.weight.requires_grad = False # don't learn this!
if __name__ == "__main__":
print("util_dsp.py")