从零实现 LLM Training:008. Use Row Parallel for Attention
之前我们实现了 Row Parallel Linear,并将其应用到了 FFN 上,本文对应的 PR 将把他用到 Attention Layer 上,这个 PR 依然会比较简单,是直接把 Row Parallel Linear 来替换 out_proj,并且保持,QKV 的 Column Parallel Linear 的 gather_output 为 true,Row Parallel Linear 的 input_is_parallel 为 false,这意味着 Attention layer 会有一次 all-gather 加一次 all-reduce,理想情况下 QKV 应该是按照 head 维度进行切分的,我们在下一个 PR 会做这种切分,到时候 gather_output 就会是 false,input_is_parallel 就会是 true 了。
model.py
diff --git a/rosellm/rosetrainer/model.py b/rosellm/rosetrainer/model.py
index d95add6..102e308 100644
--- a/rosellm/rosetrainer/model.py
+++ b/rosellm/rosetrainer/model.py
@@ -28,9 +28,15 @@ class MultiHeadSelfAttention(nn.Module):
bias=True,
gather_output=True,
)
+ self.out_proj = RowParallelLinear(
+ in_features=config.d_model,
+ out_features=config.d_model,
+ bias=True,
+ input_is_parallel=False,
+ )
else:
self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model)
- self.out_proj = nn.Linear(config.d_model, config.d_model)
+ self.out_proj = nn.Linear(config.d_model, config.d_model)
self.dropout = nn.Dropout(config.dropout)
self.register_buffer(
"mask",
test_attention_tp_vs_dense.py
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from config import GPTConfig
from model import MultiHeadSelfAttention
from tensor_parallel import init_tensor_parallel
def setup_distributed():
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
return device, local_rank
def cleanup_distributed():
dist.destroy_process_group()
def build_attention(config_base: GPTConfig, device: torch.device):
dense_cfg = GPTConfig(
vocab_size=config_base.vocab_size,
max_position_embeddings=config_base.max_position_embeddings,
n_layers=config_base.n_layers,
n_heads=config_base.n_heads,
d_model=config_base.d_model,
d_ff=config_base.d_ff,
dropout=config_base.dropout,
use_tensor_parallel=False,
)
attn_dense = MultiHeadSelfAttention(dense_cfg).to(device)
tp_cfg = GPTConfig(
vocab_size=config_base.vocab_size,
max_position_embeddings=config_base.max_position_embeddings,
n_layers=config_base.n_layers,
n_heads=config_base.n_heads,
d_model=config_base.d_model,
d_ff=config_base.d_ff,
dropout=config_base.dropout,
use_tensor_parallel=True,
)
attn_tp = MultiHeadSelfAttention(tp_cfg).to(device)
return attn_dense, attn_tp
def copy_qkv_from_dense_to_tp(
attn_dense: MultiHeadSelfAttention,
attn_tp: MultiHeadSelfAttention,
world_size: int,
rank: int,
):
with torch.no_grad():
linear_dense: nn.Linear = attn_dense.qkv_proj
col_tp = attn_tp.qkv_proj
out_features = linear_dense.out_features
out_per_rank = out_features // world_size
start = rank * out_per_rank
end = start + out_per_rank
col_tp.weight.copy_(linear_dense.weight[start:end, :])
if col_tp.bias is not None:
col_tp.bias.copy_(linear_dense.bias[start:end])
def copy_out_proj_from_dense_to_tp(
attn_dense: MultiHeadSelfAttention,
attn_tp: MultiHeadSelfAttention,
world_size: int,
rank: int,
):
with torch.no_grad():
linear_dense: nn.Linear = attn_dense.out_proj
row_tp = attn_tp.out_proj
in_features = linear_dense.in_features
in_per_rank = in_features // world_size
start = rank * in_per_rank
end = start + in_per_rank
row_tp.weight.copy_(linear_dense.weight[:, start:end])
if row_tp.bias is not None:
row_tp.bias.copy_(linear_dense.bias)
def main():
device, local_rank = setup_distributed()
init_tensor_parallel()
world_size = dist.get_world_size()
rank = dist.get_rank()
if rank == 0:
print(f"world_size = {world_size}")
base_cfg = GPTConfig(
vocab_size=10000,
max_position_embeddings=128,
n_layers=1,
n_heads=4,
d_model=64,
d_ff=256,
dropout=0.0,
)
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)
attn_dense, attn_tp = build_attention(base_cfg, device)
copy_qkv_from_dense_to_tp(attn_dense, attn_tp, world_size, rank)
copy_out_proj_from_dense_to_tp(attn_dense, attn_tp, world_size, rank)
batch_size = 2
seq_len = 8
x = torch.randn(batch_size, seq_len, base_cfg.d_model, device=device)
attention_mask = torch.ones(
batch_size,
seq_len,
dtype=torch.long,
device=device,
)
attn_dense.eval()
attn_tp.eval()
with torch.no_grad():
y_dense = attn_dense(x, attention_mask=attention_mask)
y_tp = attn_tp(x, attention_mask=attention_mask)
diff = (y_dense - y_tp).abs().max()
diff_val = diff.item()
if rank == 0:
print("y_dense shape:", y_dense.shape)
print("y_tp shape:", y_tp.shape)
print("max |y_dense - y_tp| = ", diff_val)
cleanup_distributed()
if __name__ == "__main__":
main()
运行结果:
$ torchrun --nproc-per-node=2 test_attention_tp_vs_dense.py
W1127 12:56:30.316000 22986 site-packages/torch/distributed/run.py:792]
W1127 12:56:30.316000 22986 site-packages/torch/distributed/run.py:792] *****************************************
W1127 12:56:30.316000 22986 site-packages/torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W1127 12:56:30.316000 22986 site-packages/torch/distributed/run.py:792] *****************************************
world_size = 2
y_dense shape: torch.Size([2, 8, 64])
y_tp shape: torch.Size([2, 8, 64])
max |y_dense - y_tp| = 2.384185791015625e-07