7 minute read

我们已经实现了基础的张量并行、混合精度训练、checkpoint 等,为了走向工业级别的实现,是时候给他加上命令行选项了!

train_minimal.py

diff --git a/rosellm/rosetrainer/train_minimal.py b/rosellm/rosetrainer/train_minimal.py
index 12fa1fc..338ec3b 100644
--- a/rosellm/rosetrainer/train_minimal.py
+++ b/rosellm/rosetrainer/train_minimal.py
@@ -1,3 +1,4 @@
+import argparse
 import os
 
 import torch
@@ -33,36 +34,36 @@ class ToyRandomDataset(Dataset):
         }
 
 
-def main():
+def main(args: argparse.Namespace) -> None:
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     print("Using device:", device)
-    checkpoint_path = "checkpoints/minigpt_single.pt"
-    resume = False
+    checkpoint_path = args.checkpoint_path
+    resume = args.resume
     config = GPTConfig(
-        vocab_size=10000,
-        max_position_embeddings=128,
-        n_layers=2,
-        n_heads=4,
-        d_model=128,
-        d_ff=512,
-        dropout=0.1,
+        vocab_size=args.vocab_size,
+        max_position_embeddings=args.max_position_embeddings,
+        n_layers=args.n_layers,
+        n_heads=args.n_heads,
+        d_model=args.d_model,
+        d_ff=args.d_ff,
+        dropout=args.dropout,
     )
     model = GPTModel(config).to(device)
     dataset = ToyRandomDataset(
         vocab_size=config.vocab_size,
-        seq_len=32,
+        seq_len=args.seq_len,
         num_samples=1000,
     )
     dataloader = DataLoader(
         dataset,
-        batch_size=8,
+        batch_size=args.batch_size,
         shuffle=True,
     )
-    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
-    use_amp = device.type == "cuda"
+    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
+    use_amp = device.type == "cuda" and not args.no_amp
     scaler = GradScaler(enabled=use_amp)
     model.train()
-    num_steps = 50
+    num_steps = args.num_steps
     step = 0
     if resume and os.path.exists(checkpoint_path):
         print(f"Resuming from checkpoint {checkpoint_path}")
@@ -115,5 +116,98 @@ def main():
             )
 
 
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(description="Train minimal GPT model.")
+    parser.add_argument(
+        "--vocab-size",
+        type=int,
+        default=10000,
+        help="Vocabulary size.",
+    )
+    parser.add_argument(
+        "--max-position-embeddings",
+        type=int,
+        default=128,
+        help="Max sequence length.",
+    )
+    parser.add_argument(
+        "--n-layers",
+        type=int,
+        default=2,
+        help="Number of Transformer layers.",
+    )
+    parser.add_argument(
+        "--n-heads",
+        type=int,
+        default=4,
+        help="Number of attention heads.",
+    )
+    parser.add_argument(
+        "--d-model",
+        type=int,
+        default=128,
+        help="Model hidden size.",
+    )
+    parser.add_argument(
+        "--d-ff",
+        type=int,
+        default=512,
+        help="FFN hidden size.",
+    )
+    parser.add_argument(
+        "--dropout",
+        type=float,
+        default=0.1,
+        help="Dropout probability.",
+    )
+    parser.add_argument(
+        "--use-tensor-parallel",
+        action="store_true",
+        help="Enable tensor parallel blocks.",
+    )
+    parser.add_argument(
+        "--batch-size",
+        type=int,
+        default=8,
+        help="Batch size per step.",
+    )
+    parser.add_argument(
+        "--seq-len",
+        type=int,
+        default=32,
+        help="Sequence length.",
+    )
+    parser.add_argument(
+        "--num-steps",
+        type=int,
+        default=50,
+        help="Number of training steps.",
+    )
+    parser.add_argument(
+        "--lr",
+        type=float,
+        default=3e-4,
+        help="Learning rate.",
+    )
+    parser.add_argument(
+        "--no-amp",
+        action="store_true",
+        help="Disable AMP even on CUDA.",
+    )
+    parser.add_argument(
+        "--checkpoint-path",
+        type=str,
+        default="checkpoints/minigpt_single.pt",
+        help="Path to checkpoint file.",
+    )
+    parser.add_argument(
+        "--resume",
+        action="store_true",
+        help="Resume training from checkpoint.",
+    )
+    return parser.parse_args()
+
+
 if __name__ == "__main__":
-    main()
+    args = parse_args()
+    main(args)

train_ddp.py

diff --git a/rosellm/rosetrainer/train_ddp.py b/rosellm/rosetrainer/train_ddp.py
index 47c2b68..0c3ad7f 100644
--- a/rosellm/rosetrainer/train_ddp.py
+++ b/rosellm/rosetrainer/train_ddp.py
@@ -1,3 +1,4 @@
+import argparse
 import os
 
 import torch
@@ -51,21 +52,21 @@ def is_main_process(local_rank: int) -> bool:
     return local_rank == 0
 
 
-def main():
+def main(args: argparse.Namespace) -> None:
     device, local_rank = setup_distributed()
-    checkpoint_path = "checkpoints/minigpt_ddp.pt"
-    resume = False
+    checkpoint_path = args.checkpoint_path
+    resume = args.resume
     if is_main_process(local_rank):
         print(f"[rank {local_rank}] Using device: {device}")
         os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
     config = GPTConfig(
-        vocab_size=10000,
-        max_position_embeddings=128,
-        n_layers=2,
-        n_heads=4,
-        d_model=128,
-        d_ff=512,
-        dropout=0.1,
+        vocab_size=args.vocab_size,
+        max_position_embeddings=args.max_position_embeddings,
+        n_layers=args.n_layers,
+        n_heads=args.n_heads,
+        d_model=args.d_model,
+        d_ff=args.d_ff,
+        dropout=args.dropout,
     )
     model = GPTModel(config).to(device)
     ddp_model = DDP(
@@ -76,7 +77,7 @@ def main():
     )
     dataset = ToyRandomDataset(
         vocab_size=config.vocab_size,
-        seq_len=32,
+        seq_len=args.seq_len,
         num_samples=1000,
     )
     sampler = DistributedSampler(
@@ -87,14 +88,14 @@ def main():
     )
     dataloader = DataLoader(
         dataset,
-        batch_size=8,
+        batch_size=args.batch_size,
         sampler=sampler,
     )
-    optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=3e-4)
+    optimizer = torch.optim.AdamW(ddp_model.parameters(), lr=args.lr)
     use_amp = device.type == "cuda"
     scaler = GradScaler(enabled=use_amp)
     ddp_model.train()
-    num_steps = 50
+    num_steps = args.num_steps
     step = 0
     if resume and os.path.exists(checkpoint_path):
         print(f"[rank {local_rank}] Resuming from checkpoint {checkpoint_path}")
@@ -162,5 +163,98 @@ def main():
     cleanup_distributed()
     
+def parse_args() -> argparse.Namespace:
+    parser = argparse.ArgumentParser(description="DDP training for GPT model.")
+    parser.add_argument(
+        "--vocab-size",
+        type=int,
+        default=10000,
+        help="Vocabulary size.",
+    )
+    parser.add_argument(
+        "--max-position-embeddings",
+        type=int,
+        default=128,
+        help="Max sequence length.",
+    )
+    parser.add_argument(
+        "--n-layers",
+        type=int,
+        default=2,
+        help="Number of Transformer layers.",
+    )
+    parser.add_argument(
+        "--n-heads",
+        type=int,
+        default=4,
+        help="Number of attention heads.",
+    )
+    parser.add_argument(
+        "--d-model",
+        type=int,
+        default=128,
+        help="Model hidden size.",
+    )
+    parser.add_argument(
+        "--d-ff",
+        type=int,
+        default=512,
+        help="FFN hidden size.",
+    )
+    parser.add_argument(
+        "--dropout",
+        type=float,
+        default=0.1,
+        help="Dropout probability.",
+    )
+    parser.add_argument(
+        "--use-tensor-parallel",
+        action="store_true",
+        help="Enable tensor parallel blocks.",
+    )
+    parser.add_argument(
+        "--batch-size",
+        type=int,
+        default=8,
+        help="Batch size per rank.",
+    )
+    parser.add_argument(
+        "--seq-len",
+        type=int,
+        default=32,
+        help="Sequence length.",
+    )
+    parser.add_argument(
+        "--num-steps",
+        type=int,
+        default=50,
+        help="Total training steps.",
+    )
+    parser.add_argument(
+        "--lr",
+        type=float,
+        default=3e-4,
+        help="Learning rate.",
+    )
+    parser.add_argument(
+        "--no-amp",
+        action="store_true",
+        help="Disable AMP even on CUDA.",
+    )
+    parser.add_argument(
+        "--checkpoint-path",
+        type=str,
+        default="checkpoints/minigpt_ddp.pt",
+        help="Path to checkpoint file.",
+    )
+    parser.add_argument(
+        "--resume",
+        action="store_true",
+        help="Resume training from checkpoint.",
+    )
+    return parser.parse_args()
+
+
 if __name__ == "__main__":
-    main()
+    args = parse_args()
+    main(args)

运行

$ python train_minimal.py \
  --n-layers 4 \
  --d-model 256 \
  --n-heads 4 \
  --d-ff 1024 \
  --batch-size 16 \
  --seq-len 64 \
  --num-steps 100 \
  --use-tensor-parallel \
  --resume \
  --checkpoint-path checkpoints/exp1.pt
Using device: cuda
Resume flag is set, but checkpoint not found. Starting from scratch.
step 10 / 100  loss: 9.3688  amp: True
step 20 / 100  loss: 9.3830  amp: True
step 30 / 100  loss: 9.3733  amp: True
step 40 / 100  loss: 9.3916  amp: True
step 50 / 100  loss: 9.3693  amp: True
step 60 / 100  loss: 9.3690  amp: True

第二次运行:

$ python train_minimal.py   --n-layers 4   --d-model 256   --n-heads 4   --d-ff 1024   --batch-size 16   --seq-len 64   --num-steps 100   --use-tensor-parallel   --resume   --checkpoint-path checkpoints/exp1.pt
Using device: cuda
Resuming from checkpoint checkpoints/exp1.pt
Resumed from step 60
step 70 / 100  loss: 9.3858  amp: True
step 80 / 100  loss: 9.3328  amp: True
step 90 / 100  loss: 9.3256  amp: True
step 100 / 100  loss: 9.3718  amp: True

然后运行 ddp:

$ torchrun --nproc-per-node=2 train_ddp.py \
  --n-layers 4 \
  --d-model 256 \
  --n-heads 4 \
  --d-ff 1024 \
  --seq-len 64 \
  --batch-size 16 \
  --num-steps 200 \
  --checkpoint-path checkpoints/exp_ddp.pt \
  --resume
W1127 20:16:16.653000 220625 site-packages/torch/distributed/run.py:792] 
W1127 20:16:16.653000 220625 site-packages/torch/distributed/run.py:792] *****************************************
W1127 20:16:16.653000 220625 site-packages/torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1127 20:16:16.653000 220625 site-packages/torch/distributed/run.py:792] *****************************************
[rank 0] Using device: cuda:0
[rank 0] Resume flag is set, but checkpoint not found. Starting from scratch.
[step 10 / 200]  loss = 9.4041  amp = True
[step 20 / 200]  loss = 9.3670  amp = True
[step 30 / 200]  loss = 9.3738  amp = True
[step 40 / 200]  loss = 9.3965  amp = True
[step 50 / 200]  loss = 9.3431  amp = True
[step 60 / 200]  loss = 9.3778  amp = True
[step 70 / 200]  loss = 9.3556  amp = True
[step 80 / 200]  loss = 9.3475  amp = True
[step 90 / 200]  loss = 9.3084  amp = True
[step 100 / 200]  loss = 9.3308  amp = True
[step 110 / 200]  loss = 9.2820  amp = True
[step 120 / 200]  loss = 9.2637  amp = True
[step 130 / 200]  loss = 9.2769  amp = True
[step 140 / 200]  loss = 9.2609  amp = True
[step 150 / 200]  loss = 9.2423  amp = True
[step 160 / 200]  loss = 9.2251  amp = True
[step 170 / 200]  loss = 9.2467  amp = True
[step 180 / 200]  loss = 9.2489  amp = True
[step 190 / 200]  loss = 9.2630  amp = True
[step 200 / 200]  loss = 9.2392  amp = True
Training finished.

第二次运行:

$ torchrun --nproc-per-node=2 train_ddp.py \
  --n-layers 4 \
  --d-model 256 \
  --n-heads 4 \
  --d-ff 1024 \
  --seq-len 64 \
  --batch-size 16 \
  --num-steps 200 \
  --checkpoint-path checkpoints/exp_ddp.pt \
  --resume
W1127 20:16:27.440000 220758 site-packages/torch/distributed/run.py:792] 
W1127 20:16:27.440000 220758 site-packages/torch/distributed/run.py:792] *****************************************
W1127 20:16:27.440000 220758 site-packages/torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1127 20:16:27.440000 220758 site-packages/torch/distributed/run.py:792] *****************************************
[rank 0] Using device: cuda:0
[rank 1] Resuming from checkpoint checkpoints/exp_ddp.pt
[rank 0] Resuming from checkpoint checkpoints/exp_ddp.pt
[rank 0] Resumed from step 200
Training finished.
[rank 1] Resumed from step 200