2 minute read

上一个 PR 我们给 Attention layer 加上了 Row Parallel Linear,但是实际上和 QKV 所使用的 Column Parallel Linear 没有形成搭配,会造成额外的 all-gather 开销,在本文对应 PR 中,我们将 QKV 的结果自然按 head 做切分,从而避免额外的 all-gather 开销。

model.py

diff --git a/rosellm/rosetrainer/model.py b/rosellm/rosetrainer/model.py
index 102e308..56a2054 100644
--- a/rosellm/rosetrainer/model.py
+++ b/rosellm/rosetrainer/model.py
@@ -19,22 +19,30 @@ class MultiHeadSelfAttention(nn.Module):
         self.d_model = config.d_model
         self.n_heads = config.n_heads
         self.d_head = config.d_model // config.n_heads
-        use_tp = getattr(config, "use_tensor_parallel", False)
-        if use_tp and dist.is_available() and dist.is_initialized():
+        use_tp_cfg = getattr(config, "use_tensor_parallel", False)
+        self.use_tp = use_tp_cfg and dist.is_available() and dist.is_initialized()
+        if self.use_tp:
             init_tensor_parallel()
+            tp_world_size = dist.get_world_size()
+            if self.n_heads % tp_world_size != 0:
+                raise ValueError("n_heads must be divisible by tp_world_size")
+            self.tp_world_size = tp_world_size
+            self.local_heads = self.n_heads // tp_world_size
             self.qkv_proj = ColumnParallelLinear(
                 in_features=config.d_model,
                 out_features=3 * config.d_model,
                 bias=True,
-                gather_output=True,
+                gather_output=False,
             )
             self.out_proj = RowParallelLinear(
                 in_features=config.d_model,
                 out_features=config.d_model,
                 bias=True,
-                input_is_parallel=False,
+                input_is_parallel=True,
             )
         else:
+            self.tp_world_size = 1
+            self.local_heads = self.n_heads
             self.qkv_proj = nn.Linear(config.d_model, 3 * config.d_model)
             self.out_proj = nn.Linear(config.d_model, config.d_model)
         self.dropout = nn.Dropout(config.dropout)
@@ -57,7 +65,7 @@ class MultiHeadSelfAttention(nn.Module):
     ):
         bsz, seq_len, _ = x.size()
         qkv = self.qkv_proj(x)
-        qkv = qkv.view(bsz, seq_len, 3, self.n_heads, self.d_head)
+        qkv = qkv.view(bsz, seq_len, 3, self.local_heads, self.d_head)
         qkv = qkv.permute(2, 0, 3, 1, 4)
         q, k, v = qkv[0], qkv[1], qkv[2]
         attn_scores = q @ k.transpose(-2, -1) * self.d_head**-0.5
@@ -70,7 +78,7 @@ class MultiHeadSelfAttention(nn.Module):
         attn_weights = self.dropout(attn_weights)
         attn_output = attn_weights @ v
         attn_output = attn_output.transpose(1, 2).contiguous()
-        attn_output = attn_output.view(bsz, seq_len, self.d_model)
+        attn_output = attn_output.view(bsz, seq_len, self.local_heads * self.d_head)
         out = self.out_proj(attn_output)
         out = self.dropout(out)
         return out

主要修改点就是添加了 local_heads 以及把 gather_output 改成了 false,把 input_is_parallel 改成了 true。

test_attention_tp_vs_dense.py

diff --git a/rosellm/rosetrainer/test_attention_tp_vs_dense.py b/rosellm/rosetrainer/test_attention_tp_vs_dense.py
index fd3b63b..24e21d7 100644
--- a/rosellm/rosetrainer/test_attention_tp_vs_dense.py
+++ b/rosellm/rosetrainer/test_attention_tp_vs_dense.py
@@ -55,13 +55,41 @@ def copy_qkv_from_dense_to_tp(
     with torch.no_grad():
         linear_dense: nn.Linear = attn_dense.qkv_proj
         col_tp = attn_tp.qkv_proj
-        out_features = linear_dense.out_features
-        out_per_rank = out_features // world_size
-        start = rank * out_per_rank
-        end = start + out_per_rank
-        col_tp.weight.copy_(linear_dense.weight[start:end, :])
+
+        d_model = attn_dense.d_model
+        n_heads = attn_dense.n_heads
+        d_head = attn_dense.d_head
+        assert d_model == n_heads * d_head
+
+        local_heads = n_heads // world_size
+        local_dim = local_heads * d_head
+        head_start = rank * local_heads
+        head_end = head_start + local_heads
+
+        q_offset = 0
+        k_offset = d_model
+        v_offset = 2 * d_model
+
+        q_start = q_offset + head_start * d_head
+        q_end = q_offset + head_end * d_head
+        k_start = k_offset + head_start * d_head
+        k_end = k_offset + head_end * d_head
+        v_start = v_offset + head_start * d_head
+        v_end = v_offset + head_end * d_head
+
+        q_weight = linear_dense.weight[q_start:q_end, :]
+        k_weight = linear_dense.weight[k_start:k_end, :]
+        v_weight = linear_dense.weight[v_start:v_end, :]
+
+        col_tp.weight[:local_dim, :].copy_(q_weight)
+        col_tp.weight[local_dim : 2 * local_dim, :].copy_(k_weight)
+        col_tp.weight[2 * local_dim : 3 * local_dim, :].copy_(v_weight)
+
         if col_tp.bias is not None:
-            col_tp.bias.copy_(linear_dense.bias[start:end])
+            q_bias = linear_dense.bias[q_start:q_end]
+            k_bias = linear_dense.bias[k_start:k_end]
+            v_bias = linear_dense.bias[v_start:v_end]
+            col_tp.bias.copy_(torch.cat([q_bias, k_bias, v_bias], dim=0))

相应需要稍微修改一下测试文件,运行如下:

$ torchrun --nproc-per-node=2 test_attention_tp_vs_dense.py 
W1127 14:25:50.446000 72319 site-packages/torch/distributed/run.py:792] 
W1127 14:25:50.446000 72319 site-packages/torch/distributed/run.py:792] *****************************************
W1127 14:25:50.446000 72319 site-packages/torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1127 14:25:50.446000 72319 site-packages/torch/distributed/run.py:792] *****************************************
world_size = 2
y_dense shape: torch.Size([2, 8, 64])
y_tp shape: torch.Size([2, 8, 64])
max |y_dense - y_tp| =  2.384185791015625e-07