Skip to content

Commit ef13a9f

Browse files
Miltos Allamanismallamanis
Miltos Allamanis
authored andcommitted
Option to catch occasional CUDA OOMs to allow for more robust training.
1 parent bf72e7f commit ef13a9f

File tree

3 files changed

+76
-39
lines changed

3 files changed

+76
-39
lines changed

ptgnn/baseneuralmodel/trainer.py

+54-38
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ptgnn.baseneuralmodel.abstractneuralmodel import AbstractNeuralModel
1313
from ptgnn.baseneuralmodel.modulewithmetrics import ModuleWithMetrics
1414
from ptgnn.baseneuralmodel.utils.data import MemorizedDataIterable
15+
from ptgnn.baseneuralmodel.utils.oom import catch_cuda_oom
1516

1617
TRawDatapoint = TypeVar("TRawDatapoint")
1718
TTensorizedDatapoint = TypeVar("TTensorizedDatapoint")
@@ -52,6 +53,7 @@ def __init__(
5253
target_validation_metric: Optional[str] = None,
5354
target_validation_metric_higher_is_better: bool = False,
5455
enable_amp: bool = False,
56+
catch_cuda_ooms: bool = False,
5557
):
5658
"""
5759
:param model: The Component to be built and trained
@@ -64,6 +66,15 @@ def __init__(
6466
:param scheduler_creator: An optional function that accepts an optimizer and creates a scheduler
6567
implementing `AbstractScheduler`. This could be a wrapper for existing learning schedulers.
6668
The scheduler will be invoked at after each training step.
69+
:param clip_gradient_norm: An optional norm for clipping the gradient norms during training.
70+
:param target_validation_metric: An optional string of the name of the metric (returned by
71+
the TNeuralModule) which is used to detect if the model performance improved in validation.
72+
This is used for early stopping, and checkpointing the best model. If `None` the model
73+
loss (value returned from `forward()` of TNeuralModule) is used.
74+
:param target_validation_metric_higher_is_better: if `True` increases to `target_validation_metric`
75+
imply improvements. Ignored if `target_validation_metric` is `None`.
76+
:param enable_amp: Enable automatic mixed precision during training.
77+
:param catch_cuda_ooms: Catch CUDA out-of-memory errors (OOM) and resume training when they happen.
6778
"""
6879
self.__model = model
6980
self.__neural_network: Optional[TNeuralModule] = None
@@ -87,6 +98,7 @@ def __init__(
8798
self._improved_epoch_end_hooks: List[EndOfEpochHook] = []
8899
self._clip_gradient_norm = clip_gradient_norm
89100
self._enable_amp = enable_amp
101+
self._catch_cuda_ooms = catch_cuda_ooms
90102

91103
self._target_metric = target_validation_metric
92104
if target_validation_metric is not None:
@@ -203,40 +215,41 @@ def _run_training(
203215
)
204216
):
205217
optimizer.zero_grad()
206-
with torch.cuda.amp.autocast(enabled=self._enable_amp):
207-
mb_loss = self.neural_module(**mb_data)
218+
with catch_cuda_oom(self._catch_cuda_ooms):
219+
with torch.cuda.amp.autocast(enabled=self._enable_amp):
220+
mb_loss = self.neural_module(**mb_data)
208221

209-
scaler.scale(mb_loss).backward()
222+
scaler.scale(mb_loss).backward()
210223

211-
if torch.isnan(mb_loss):
212-
raise Exception("Loss has a NaN value.")
224+
if torch.isnan(mb_loss):
225+
raise Exception("Loss has a NaN value.")
213226

214-
if self._clip_gradient_norm is not None:
215-
scaler.unscale_(optimizer)
216-
torch.nn.utils.clip_grad_norm_(
217-
self.neural_module.parameters(recurse=True), self._clip_gradient_norm
218-
)
219-
220-
scaler.step(optimizer)
221-
scaler.update()
222-
if scheduler is not None:
223-
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)
224-
225-
num_minibatches += 1
226-
num_samples += len(raw_samples)
227-
with torch.no_grad():
228-
sum_epoch_loss += mb_loss
229-
if show_progress_bar:
230-
mb_loss = float(mb_loss)
231-
if num_minibatches == 1: # First minibatch
232-
running_avg_loss = mb_loss
233-
else:
234-
running_avg_loss = (
235-
exponential_running_average_factor * running_avg_loss
236-
+ (1 - exponential_running_average_factor) * mb_loss
227+
if self._clip_gradient_norm is not None:
228+
scaler.unscale_(optimizer)
229+
torch.nn.utils.clip_grad_norm_(
230+
self.neural_module.parameters(recurse=True), self._clip_gradient_norm
237231
)
238-
progress_bar.update()
239-
progress_bar.set_postfix(Loss=f"{running_avg_loss:.2f}")
232+
233+
scaler.step(optimizer)
234+
scaler.update()
235+
if scheduler is not None:
236+
scheduler.step(epoch_idx=epoch, epoch_step=step_idx)
237+
238+
num_minibatches += 1
239+
num_samples += len(raw_samples)
240+
with torch.no_grad():
241+
sum_epoch_loss += mb_loss
242+
if show_progress_bar:
243+
mb_loss = float(mb_loss)
244+
if num_minibatches == 1: # First minibatch
245+
running_avg_loss = mb_loss
246+
else:
247+
running_avg_loss = (
248+
exponential_running_average_factor * running_avg_loss
249+
+ (1 - exponential_running_average_factor) * mb_loss
250+
)
251+
progress_bar.update()
252+
progress_bar.set_postfix(Loss=f"{running_avg_loss:.2f}")
240253

241254
elapsed_time = time.time() - start_time
242255
self.LOGGER.info(
@@ -275,14 +288,17 @@ def _run_validation(
275288
shuffle_input=False,
276289
parallelize=parallelize,
277290
):
278-
with torch.cuda.amp.autocast(enabled=self._enable_amp):
279-
mb_loss = self.neural_module(**mb_data)
280-
num_minibatches += 1
281-
num_samples += len(raw_samples)
282-
sum_epoch_loss += mb_loss
283-
if show_progress_bar:
284-
progress_bar.update()
285-
progress_bar.set_postfix(Loss=f"{float(sum_epoch_loss) / num_minibatches:.2f}")
291+
with catch_cuda_oom(self._catch_cuda_ooms):
292+
with torch.cuda.amp.autocast(enabled=self._enable_amp):
293+
mb_loss = self.neural_module(**mb_data)
294+
num_minibatches += 1
295+
num_samples += len(raw_samples)
296+
sum_epoch_loss += mb_loss
297+
if show_progress_bar:
298+
progress_bar.update()
299+
progress_bar.set_postfix(
300+
Loss=f"{float(sum_epoch_loss) / num_minibatches:.2f}"
301+
)
286302

287303
elapsed_time = time.time() - start_time
288304
assert num_samples > 0, "No validation data was found."

ptgnn/baseneuralmodel/utils/oom.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing_extensions import Final
2+
3+
import logging
4+
import torch
5+
from contextlib import contextmanager
6+
7+
LOGGER: Final = logging.getLogger(__name__)
8+
9+
10+
@contextmanager
11+
def catch_cuda_oom(enabled: bool = True):
12+
if enabled:
13+
try:
14+
yield
15+
except RuntimeError as re:
16+
if "CUDA out of memory." in repr(re):
17+
LOGGER.exception("CUDA Out-Of-Memory Caught and Execution Resumed.", exc_info=re)
18+
torch.cuda.empty_cache()
19+
else:
20+
raise re
21+
else:
22+
yield

ptgnn/implementations/typilus/traindistributed.py

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
-h --help Show this screen.
1818
--debug Enable debug routines. [default: False]
1919
"""
20-
import logging
2120
import random
2221
import torch
2322
import torch.distributed as dist

0 commit comments

Comments
 (0)