@@ -45,22 +45,20 @@ def train_one_step(epoch,optimizer,optimizer_disc, model, disc_model, trainloade
45
45
optimizer_disc .zero_grad ()
46
46
output , loss_w , _ = model (input_wav ) #output: [B, 1, T]: eg. [2, 1, 203760] | loss_w: [1]
47
47
logits_real , fmap_real = disc_model (input_wav )
48
- # train discriminator when epoch > warmup_epoch and train_discriminator is True
49
- if config .model .train_discriminator and epoch > config .lr_scheduler .warmup_epoch :
50
- logits_fake , _ = disc_model (output .detach ()) # detach to avoid backpropagation to model
51
- loss_disc = disc_loss (logits_real , logits_fake ) # compute discriminator loss
52
- loss_disc .backward (retain_graph = True )
53
- optimizer_disc .step ()
54
-
55
48
logits_fake , fmap_fake = disc_model (output )
56
49
loss_g = total_loss (fmap_real , logits_fake , fmap_fake , input_wav , output )
57
50
loss = loss_g + loss_w
58
51
loss .backward ()
59
52
optimizer .step ()
60
-
61
53
scheduler .step ()
62
- disc_scheduler .step ()
63
-
54
+ # train discriminator when epoch > warmup_epoch and train_discriminator is True
55
+ if config .model .train_discriminator and epoch > config .lr_scheduler .warmup_epoch :
56
+ logits_fake , _ = disc_model (output .detach ()) # detach to avoid backpropagation to model
57
+ loss_disc = disc_loss ([logit_real .detach () for logit_real in logits_real ], logits_fake ) # compute discriminator loss
58
+ loss_disc .backward ()
59
+ optimizer_disc .step ()
60
+ disc_scheduler .step ()
61
+
64
62
if not config .distributed .data_parallel or dist .get_rank ()== 0 :
65
63
logger .info (f'| epoch: { epoch } | loss: { loss .item ()} | loss_g: { loss_g .item ()} | loss_w: { loss_w .item ()} | lr: { optimizer .param_groups [0 ]["lr" ]} | disc_lr: { optimizer_disc .param_groups [0 ]["lr" ]} ' )
66
64
if config .model .train_discriminator and epoch > config .lr_scheduler .warmup_epoch :
@@ -261,4 +259,4 @@ def main(config):
261
259
262
260
263
261
if __name__ == '__main__' :
264
- main ()
262
+ main ()
0 commit comments