从零实现 LLM Training:004. Column Parallel
在实现了 mini-GPT,简单的 train loop,较简单的数据并行之后,我们这次可以正式开始最基础的张量并行。
张量并行分为列张量并行和行张量并行,其中列张量并行表示对权重矩阵按列切分,相当于同样的激活值在不同的卡上过模型按列切分的不同部分,得到的结果需要 all-gather 成完整的激活值,在后面直接搭配行张量并行时,这个 all-gather 会省略掉,因为行张量并行恰好需要切分后的激活值,行张量并行的结果是完整的激活值形状,但是每张卡有不同的具体激活值,需要进行 all-reduce 来拿到统一的结果。
本文的新 PR 就先实现最简单的列张量并行。
tensor_parallel.py
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
_TP_GROUP: Optional[dist.ProcessGroup] = None
def init_tensor_parallel(tp_size: Optional[int] = None) -> None:
global _TP_GROUP
if not dist.is_initialized():
raise RuntimeError("dist is not initialized")
world_size = dist.get_world_size()
if tp_size is None:
tp_size = world_size
if tp_size != world_size:
raise NotImplementedError("currently we only support tp_size == world_size")
_TP_GROUP = dist.group.WORLD
def get_tensor_parallel_group() -> dist.ProcessGroup:
if _TP_GROUP is None:
raise RuntimeError("tensor parallel group is not initialized")
return _TP_GROUP
class ColumnParallelLinear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
gather_output: bool = True,
) -> None:
super().__init__()
if not dist.is_initialized():
raise RuntimeError("dist is not initialized")
tp_group = get_tensor_parallel_group()
tp_world_size = dist.get_world_size(tp_group)
if out_features % tp_world_size != 0:
raise ValueError("out_features must be divisible by tp_world_size")
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.tp_group = tp_group
self.tp_world_size = tp_world_size
self.out_per_rank = out_features // tp_world_size
self.rank = dist.get_rank(tp_group)
self.weight = nn.Parameter(torch.empty(self.out_per_rank, in_features))
if bias:
self.bias = nn.Parameter(torch.empty(self.out_per_rank))
else:
self.bias = None
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
fan_in = self.in_features
bound = 1 / fan_in**0.5
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: torch.Tensor) -> torch.Tensor:
y_local = torch.matmul(x, self.weight.t())
if self.bias is not None:
y_local = y_local + self.bias
if not self.gather_output or self.tp_world_size == 1:
return y_local
out_list = [torch.empty_like(y_local) for _ in range(self.tp_world_size)]
dist.all_gather(out_list, y_local, group=self.tp_group)
y = torch.cat(out_list, dim=-1)
return y
这里我们定义了一个全局变量 _TP_GROUP,表示张量并行的通信组(process group),并在初始化的时候暂时赋值他为 dist.group.WORLD,相当于我们默认全组都是张量并行组,后续的 PR 中我们会做深入的细化。
然后我们实现了 ColumnParallelLinear ,在 init 中将 weight 按照 rank 进行切分,在前向中,手动调用了 all-gather 来聚合出完整的结果。
接下来我们需要一个小的测试来验证我们的实现是正确的。
test_tensor_parallel_linear.py
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from tensor_parallel import ColumnParallelLinear, init_tensor_parallel
def setup_distributed() -> torch.device:
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
return torch.device("cuda", local_rank)
def cleanup_distributed() -> None:
dist.destroy_process_group()
def main() -> None:
device = setup_distributed()
init_tensor_parallel()
world_size = dist.get_world_size()
rank = dist.get_rank()
if rank == 0:
print(f"world_size = {world_size}")
batch_size = 4
in_features = 8
out_features = 12
if out_features % world_size != 0:
raise RuntimeError("out_features must be divisible by world_size")
torch.manual_seed(42)
ref_linear = nn.Linear(in_features, out_features, bias=True).to(device)
tp_linear = ColumnParallelLinear(
in_features=in_features,
out_features=out_features,
bias=True,
gather_output=True,
).to(device)
with torch.no_grad():
out_per_rank = out_features // world_size
start = rank * out_per_rank
end = start + out_per_rank
tp_linear.weight.copy_(ref_linear.weight[start:end, :])
tp_linear.bias.copy_(ref_linear.bias[start:end])
torch.manual_seed(123)
x = torch.randn(batch_size, in_features, device=device)
y_ref = ref_linear(x)
y_tp = tp_linear(x)
diff = (y_ref - y_tp).abs().max()
diff_val = diff.item()
if rank == 0:
print("max |y_ref - t_tp| = ", diff_val)
cleanup_distributed()
if __name__ == "__main__":
main()
这个测试验证了经过列并行 linear 和经过普通的 linear 得到的结果完全一致,具体运行如下:
$ torchrun --nproc-per-node=2 test_tensor_parallel_linear.py
W1126 16:57:28.144000 3084989 site-packages/torch/distributed/run.py:792]
W1126 16:57:28.144000 3084989 site-packages/torch/distributed/run.py:792] *****************************************
W1126 16:57:28.144000 3084989 site-packages/torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W1126 16:57:28.144000 3084989 site-packages/torch/distributed/run.py:792] *****************************************
world_size = 2
max |y_ref - t_tp| = 0.0