Skip to content

Commit d7bf915

Browse files
authored
【Hackathon 8th No.9】在 PaddleSpeech 中复现 DAC 训练需要用到的 loss (#3988)
* add DAC loss * fix bug * fix codestyle * fix codestyle * fix codestyle * fix codestyle * fix codestyle * fix codestyle
1 parent afa6f12 commit d7bf915

File tree

9 files changed

+351
-16
lines changed

9 files changed

+351
-16
lines changed

paddlespeech/__init__.py

-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,3 @@
1313
# limitations under the License.
1414
import _locale
1515
_locale._getdefaultlocale = (lambda *args: ['en_US', 'utf8'])
16-
17-
__version__ = '0.0.0'
18-
19-
__commit__ = '9cf8c1985a98bb380c183116123672976bdfe5c9'

paddlespeech/audiotools/core/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from . import util
15-
from ._julius import fft_conv1d
16-
from ._julius import FFTConv1D
15+
from ...t2s.modules import fft_conv1d
16+
from ...t2s.modules import FFTConv1D
1717
from ._julius import highpass_filter
1818
from ._julius import highpass_filters
1919
from ._julius import lowpass_filter

paddlespeech/audiotools/core/_julius.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import paddle.nn as nn
2121
import paddle.nn.functional as F
2222

23-
from paddlespeech.t2s.modules import fft_conv1d
24-
from paddlespeech.t2s.modules import FFTConv1D
2523
from paddlespeech.utils import satisfy_paddle_version
2624

2725
__all__ = [
@@ -312,6 +310,7 @@ def forward(self, _input):
312310
mode="replicate",
313311
data_format="NCL")
314312
if self.fft:
313+
from paddlespeech.t2s.modules import fft_conv1d
315314
out = fft_conv1d(_input, self.filters, stride=self.stride)
316315
else:
317316
out = F.conv1d(_input, self.filters, stride=self.stride)

paddlespeech/audiotools/core/util.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from flatten_dict import flatten
3333
from flatten_dict import unflatten
3434

35-
from .audio_signal import AudioSignal
3635
from paddlespeech.utils import satisfy_paddle_version
3736
from paddlespeech.vector.training.seeding import seed_everything
3837

@@ -232,8 +231,7 @@ def ensure_tensor(
232231

233232
def _get_value(other):
234233
#
235-
from . import AudioSignal
236-
234+
from .audio_signal import AudioSignal
237235
if isinstance(other, AudioSignal):
238236
return other.audio_data
239237
return other
@@ -784,6 +782,8 @@ def collate(list_of_dicts: list, n_splits: int=None):
784782
Dictionary containing batched data.
785783
"""
786784

785+
from .audio_signal import AudioSignal
786+
787787
batches = []
788788
list_len = len(list_of_dicts)
789789

@@ -873,7 +873,7 @@ def generate_chord_dataset(
873873
874874
"""
875875
import librosa
876-
from . import AudioSignal
876+
from .audio_signal import AudioSignal
877877
from ..data.preprocess import create_csv
878878

879879
min_midi = librosa.note_to_midi(min_note)

paddlespeech/t2s/modules/losses.py

+279
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15+
from typing import Callable
16+
from typing import List
17+
from typing import Optional
1518
from typing import Tuple
19+
from typing import Union
1620

1721
import librosa
1822
import numpy as np
@@ -23,6 +27,8 @@
2327
from scipy.stats import betabinom
2428
from typeguard import typechecked
2529

30+
from paddlespeech.audiotools.core.audio_signal import AudioSignal
31+
from paddlespeech.audiotools.core.audio_signal import STFTParams
2632
from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
2733
from paddlespeech.t2s.modules.predictor.duration_predictor import (
2834
DurationPredictorLoss, # noqa: H301
@@ -1326,3 +1332,276 @@ def _generate_prior(self, text_lengths, feats_lengths,
13261332
bb_prior[bidx, :T, :N] = prob
13271333

13281334
return bb_prior
1335+
1336+
1337+
class MultiScaleSTFTLoss(nn.Layer):
1338+
"""Computes the multi-scale STFT loss from [1].
1339+
1340+
References
1341+
----------
1342+
1343+
1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
1344+
"DDSP: Differentiable Digital Signal Processing."
1345+
International Conference on Learning Representations. 2019.
1346+
1347+
Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/spectral.py
1348+
"""
1349+
1350+
def __init__(
1351+
self,
1352+
window_lengths: List[int]=[2048, 512],
1353+
loss_fn: Callable=nn.L1Loss(),
1354+
clamp_eps: float=1e-5,
1355+
mag_weight: float=1.0,
1356+
log_weight: float=1.0,
1357+
pow: float=2.0,
1358+
weight: float=1.0,
1359+
match_stride: bool=False,
1360+
window_type: Optional[str]=None, ):
1361+
"""
1362+
Args:
1363+
window_lengths : List[int], optional
1364+
Length of each window of each STFT, by default [2048, 512]
1365+
loss_fn : typing.Callable, optional
1366+
How to compare each loss, by default nn.L1Loss()
1367+
clamp_eps : float, optional
1368+
Clamp on the log magnitude, below, by default 1e-5
1369+
mag_weight : float, optional
1370+
Weight of raw magnitude portion of loss, by default 1.0
1371+
log_weight : float, optional
1372+
Weight of log magnitude portion of loss, by default 1.0
1373+
pow : float, optional
1374+
Power to raise magnitude to before taking log, by default 2.0
1375+
weight : float, optional
1376+
Weight of this loss, by default 1.0
1377+
match_stride : bool, optional
1378+
Whether to match the stride of convolutional layers, by default False
1379+
window_type : str, optional
1380+
Type of window to use, by default None.
1381+
"""
1382+
super().__init__()
1383+
1384+
self.stft_params = [
1385+
STFTParams(
1386+
window_length=w,
1387+
hop_length=w // 4,
1388+
match_stride=match_stride,
1389+
window_type=window_type, ) for w in window_lengths
1390+
]
1391+
self.loss_fn = loss_fn
1392+
self.log_weight = log_weight
1393+
self.mag_weight = mag_weight
1394+
self.clamp_eps = clamp_eps
1395+
self.weight = weight
1396+
self.pow = pow
1397+
1398+
def forward(self, x: AudioSignal, y: AudioSignal):
1399+
"""Computes multi-scale STFT between an estimate and a reference
1400+
signal.
1401+
1402+
Args:
1403+
x : AudioSignal
1404+
Estimate signal
1405+
y : AudioSignal
1406+
Reference signal
1407+
1408+
Returns:
1409+
paddle.Tensor
1410+
Multi-scale STFT loss.
1411+
1412+
Example:
1413+
>>> from paddlespeech.audiotools.core.audio_signal import AudioSignal
1414+
>>> import paddle
1415+
1416+
>>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05)
1417+
>>> y = x * 0.01
1418+
>>> loss = MultiScaleSTFTLoss()
1419+
>>> loss(x, y).numpy()
1420+
7.562150
1421+
"""
1422+
for s in self.stft_params:
1423+
x.stft(s.window_length, s.hop_length, s.window_type)
1424+
y.stft(s.window_length, s.hop_length, s.window_type)
1425+
loss += self.log_weight * self.loss_fn(
1426+
x.magnitude.clip(self.clamp_eps).pow(self.pow).log10(),
1427+
y.magnitude.clip(self.clamp_eps).pow(self.pow).log10(), )
1428+
loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
1429+
return loss
1430+
1431+
1432+
class GANLoss(nn.Layer):
1433+
"""
1434+
Computes a discriminator loss, given a discriminator on
1435+
generated waveforms/spectrograms compared to ground truth
1436+
waveforms/spectrograms. Computes the loss for both the
1437+
discriminator and the generator in separate functions.
1438+
1439+
Example:
1440+
>>> from paddlespeech.audiotools.core.audio_signal import AudioSignal
1441+
>>> import paddle
1442+
1443+
>>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05)
1444+
>>> y = x * 0.01
1445+
>>> class My_discriminator0:
1446+
>>> def __call__(self, x):
1447+
>>> return x.sum()
1448+
>>> loss = GANLoss(My_discriminator0())
1449+
>>> [loss(x, y)[0].numpy(), loss(x, y)[1].numpy()]
1450+
[-0.102722, -0.001027]
1451+
1452+
>>> class My_discriminator1:
1453+
>>> def __call__(self, x):
1454+
>>> return x.sum()
1455+
>>> loss = GANLoss(My_discriminator1())
1456+
>>> [loss.generator_loss(x, y)[0].numpy(), loss.generator_loss(x, y)[1].numpy()]
1457+
[1.00019, 0]
1458+
1459+
>>> loss.discriminator_loss(x, y)
1460+
1.000200
1461+
"""
1462+
1463+
def __init__(self, discriminator):
1464+
"""
1465+
Args:
1466+
discriminator : paddle.nn.layer
1467+
Discriminator model
1468+
"""
1469+
super().__init__()
1470+
self.discriminator = discriminator
1471+
1472+
def forward(self,
1473+
fake: Union[AudioSignal, paddle.Tensor],
1474+
real: Union[AudioSignal, paddle.Tensor]):
1475+
if isinstance(fake, AudioSignal):
1476+
d_fake = self.discriminator(fake.audio_data)
1477+
else:
1478+
d_fake = self.discriminator(fake)
1479+
1480+
if isinstance(real, AudioSignal):
1481+
d_real = self.discriminator(real.audio_data)
1482+
else:
1483+
d_real = self.discriminator(real)
1484+
return d_fake, d_real
1485+
1486+
def discriminator_loss(self, fake, real):
1487+
d_fake, d_real = self.forward(fake, real)
1488+
1489+
loss_d = 0
1490+
for x_fake, x_real in zip(d_fake, d_real):
1491+
loss_d += paddle.mean(x_fake[-1]**2)
1492+
loss_d += paddle.mean((1 - x_real[-1])**2)
1493+
return loss_d
1494+
1495+
def generator_loss(self, fake, real):
1496+
d_fake, d_real = self.forward(fake, real)
1497+
1498+
loss_g = 0
1499+
for x_fake in d_fake:
1500+
loss_g += paddle.mean((1 - x_fake[-1])**2)
1501+
1502+
loss_feature = 0
1503+
1504+
for i in range(len(d_fake)):
1505+
for j in range(len(d_fake[i]) - 1):
1506+
loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j]())
1507+
return loss_g, loss_feature
1508+
1509+
1510+
class SISDRLoss(nn.Layer):
1511+
"""
1512+
Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
1513+
of estimated and reference audio signals or aligned features.
1514+
1515+
Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/distance.py
1516+
1517+
Example:
1518+
>>> from paddlespeech.audiotools.core.audio_signal import AudioSignal
1519+
>>> import paddle
1520+
1521+
>>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05)
1522+
>>> y = x * 0.01
1523+
>>> sisdr = SISDRLoss()
1524+
>>> sisdr(x, y).numpy()
1525+
-145.377640
1526+
"""
1527+
1528+
def __init__(
1529+
self,
1530+
scaling: bool=True,
1531+
reduction: str="mean",
1532+
zero_mean: bool=True,
1533+
clip_min: Optional[int]=None,
1534+
weight: float=1.0, ):
1535+
"""
1536+
Args:
1537+
scaling : bool, optional
1538+
Whether to use scale-invariant (True) or
1539+
signal-to-noise ratio (False), by default True
1540+
reduction : str, optional
1541+
How to reduce across the batch (either 'mean',
1542+
'sum', or none).], by default ' mean'
1543+
zero_mean : bool, optional
1544+
Zero mean the references and estimates before
1545+
computing the loss, by default True
1546+
clip_min : int, optional
1547+
The minimum possible loss value. Helps network
1548+
to not focus on making already good examples better, by default None
1549+
weight : float, optional
1550+
Weight of this loss, defaults to 1.0.
1551+
"""
1552+
self.scaling = scaling
1553+
self.reduction = reduction
1554+
self.zero_mean = zero_mean
1555+
self.clip_min = clip_min
1556+
self.weight = weight
1557+
super().__init__()
1558+
1559+
def forward(self,
1560+
x: Union[AudioSignal, paddle.Tensor],
1561+
y: Union[AudioSignal, paddle.Tensor]):
1562+
eps = 1e-8
1563+
# B, C, T
1564+
if isinstance(x, AudioSignal):
1565+
references = x.audio_data
1566+
estimates = y.audio_data
1567+
else:
1568+
references = x
1569+
estimates = y
1570+
1571+
nb = references.shape[0]
1572+
references = references.reshape([nb, 1, -1]).transpose([0, 2, 1])
1573+
estimates = estimates.reshape([nb, 1, -1]).transpose([0, 2, 1])
1574+
1575+
# samples now on axis 1
1576+
if self.zero_mean:
1577+
mean_reference = references.mean(axis=1, keepdim=True)
1578+
mean_estimate = estimates.mean(axis=1, keepdim=True)
1579+
else:
1580+
mean_reference = 0
1581+
mean_estimate = 0
1582+
1583+
_references = references - mean_reference
1584+
_estimates = estimates - mean_estimate
1585+
1586+
references_projection = (_references**2).sum(axis=-2) + eps
1587+
references_on_estimates = (_estimates * _references).sum(axis=-2) + eps
1588+
1589+
scale = (
1590+
(references_on_estimates / references_projection).unsqueeze(axis=1)
1591+
if self.scaling else 1)
1592+
1593+
e_true = scale * _references
1594+
e_res = _estimates - e_true
1595+
1596+
signal = (e_true**2).sum(axis=1)
1597+
noise = (e_res**2).sum(axis=1)
1598+
sdr = -10 * paddle.log10(signal / noise + eps)
1599+
1600+
if self.clip_min != None:
1601+
sdr = paddle.clip(sdr, min=self.clip_min)
1602+
1603+
if self.reduction == "mean":
1604+
sdr = sdr.mean()
1605+
elif self.reduction == "sum":
1606+
sdr = sdr.sum()
1607+
return sdr

tests/unit/audiotools/core/test_util.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from paddlespeech.audiotools import util
1515
from paddlespeech.audiotools.core.audio_signal import AudioSignal
16-
from paddlespeech.vector.training.seeding import seed_everything
1716

1817

1918
def test_check_random_state():
@@ -36,12 +35,12 @@ def test_check_random_state():
3635

3736

3837
def test_seed():
39-
seed_everything(0)
38+
util.seed_everything(0)
4039
paddle_result_a = paddle.randn([1])
4140
np_result_a = np.random.randn(1)
4241
py_result_a = random.random()
4342

44-
seed_everything(0)
43+
util.seed_everything(0)
4544
paddle_result_b = paddle.randn([1])
4645
np_result_b = np.random.randn(1)
4746
py_result_b = random.random()

tests/unit/audiotools/test_audiotools.sh

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
python -m pip install -r ../../../paddlespeech/audiotools/requirements.txt
21
wget https://paddlespeech.bj.bcebos.com/PaddleAudio/audio_tools/audio.tar.gz
32
wget https://paddlespeech.bj.bcebos.com/PaddleAudio/audio_tools/regression.tar.gz
43
tar -zxvf audio.tar.gz

0 commit comments

Comments
 (0)