Skip to content

Commit d22b3ad

Browse files
authored
[automodel] consolidate sft peft scripts (#13634)
* move files Signed-off-by: Alexandros Koumparoulis <[email protected]> * remove peft Signed-off-by: Alexandros Koumparoulis <[email protected]> * add --lora option Signed-off-by: Alexandros Koumparoulis <[email protected]> * fix paths Signed-off-by: Alexandros Koumparoulis <[email protected]> --------- Signed-off-by: Alexandros Koumparoulis <[email protected]>
1 parent 6e7b414 commit d22b3ad

9 files changed

+27
-335
lines changed

examples/llm/sft/automodel.py renamed to examples/llm/finetune/automodel.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,13 @@
2727

2828
# Run this example with torchrun, for example:
2929
# torchrun --nproc-per-node=8 \
30-
# examples/llm/peft/automodel.py \
30+
# examples/llm/finetune/automodel.py \
3131
# --strategy fsdp2 \
3232
# --devices 8 \
3333
# --model meta-llama/Llama-3.2-1B \
3434
# --ckpt-folder "output"
3535
#
36+
# For PEFT please also pass --lora
3637
# Note: ensure that the --nproc-per-node and --devices values match.
3738

3839

@@ -201,7 +202,7 @@ def logger(ckpt_folder, save_every_n_train_steps) -> nl.NeMoLogger:
201202
)
202203

203204
return nl.NeMoLogger(
204-
name="nemo2_sft",
205+
name="nemo2_finetune",
205206
log_dir=ckpt_folder,
206207
use_datetime_version=False, # must be false if using auto resume
207208
ckpt=ckpt,
@@ -210,7 +211,7 @@ def logger(ckpt_folder, save_every_n_train_steps) -> nl.NeMoLogger:
210211

211212

212213
def main():
213-
"""Example script to run SFT with a HF transformers-instantiated model on squad."""
214+
"""Example script to run SFT/PEFT with a HF transformers-instantiated model on squad."""
214215
import argparse
215216

216217
parser = argparse.ArgumentParser()
@@ -308,6 +309,7 @@ def main():
308309
'run with --attn-implementation=flash_attention_2',
309310
)
310311
parser.add_argument('--start-of-turn-token', default=None, help='Chat turn token')
312+
parser.add_argument('--lora', action='store_true', help='Enables PEFT (LoRA) finetuning; Default: off (SFT).')
311313

312314
args = parser.parse_args()
313315
if args.dp_size is None:
@@ -371,7 +373,7 @@ def main():
371373
trust_remote_code=args.trust_remote_code,
372374
use_liger_kernel=args.liger,
373375
enable_grad_ckpt=args.enable_grad_ckpt,
374-
use_linear_ce_loss=not args.no_lce,
376+
use_linear_ce_loss=args.no_lce,
375377
)
376378

377379
assert (
@@ -427,7 +429,13 @@ def main():
427429
limit_dataset_samples=args.limit_dataset_samples,
428430
fp8=args.fp8,
429431
)
430-
432+
if args.peft:
433+
peft = llm.peft.LoRA(
434+
target_modules=['*_proj'],
435+
dim=8,
436+
)
437+
else:
438+
peft = None
431439
llm.api.finetune(
432440
model=model,
433441
data=dataset,
@@ -448,6 +456,7 @@ def main():
448456
callbacks=callbacks,
449457
precision="bf16-mixed",
450458
),
459+
peft=peft,
451460
optim=optimizer,
452461
log=logger(args.ckpt_folder, args.max_steps // 2),
453462
resume=resume,

examples/llm/peft/automodel.py

Lines changed: 0 additions & 325 deletions
This file was deleted.

0 commit comments

Comments
 (0)