27
27
28
28
# Run this example with torchrun, for example:
29
29
# torchrun --nproc-per-node=8 \
30
- # examples/llm/peft /automodel.py \
30
+ # examples/llm/finetune /automodel.py \
31
31
# --strategy fsdp2 \
32
32
# --devices 8 \
33
33
# --model meta-llama/Llama-3.2-1B \
34
34
# --ckpt-folder "output"
35
35
#
36
+ # For PEFT please also pass --lora
36
37
# Note: ensure that the --nproc-per-node and --devices values match.
37
38
38
39
@@ -201,7 +202,7 @@ def logger(ckpt_folder, save_every_n_train_steps) -> nl.NeMoLogger:
201
202
)
202
203
203
204
return nl .NeMoLogger (
204
- name = "nemo2_sft " ,
205
+ name = "nemo2_finetune " ,
205
206
log_dir = ckpt_folder ,
206
207
use_datetime_version = False , # must be false if using auto resume
207
208
ckpt = ckpt ,
@@ -210,7 +211,7 @@ def logger(ckpt_folder, save_every_n_train_steps) -> nl.NeMoLogger:
210
211
211
212
212
213
def main ():
213
- """Example script to run SFT with a HF transformers-instantiated model on squad."""
214
+ """Example script to run SFT/PEFT with a HF transformers-instantiated model on squad."""
214
215
import argparse
215
216
216
217
parser = argparse .ArgumentParser ()
@@ -308,6 +309,7 @@ def main():
308
309
'run with --attn-implementation=flash_attention_2' ,
309
310
)
310
311
parser .add_argument ('--start-of-turn-token' , default = None , help = 'Chat turn token' )
312
+ parser .add_argument ('--lora' , action = 'store_true' , help = 'Enables PEFT (LoRA) finetuning; Default: off (SFT).' )
311
313
312
314
args = parser .parse_args ()
313
315
if args .dp_size is None :
@@ -371,7 +373,7 @@ def main():
371
373
trust_remote_code = args .trust_remote_code ,
372
374
use_liger_kernel = args .liger ,
373
375
enable_grad_ckpt = args .enable_grad_ckpt ,
374
- use_linear_ce_loss = not args .no_lce ,
376
+ use_linear_ce_loss = args .no_lce ,
375
377
)
376
378
377
379
assert (
@@ -427,7 +429,13 @@ def main():
427
429
limit_dataset_samples = args .limit_dataset_samples ,
428
430
fp8 = args .fp8 ,
429
431
)
430
-
432
+ if args .peft :
433
+ peft = llm .peft .LoRA (
434
+ target_modules = ['*_proj' ],
435
+ dim = 8 ,
436
+ )
437
+ else :
438
+ peft = None
431
439
llm .api .finetune (
432
440
model = model ,
433
441
data = dataset ,
@@ -448,6 +456,7 @@ def main():
448
456
callbacks = callbacks ,
449
457
precision = "bf16-mixed" ,
450
458
),
459
+ peft = peft ,
451
460
optim = optimizer ,
452
461
log = logger (args .ckpt_folder , args .max_steps // 2 ),
453
462
resume = resume ,
0 commit comments