2 minute read

在我们上一个 PR 中已经实现了 Column Parallel Linear,这个 PR 里我们把它集成到 GPTModel,并用 train loop 去验证他。

config.py

from dataclasses import dataclass

@dataclass
class GPTConfig:
    vocab_size: int = 32000
    max_position_embeddings: int = 1024
    n_layers: int = 12
    n_heads: int = 12
    d_model: int = 768
    d_ff: int = 3072
    dropout: float = 0.1
+   use_tensor_parallel: bool = False

只需要添加一个 use_tensor_parallel 的 flag 即可。

model.py

diff --git a/rosellm/rosetrainer/model.py b/rosellm/rosetrainer/model.py
index 9025c71..fcc20d8 100644
--- a/rosellm/rosetrainer/model.py
+++ b/rosellm/rosetrainer/model.py
@@ -1,9 +1,11 @@
 from typing import Optional
 
 import torch
+import torch.distributed as dist
 import torch.nn as nn
 import torch.nn.functional as F
 from config import GPTConfig
+from tensor_parallel import ColumnParallelLinear, init_tensor_parallel
 
 
 class MultiHeadSelfAttention(nn.Module):
@@ -13,7 +15,17 @@ class MultiHeadSelfAttention(nn.Module):
         self.d_model = config.d_model
         self.n_heads = config.n_heads
         self.d_head = config.d_model // config.n_heads
-        self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model)
+        use_tp = getattr(config, "use_tensor_parallel", False)
+        if use_tp and dist.is_available() and dist.is_initialized():
+            init_tensor_parallel()
+            self.qkv_proj = ColumnParallelLinear(
+                in_features=config.d_model,
+                out_features=3 * config.d_model,
+                bias=True,
+                gather_output=True,
+            )
+        else:
+            self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model)
         self.out_proj = nn.Linear(config.d_model, config.d_model)
         self.dropout = nn.Dropout(config.dropout)
         self.register_buffer(
@@ -57,7 +69,17 @@ class MultiHeadSelfAttention(nn.Module):
 class FeedForward(nn.Module):
     def __init__(self, config: GPTConfig):
         super().__init__()
-        self.fc1 = nn.Linear(config.d_model, config.d_ff)
+        use_tp = getattr(config, "use_tensor_parallel", False)
+        if use_tp and dist.is_available() and dist.is_initialized():
+            init_tensor_parallel()
+            self.fc1 = ColumnParallelLinear(
+                in_features=config.d_model,
+                out_features=config.d_ff,
+                bias=True,
+                gather_output=True,
+            )
+        else:
+            self.fc1 = nn.Linear(config.d_model, config.d_ff)
         self.fc2 = nn.Linear(config.d_ff, config.d_model)
         self.dropout = nn.Dropout(config.dropout)
 

在 Attention layer 的 QKV 参数矩阵以及 FFN layer 的第一个参数矩阵加一个判断来使用 Column Parallel Linear。

此处由于仅有 Column Parallel Linear,因此每一层都在频繁地 all-gather(但是其实即使还有了 Row Parallel Linear,每一层仍然会频繁地 all-reduce)。

test_forward_tp.py

import os

import torch
import torch.distributed as dist
from config import GPTConfig
from model import GPTModel
from tensor_parallel import init_tensor_parallel


def setup_distributed():
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    return device, local_rank


def cleanup_distributed():
    dist.destroy_process_group()


def main():
    device, local_rank = setup_distributed()
    torch.manual_seed(123)  # for deterministic seed
    torch.cuda.manual_seed(123)  # for deterministic seed
    init_tensor_parallel()
    world_size = dist.get_world_size()
    if local_rank == 0:
        print("world_size:", world_size)
    config = GPTConfig(
        vocab_size=10000,
        max_position_embeddings=128,
        n_layers=2,
        n_heads=4,
        d_model=128,
        d_ff=512,
        dropout=0.1,
        use_tensor_parallel=True,
    )
    model = GPTModel(config).to(device)
    batch_size = 2
    seq_len = 16
    input_ids = torch.randint(
        low=0,
        high=config.vocab_size,
        size=(batch_size, seq_len),
        dtype=torch.long,
        device=device,
    )
    attention_mask = torch.ones(
        batch_size,
        seq_len,
        dtype=torch.long,
        device=device,
    )
    logits, loss = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        labels=input_ids,
    )
    with torch.no_grad():
        logits_ref = logits.clone()
        dist.broadcast(logits_ref, src=0)
        diff = (logits_ref - logits).abs().max()
        diff_val = diff.item()
    print("max diff vs rank0:", diff_val)
    if local_rank == 0:
        print("input_ids shape:", input_ids.shape)
        print("logits shape:", logits.shape)
        print("loss:", loss.item())
    cleanup_distributed()


if __name__ == "__main__":
    main()

此处要注意使用 torch.manual_seed 以及torch.cuda.manual_seed 来保证每张卡用的 input_ids 是随机生成的一样的,并且 dropout 的确定性。

运行

$ torchrun --nproc-per-node=2 test_forward_tp.py 
W1126 20:20:40.188000 3266158 site-packages/torch/distributed/run.py:792] 
W1126 20:20:40.188000 3266158 site-packages/torch/distributed/run.py:792] *****************************************
W1126 20:20:40.188000 3266158 site-packages/torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1126 20:20:40.188000 3266158 site-packages/torch/distributed/run.py:792] *****************************************
world_size: 2
max diff vs rank0: 0.0
input_ids shape: torch.Size([2, 16])
logits shape: torch.Size([2, 16, 10000])
loss: 9.380536079406738
max diff vs rank0: 0.0