Skip to content

Commit ef99d33

Browse files
authored
Merge pull request #7 from NoFish-528/issue6
[bugs]:fix #6
2 parents b51129f + 6ced7cc commit ef99d33

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

train_multi_gpu.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,20 @@ def train_one_step(epoch,optimizer,optimizer_disc, model, disc_model, trainloade
4545
optimizer_disc.zero_grad()
4646
output, loss_w, _ = model(input_wav) #output: [B, 1, T]: eg. [2, 1, 203760] | loss_w: [1]
4747
logits_real, fmap_real = disc_model(input_wav)
48-
# train discriminator when epoch > warmup_epoch and train_discriminator is True
49-
if config.model.train_discriminator and epoch > config.lr_scheduler.warmup_epoch:
50-
logits_fake, _ = disc_model(output.detach()) # detach to avoid backpropagation to model
51-
loss_disc = disc_loss(logits_real, logits_fake) # compute discriminator loss
52-
loss_disc.backward(retain_graph=True)
53-
optimizer_disc.step()
54-
5548
logits_fake, fmap_fake = disc_model(output)
5649
loss_g = total_loss(fmap_real, logits_fake, fmap_fake, input_wav, output)
5750
loss = loss_g + loss_w
5851
loss.backward()
5952
optimizer.step()
60-
6153
scheduler.step()
62-
disc_scheduler.step()
63-
54+
# train discriminator when epoch > warmup_epoch and train_discriminator is True
55+
if config.model.train_discriminator and epoch > config.lr_scheduler.warmup_epoch:
56+
logits_fake, _ = disc_model(output.detach()) # detach to avoid backpropagation to model
57+
loss_disc = disc_loss([logit_real.detach() for logit_real in logits_real], logits_fake) # compute discriminator loss
58+
loss_disc.backward()
59+
optimizer_disc.step()
60+
disc_scheduler.step()
61+
6462
if not config.distributed.data_parallel or dist.get_rank()==0:
6563
logger.info(f'| epoch: {epoch} | loss: {loss.item()} | loss_g: {loss_g.item()} | loss_w: {loss_w.item()} | lr: {optimizer.param_groups[0]["lr"]} | disc_lr: {optimizer_disc.param_groups[0]["lr"]}')
6664
if config.model.train_discriminator and epoch > config.lr_scheduler.warmup_epoch:
@@ -261,4 +259,4 @@ def main(config):
261259

262260

263261
if __name__ == '__main__':
264-
main()
262+
main()

0 commit comments

Comments
 (0)