|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | import math
|
| 15 | +from typing import Callable |
| 16 | +from typing import List |
| 17 | +from typing import Optional |
15 | 18 | from typing import Tuple
|
| 19 | +from typing import Union |
16 | 20 |
|
17 | 21 | import librosa
|
18 | 22 | import numpy as np
|
|
23 | 27 | from scipy.stats import betabinom
|
24 | 28 | from typeguard import typechecked
|
25 | 29 |
|
| 30 | +from paddlespeech.audiotools.core.audio_signal import AudioSignal |
| 31 | +from paddlespeech.audiotools.core.audio_signal import STFTParams |
26 | 32 | from paddlespeech.t2s.modules.nets_utils import make_non_pad_mask
|
27 | 33 | from paddlespeech.t2s.modules.predictor.duration_predictor import (
|
28 | 34 | DurationPredictorLoss, # noqa: H301
|
@@ -1326,3 +1332,276 @@ def _generate_prior(self, text_lengths, feats_lengths,
|
1326 | 1332 | bb_prior[bidx, :T, :N] = prob
|
1327 | 1333 |
|
1328 | 1334 | return bb_prior
|
| 1335 | + |
| 1336 | + |
| 1337 | +class MultiScaleSTFTLoss(nn.Layer): |
| 1338 | + """Computes the multi-scale STFT loss from [1]. |
| 1339 | +
|
| 1340 | + References |
| 1341 | + ---------- |
| 1342 | +
|
| 1343 | + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. |
| 1344 | + "DDSP: Differentiable Digital Signal Processing." |
| 1345 | + International Conference on Learning Representations. 2019. |
| 1346 | +
|
| 1347 | + Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/spectral.py |
| 1348 | + """ |
| 1349 | + |
| 1350 | + def __init__( |
| 1351 | + self, |
| 1352 | + window_lengths: List[int]=[2048, 512], |
| 1353 | + loss_fn: Callable=nn.L1Loss(), |
| 1354 | + clamp_eps: float=1e-5, |
| 1355 | + mag_weight: float=1.0, |
| 1356 | + log_weight: float=1.0, |
| 1357 | + pow: float=2.0, |
| 1358 | + weight: float=1.0, |
| 1359 | + match_stride: bool=False, |
| 1360 | + window_type: Optional[str]=None, ): |
| 1361 | + """ |
| 1362 | + Args: |
| 1363 | + window_lengths : List[int], optional |
| 1364 | + Length of each window of each STFT, by default [2048, 512] |
| 1365 | + loss_fn : typing.Callable, optional |
| 1366 | + How to compare each loss, by default nn.L1Loss() |
| 1367 | + clamp_eps : float, optional |
| 1368 | + Clamp on the log magnitude, below, by default 1e-5 |
| 1369 | + mag_weight : float, optional |
| 1370 | + Weight of raw magnitude portion of loss, by default 1.0 |
| 1371 | + log_weight : float, optional |
| 1372 | + Weight of log magnitude portion of loss, by default 1.0 |
| 1373 | + pow : float, optional |
| 1374 | + Power to raise magnitude to before taking log, by default 2.0 |
| 1375 | + weight : float, optional |
| 1376 | + Weight of this loss, by default 1.0 |
| 1377 | + match_stride : bool, optional |
| 1378 | + Whether to match the stride of convolutional layers, by default False |
| 1379 | + window_type : str, optional |
| 1380 | + Type of window to use, by default None. |
| 1381 | + """ |
| 1382 | + super().__init__() |
| 1383 | + |
| 1384 | + self.stft_params = [ |
| 1385 | + STFTParams( |
| 1386 | + window_length=w, |
| 1387 | + hop_length=w // 4, |
| 1388 | + match_stride=match_stride, |
| 1389 | + window_type=window_type, ) for w in window_lengths |
| 1390 | + ] |
| 1391 | + self.loss_fn = loss_fn |
| 1392 | + self.log_weight = log_weight |
| 1393 | + self.mag_weight = mag_weight |
| 1394 | + self.clamp_eps = clamp_eps |
| 1395 | + self.weight = weight |
| 1396 | + self.pow = pow |
| 1397 | + |
| 1398 | + def forward(self, x: AudioSignal, y: AudioSignal): |
| 1399 | + """Computes multi-scale STFT between an estimate and a reference |
| 1400 | + signal. |
| 1401 | +
|
| 1402 | + Args: |
| 1403 | + x : AudioSignal |
| 1404 | + Estimate signal |
| 1405 | + y : AudioSignal |
| 1406 | + Reference signal |
| 1407 | +
|
| 1408 | + Returns: |
| 1409 | + paddle.Tensor |
| 1410 | + Multi-scale STFT loss. |
| 1411 | + |
| 1412 | + Example: |
| 1413 | + >>> from paddlespeech.audiotools.core.audio_signal import AudioSignal |
| 1414 | + >>> import paddle |
| 1415 | +
|
| 1416 | + >>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05) |
| 1417 | + >>> y = x * 0.01 |
| 1418 | + >>> loss = MultiScaleSTFTLoss() |
| 1419 | + >>> loss(x, y).numpy() |
| 1420 | + 7.562150 |
| 1421 | + """ |
| 1422 | + for s in self.stft_params: |
| 1423 | + x.stft(s.window_length, s.hop_length, s.window_type) |
| 1424 | + y.stft(s.window_length, s.hop_length, s.window_type) |
| 1425 | + loss += self.log_weight * self.loss_fn( |
| 1426 | + x.magnitude.clip(self.clamp_eps).pow(self.pow).log10(), |
| 1427 | + y.magnitude.clip(self.clamp_eps).pow(self.pow).log10(), ) |
| 1428 | + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) |
| 1429 | + return loss |
| 1430 | + |
| 1431 | + |
| 1432 | +class GANLoss(nn.Layer): |
| 1433 | + """ |
| 1434 | + Computes a discriminator loss, given a discriminator on |
| 1435 | + generated waveforms/spectrograms compared to ground truth |
| 1436 | + waveforms/spectrograms. Computes the loss for both the |
| 1437 | + discriminator and the generator in separate functions. |
| 1438 | +
|
| 1439 | + Example: |
| 1440 | + >>> from paddlespeech.audiotools.core.audio_signal import AudioSignal |
| 1441 | + >>> import paddle |
| 1442 | +
|
| 1443 | + >>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05) |
| 1444 | + >>> y = x * 0.01 |
| 1445 | + >>> class My_discriminator0: |
| 1446 | + >>> def __call__(self, x): |
| 1447 | + >>> return x.sum() |
| 1448 | + >>> loss = GANLoss(My_discriminator0()) |
| 1449 | + >>> [loss(x, y)[0].numpy(), loss(x, y)[1].numpy()] |
| 1450 | + [-0.102722, -0.001027] |
| 1451 | +
|
| 1452 | + >>> class My_discriminator1: |
| 1453 | + >>> def __call__(self, x): |
| 1454 | + >>> return x.sum() |
| 1455 | + >>> loss = GANLoss(My_discriminator1()) |
| 1456 | + >>> [loss.generator_loss(x, y)[0].numpy(), loss.generator_loss(x, y)[1].numpy()] |
| 1457 | + [1.00019, 0] |
| 1458 | +
|
| 1459 | + >>> loss.discriminator_loss(x, y) |
| 1460 | + 1.000200 |
| 1461 | + """ |
| 1462 | + |
| 1463 | + def __init__(self, discriminator): |
| 1464 | + """ |
| 1465 | + Args: |
| 1466 | + discriminator : paddle.nn.layer |
| 1467 | + Discriminator model |
| 1468 | + """ |
| 1469 | + super().__init__() |
| 1470 | + self.discriminator = discriminator |
| 1471 | + |
| 1472 | + def forward(self, |
| 1473 | + fake: Union[AudioSignal, paddle.Tensor], |
| 1474 | + real: Union[AudioSignal, paddle.Tensor]): |
| 1475 | + if isinstance(fake, AudioSignal): |
| 1476 | + d_fake = self.discriminator(fake.audio_data) |
| 1477 | + else: |
| 1478 | + d_fake = self.discriminator(fake) |
| 1479 | + |
| 1480 | + if isinstance(real, AudioSignal): |
| 1481 | + d_real = self.discriminator(real.audio_data) |
| 1482 | + else: |
| 1483 | + d_real = self.discriminator(real) |
| 1484 | + return d_fake, d_real |
| 1485 | + |
| 1486 | + def discriminator_loss(self, fake, real): |
| 1487 | + d_fake, d_real = self.forward(fake, real) |
| 1488 | + |
| 1489 | + loss_d = 0 |
| 1490 | + for x_fake, x_real in zip(d_fake, d_real): |
| 1491 | + loss_d += paddle.mean(x_fake[-1]**2) |
| 1492 | + loss_d += paddle.mean((1 - x_real[-1])**2) |
| 1493 | + return loss_d |
| 1494 | + |
| 1495 | + def generator_loss(self, fake, real): |
| 1496 | + d_fake, d_real = self.forward(fake, real) |
| 1497 | + |
| 1498 | + loss_g = 0 |
| 1499 | + for x_fake in d_fake: |
| 1500 | + loss_g += paddle.mean((1 - x_fake[-1])**2) |
| 1501 | + |
| 1502 | + loss_feature = 0 |
| 1503 | + |
| 1504 | + for i in range(len(d_fake)): |
| 1505 | + for j in range(len(d_fake[i]) - 1): |
| 1506 | + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j]()) |
| 1507 | + return loss_g, loss_feature |
| 1508 | + |
| 1509 | + |
| 1510 | +class SISDRLoss(nn.Layer): |
| 1511 | + """ |
| 1512 | + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch |
| 1513 | + of estimated and reference audio signals or aligned features. |
| 1514 | +
|
| 1515 | + Implementation copied from: https://github.com/descriptinc/audiotools/blob/master/audiotools/metrics/distance.py |
| 1516 | +
|
| 1517 | + Example: |
| 1518 | + >>> from paddlespeech.audiotools.core.audio_signal import AudioSignal |
| 1519 | + >>> import paddle |
| 1520 | +
|
| 1521 | + >>> x = AudioSignal("https://paddlespeech.bj.bcebos.com/PaddleAudio/en.wav", 2_05) |
| 1522 | + >>> y = x * 0.01 |
| 1523 | + >>> sisdr = SISDRLoss() |
| 1524 | + >>> sisdr(x, y).numpy() |
| 1525 | + -145.377640 |
| 1526 | + """ |
| 1527 | + |
| 1528 | + def __init__( |
| 1529 | + self, |
| 1530 | + scaling: bool=True, |
| 1531 | + reduction: str="mean", |
| 1532 | + zero_mean: bool=True, |
| 1533 | + clip_min: Optional[int]=None, |
| 1534 | + weight: float=1.0, ): |
| 1535 | + """ |
| 1536 | + Args: |
| 1537 | + scaling : bool, optional |
| 1538 | + Whether to use scale-invariant (True) or |
| 1539 | + signal-to-noise ratio (False), by default True |
| 1540 | + reduction : str, optional |
| 1541 | + How to reduce across the batch (either 'mean', |
| 1542 | + 'sum', or none).], by default ' mean' |
| 1543 | + zero_mean : bool, optional |
| 1544 | + Zero mean the references and estimates before |
| 1545 | + computing the loss, by default True |
| 1546 | + clip_min : int, optional |
| 1547 | + The minimum possible loss value. Helps network |
| 1548 | + to not focus on making already good examples better, by default None |
| 1549 | + weight : float, optional |
| 1550 | + Weight of this loss, defaults to 1.0. |
| 1551 | + """ |
| 1552 | + self.scaling = scaling |
| 1553 | + self.reduction = reduction |
| 1554 | + self.zero_mean = zero_mean |
| 1555 | + self.clip_min = clip_min |
| 1556 | + self.weight = weight |
| 1557 | + super().__init__() |
| 1558 | + |
| 1559 | + def forward(self, |
| 1560 | + x: Union[AudioSignal, paddle.Tensor], |
| 1561 | + y: Union[AudioSignal, paddle.Tensor]): |
| 1562 | + eps = 1e-8 |
| 1563 | + # B, C, T |
| 1564 | + if isinstance(x, AudioSignal): |
| 1565 | + references = x.audio_data |
| 1566 | + estimates = y.audio_data |
| 1567 | + else: |
| 1568 | + references = x |
| 1569 | + estimates = y |
| 1570 | + |
| 1571 | + nb = references.shape[0] |
| 1572 | + references = references.reshape([nb, 1, -1]).transpose([0, 2, 1]) |
| 1573 | + estimates = estimates.reshape([nb, 1, -1]).transpose([0, 2, 1]) |
| 1574 | + |
| 1575 | + # samples now on axis 1 |
| 1576 | + if self.zero_mean: |
| 1577 | + mean_reference = references.mean(axis=1, keepdim=True) |
| 1578 | + mean_estimate = estimates.mean(axis=1, keepdim=True) |
| 1579 | + else: |
| 1580 | + mean_reference = 0 |
| 1581 | + mean_estimate = 0 |
| 1582 | + |
| 1583 | + _references = references - mean_reference |
| 1584 | + _estimates = estimates - mean_estimate |
| 1585 | + |
| 1586 | + references_projection = (_references**2).sum(axis=-2) + eps |
| 1587 | + references_on_estimates = (_estimates * _references).sum(axis=-2) + eps |
| 1588 | + |
| 1589 | + scale = ( |
| 1590 | + (references_on_estimates / references_projection).unsqueeze(axis=1) |
| 1591 | + if self.scaling else 1) |
| 1592 | + |
| 1593 | + e_true = scale * _references |
| 1594 | + e_res = _estimates - e_true |
| 1595 | + |
| 1596 | + signal = (e_true**2).sum(axis=1) |
| 1597 | + noise = (e_res**2).sum(axis=1) |
| 1598 | + sdr = -10 * paddle.log10(signal / noise + eps) |
| 1599 | + |
| 1600 | + if self.clip_min != None: |
| 1601 | + sdr = paddle.clip(sdr, min=self.clip_min) |
| 1602 | + |
| 1603 | + if self.reduction == "mean": |
| 1604 | + sdr = sdr.mean() |
| 1605 | + elif self.reduction == "sum": |
| 1606 | + sdr = sdr.sum() |
| 1607 | + return sdr |
0 commit comments