4 minute read

在实现了 offline scheduler 之后,我们可以往 paged attention 迈一小步,可以先实现一个 python 版本的 kv block manager,并为了使单次变更比较小,本次仅对 kv-cache 记录一下 kv block 的 meta data,暂时不在实际的 forward 使用。

代码变更

engine.py

最主要添加的就是 kv block manager,最主要是构造一个 global id 到 kv block info 的映射,kv block info 则包含 layer index,block index,start,length 这些 metadata,然后在每个 inference session 维护一个 block ids per layer,从 layer index 映射到这个 layer 所对应的 global ids(block ids):

diff --git a/rosellm/roseinfer/engine.py b/rosellm/roseinfer/engine.py
index 16cd1f8..ab1e640 100644
--- a/rosellm/roseinfer/engine.py
+++ b/rosellm/roseinfer/engine.py
@@ -1,4 +1,4 @@
-from typing import Iterator, Optional
+from typing import Iterator, NamedTuple, Optional
 
 import torch
 from roseinfer.detokenizer import (
@@ -63,6 +63,18 @@ class InferenceEngine:
             return PrefixDiffDetokenizer(self.tokenizer)
 
         self._make_detok = make_detok
+        block_size = 64
+        max_context = max_position_embeddings or self.config.max_position_embeddings
+        max_blocks_per_layer = (max_context + block_size - 1) // block_size
+        self.kv_manager = KVBlockManager(
+            num_layers=self.config.n_layers,
+            num_heads=self.config.n_heads,
+            head_dim=self.config.d_model // self.config.n_heads,
+            block_size=block_size,
+            max_blocks_per_layer=max_blocks_per_layer,
+            device=self.device,
+            dtype=self.amp_dtype if self.use_amp else self.model.dtype,
+        )
 
         if self.config.vocab_size < self.tokenizer.vocab_size:
             raise ValueError("the model vocab_size is less than tokenizer vocab_size")
@@ -603,6 +615,11 @@ class InferenceSession:
         self.do_sample: bool = False
         self.stop_on_eos: bool = True
         self.step_count: int = 0
+        self.kv_manager = engine.kv_manager
+        self.block_ids_per_layer: list[list[int]] = [
+            [] for _ in range(self.kv_manager.num_layers)
+        ]
+        self.prompt_length: int = 0
 
     def set_generation_config(
         self,
@@ -633,6 +650,28 @@ class InferenceSession:
             skip_special_tokens=True,
         )
 
+    def _register_prefill_kv(
+        self,
+        presents,
+        seq_len: int,
+    ) -> None:
+        if self.kv_manager is None:
+            return
+        self.prompt_length = seq_len
+        self.block_ids_per_layer = [[] for _ in range(self.kv_manager.num_layers)]
+        for layer_idx, layer_past in enumerate(presents):
+            if layer_idx >= self.kv_manager.num_layers:
+                break
+            key, value = layer_past  # [B, H, T, D]
+            if key.size(2) != seq_len:
+                continue
+            block_ids = self.kv_manager.register_prefill_layer(
+                layer_idx,
+                key,
+                value,
+            )
+            self.block_ids_per_layer[layer_idx] = block_ids
+
     @torch.no_grad()
     def prefill(
         self,
@@ -659,6 +698,7 @@ class InferenceSession:
                 past_key_values=None,
                 use_cache=True,
             )
+        self._register_prefill_kv(presents, input_ids.size(1))
         self.kv_cache = presents
         return logits  # [..., T0, vocab]
 
@@ -694,6 +734,8 @@ class InferenceSession:
                 past_key_values=None,
                 use_cache=True,
             )
+        if input_ids.size(0) == 1:  # temporarily only support batch size 1
+            self._register_prefill_kv(presents, input_ids.size(1))
         self.kv_cache = presents
         last_logits = logits[:, -1, :]  # [batch, vocab]
         return last_logits
@@ -757,6 +799,15 @@ class InferenceSession:
             self.finished = True
         return token_id
 
+    def release_kv_blocks(self) -> None:
+        if self.kv_manager is None:
+            return
+        for layer_idx, block_ids in enumerate(self.block_ids_per_layer):
+            if not block_ids:
+                continue
+            self.kv_manager.free_blocks(layer_idx, block_ids)
+        self.block_ids_per_layer = [[] for _ in range(self.kv_manager.num_layers)]
+
     @torch.no_grad()
     def decode_step_batch(
         self,
@@ -863,4 +914,98 @@ class OfflineScheduler:
         outputs: dict[int, str] = {}
         for rid, session in self._sessions.items():
             outputs[rid] = session.decode_text()
+        for session in self._sessions.values():
+            session.release_kv_blocks()
         return outputs
+
+
+class KVBlockInfo(NamedTuple):
+    layer: int
+    block_index: int
+    start: int
+    length: int
+
+
+class KVBlockManager:
+    def __init__(
+        self,
+        num_layers: int,
+        num_heads: int,
+        head_dim: int,
+        block_size: int,
+        max_blocks_per_layer: int,
+        device: torch.device,
+        dtype: torch.dtype,
+    ) -> None:
+        self.num_layers = num_layers
+        self.num_heads = num_heads
+        self.head_dim = head_dim
+        self.block_size = block_size
+        self.max_blocks_per_layer = max_blocks_per_layer
+        self.device = device
+        self.dtype = dtype
+        self._next_block_index: list[int] = [0 for _ in range(num_layers)]
+        self._free_block_indices: list[list[int]] = [[] for _ in range(num_layers)]
+        self._block_infos: dict[int, KVBlockInfo] = {}
+
+    def _alloc_block_index(self, layer_idx: int) -> int:
+        free_list = self._free_block_indices[layer_idx]
+        if free_list:
+            return free_list.pop()
+        idx = self._next_block_index[layer_idx]
+        if idx >= self.max_blocks_per_layer:
+            raise RuntimeError(f"no more blocks available for layer {layer_idx}")
+        self._next_block_index[layer_idx] += 1
+        return idx
+
+    def _to_global_block_id(
+        self,
+        layer_idx: int,
+        block_index: int,
+    ) -> int:
+        return layer_idx * self.max_blocks_per_layer + block_index
+
+    def register_prefill_layer(
+        self,
+        layer_idx: int,
+        key: torch.Tensor,
+        value: torch.Tensor,
+    ) -> list[int]:
+        assert layer_idx < self.num_layers
+        seq_len = key.size(2)
+        block_size = self.block_size
+        num_blocks = (seq_len + block_size - 1) // block_size
+        block_ids: list[int] = []
+        for i in range(num_blocks):
+            start = i * block_size
+            end = min(start + block_size, seq_len)
+            length = end - start
+            block_idx = self._alloc_block_index(layer_idx)
+            global_id = self._to_global_block_id(
+                layer_idx,
+                block_idx,
+            )
+            info = KVBlockInfo(
+                layer=layer_idx,
+                block_index=block_idx,
+                start=start,
+                length=length,
+            )
+            self._block_infos[global_id] = info
+            block_ids.append(global_id)
+        return block_ids
+
+    def free_blocks(
+        self,
+        layer_idx: int,
+        block_ids: list[int],
+    ) -> None:
+        for global_id in block_ids:
+            info = self._block_infos.pop(global_id, None)
+            if info is None:
+                continue
+            if info.layer != layer_idx:
+                continue
+            self._free_block_indices[layer_idx].append(
+                info.block_index,
+            )

运行

重跑之前的脚本,确保一切正常:


(/data/projects/rosellm/.conda) wine@wine-MS-7D90:/data/projects/rosellm/rosellm$ ./generate.sh 
[roseinfer] device: cuda
[roseinfer] use_amp: True
[roseinfer] prompt: hi, 
[roseinfer] streaming output: ˈp, the early development and and the state at the two weeks to be considered the health or by a few in their way to be that are so on that you you will create the environment.
There’s all the two or the new new new business and a significant by the more to see it would look at a well have a way. It have this can do more of the use and the world-t be a little for the year by it, "Pt.
B, a way to the “The way it,”. The first the life will be the process. By the world-
The the same type of the future to have you can help.