12
12
from ptgnn .baseneuralmodel .abstractneuralmodel import AbstractNeuralModel
13
13
from ptgnn .baseneuralmodel .modulewithmetrics import ModuleWithMetrics
14
14
from ptgnn .baseneuralmodel .utils .data import MemorizedDataIterable
15
+ from ptgnn .baseneuralmodel .utils .oom import catch_cuda_oom
15
16
16
17
TRawDatapoint = TypeVar ("TRawDatapoint" )
17
18
TTensorizedDatapoint = TypeVar ("TTensorizedDatapoint" )
@@ -52,6 +53,7 @@ def __init__(
52
53
target_validation_metric : Optional [str ] = None ,
53
54
target_validation_metric_higher_is_better : bool = False ,
54
55
enable_amp : bool = False ,
56
+ catch_cuda_ooms : bool = False ,
55
57
):
56
58
"""
57
59
:param model: The Component to be built and trained
@@ -64,6 +66,15 @@ def __init__(
64
66
:param scheduler_creator: An optional function that accepts an optimizer and creates a scheduler
65
67
implementing `AbstractScheduler`. This could be a wrapper for existing learning schedulers.
66
68
The scheduler will be invoked at after each training step.
69
+ :param clip_gradient_norm: An optional norm for clipping the gradient norms during training.
70
+ :param target_validation_metric: An optional string of the name of the metric (returned by
71
+ the TNeuralModule) which is used to detect if the model performance improved in validation.
72
+ This is used for early stopping, and checkpointing the best model. If `None` the model
73
+ loss (value returned from `forward()` of TNeuralModule) is used.
74
+ :param target_validation_metric_higher_is_better: if `True` increases to `target_validation_metric`
75
+ imply improvements. Ignored if `target_validation_metric` is `None`.
76
+ :param enable_amp: Enable automatic mixed precision during training.
77
+ :param catch_cuda_ooms: Catch CUDA out-of-memory errors (OOM) and resume training when they happen.
67
78
"""
68
79
self .__model = model
69
80
self .__neural_network : Optional [TNeuralModule ] = None
@@ -87,6 +98,7 @@ def __init__(
87
98
self ._improved_epoch_end_hooks : List [EndOfEpochHook ] = []
88
99
self ._clip_gradient_norm = clip_gradient_norm
89
100
self ._enable_amp = enable_amp
101
+ self ._catch_cuda_ooms = catch_cuda_ooms
90
102
91
103
self ._target_metric = target_validation_metric
92
104
if target_validation_metric is not None :
@@ -203,40 +215,41 @@ def _run_training(
203
215
)
204
216
):
205
217
optimizer .zero_grad ()
206
- with torch .cuda .amp .autocast (enabled = self ._enable_amp ):
207
- mb_loss = self .neural_module (** mb_data )
218
+ with catch_cuda_oom (self ._catch_cuda_ooms ):
219
+ with torch .cuda .amp .autocast (enabled = self ._enable_amp ):
220
+ mb_loss = self .neural_module (** mb_data )
208
221
209
- scaler .scale (mb_loss ).backward ()
222
+ scaler .scale (mb_loss ).backward ()
210
223
211
- if torch .isnan (mb_loss ):
212
- raise Exception ("Loss has a NaN value." )
224
+ if torch .isnan (mb_loss ):
225
+ raise Exception ("Loss has a NaN value." )
213
226
214
- if self ._clip_gradient_norm is not None :
215
- scaler .unscale_ (optimizer )
216
- torch .nn .utils .clip_grad_norm_ (
217
- self .neural_module .parameters (recurse = True ), self ._clip_gradient_norm
218
- )
219
-
220
- scaler .step (optimizer )
221
- scaler .update ()
222
- if scheduler is not None :
223
- scheduler .step (epoch_idx = epoch , epoch_step = step_idx )
224
-
225
- num_minibatches += 1
226
- num_samples += len (raw_samples )
227
- with torch .no_grad ():
228
- sum_epoch_loss += mb_loss
229
- if show_progress_bar :
230
- mb_loss = float (mb_loss )
231
- if num_minibatches == 1 : # First minibatch
232
- running_avg_loss = mb_loss
233
- else :
234
- running_avg_loss = (
235
- exponential_running_average_factor * running_avg_loss
236
- + (1 - exponential_running_average_factor ) * mb_loss
227
+ if self ._clip_gradient_norm is not None :
228
+ scaler .unscale_ (optimizer )
229
+ torch .nn .utils .clip_grad_norm_ (
230
+ self .neural_module .parameters (recurse = True ), self ._clip_gradient_norm
237
231
)
238
- progress_bar .update ()
239
- progress_bar .set_postfix (Loss = f"{ running_avg_loss :.2f} " )
232
+
233
+ scaler .step (optimizer )
234
+ scaler .update ()
235
+ if scheduler is not None :
236
+ scheduler .step (epoch_idx = epoch , epoch_step = step_idx )
237
+
238
+ num_minibatches += 1
239
+ num_samples += len (raw_samples )
240
+ with torch .no_grad ():
241
+ sum_epoch_loss += mb_loss
242
+ if show_progress_bar :
243
+ mb_loss = float (mb_loss )
244
+ if num_minibatches == 1 : # First minibatch
245
+ running_avg_loss = mb_loss
246
+ else :
247
+ running_avg_loss = (
248
+ exponential_running_average_factor * running_avg_loss
249
+ + (1 - exponential_running_average_factor ) * mb_loss
250
+ )
251
+ progress_bar .update ()
252
+ progress_bar .set_postfix (Loss = f"{ running_avg_loss :.2f} " )
240
253
241
254
elapsed_time = time .time () - start_time
242
255
self .LOGGER .info (
@@ -275,14 +288,17 @@ def _run_validation(
275
288
shuffle_input = False ,
276
289
parallelize = parallelize ,
277
290
):
278
- with torch .cuda .amp .autocast (enabled = self ._enable_amp ):
279
- mb_loss = self .neural_module (** mb_data )
280
- num_minibatches += 1
281
- num_samples += len (raw_samples )
282
- sum_epoch_loss += mb_loss
283
- if show_progress_bar :
284
- progress_bar .update ()
285
- progress_bar .set_postfix (Loss = f"{ float (sum_epoch_loss ) / num_minibatches :.2f} " )
291
+ with catch_cuda_oom (self ._catch_cuda_ooms ):
292
+ with torch .cuda .amp .autocast (enabled = self ._enable_amp ):
293
+ mb_loss = self .neural_module (** mb_data )
294
+ num_minibatches += 1
295
+ num_samples += len (raw_samples )
296
+ sum_epoch_loss += mb_loss
297
+ if show_progress_bar :
298
+ progress_bar .update ()
299
+ progress_bar .set_postfix (
300
+ Loss = f"{ float (sum_epoch_loss ) / num_minibatches :.2f} "
301
+ )
286
302
287
303
elapsed_time = time .time () - start_time
288
304
assert num_samples > 0 , "No validation data was found."
0 commit comments