1 minute read

KVBlockManager.append_token_batch() 是 decode 的热路径:每一层、每一步都要根据 block_ids 找到 “当前 last block 的 index / length”,然后写 KV 并更新 length。

之前 KVBlockManager 维护两张 dict:

  • _block_infos: dict[int, KVBlockInfo]
  • _block_refcounts: dict[int, int]

但我们这里的 global_id = layer_idx * max_blocks_per_layer + block_idx 本质上是 连续整数 id,而且上界固定(num_layers * max_blocks_per_layer)。

所以这版直接把两张 dict 改成 定长 list:按 global_id 直接索引,减少 decode 热路径的 Python dict 开销。

代码变更

roseinfer/engine.py

核心 diff:

diff --git a/rosellm/roseinfer/engine.py b/rosellm/roseinfer/engine.py
@@ class KVBlockManager.__init__
-self._block_infos: dict[int, KVBlockInfo] = {}
+self._block_infos: list[KVBlockInfo | None] = [None for _ in range(num_layers * max_blocks_per_layer)]
@@
-self._block_refcounts: dict[int, int] = {}
+self._block_refcounts: list[int] = [0 for _ in range(num_layers * max_blocks_per_layer)]
@@ def incref_blocks(...)
-self._block_refcounts[gid] = self._block_refcounts.get(gid, 0) + 1
+self._block_refcounts[gid] += 1
@@ def free_blocks(...)
-ref = self._block_refcounts.get(gid)
+ref = self._block_refcounts[gid]
 ...
-info = self._block_infos.pop(gid, None)
+info = self._block_infos[gid]; self._block_infos[gid] = None

运行

pytest -q
.....................................                                    [100%]
37 passed, 1 warning in 2.69s

Benchmark(HF GPT-2 / streaming / decode-heavy)

为了让 append_token_batch() 的调用次数足够多,用一组 decode-heavy 的配置:

  • --no-prefix-cache:避免 prefix cache 分支干扰
  • --max-new-tokens 128:让每个 request 产生足够多的 decode steps
  • --prompt-repeats "1,4":混合不同 prompt 长度(pos 不恒定)
python -m rosellm.roseinfer.benchmark_streaming \
  --hf-model-id gpt2 --device cuda \
  --prompt "Hello" --prompt-repeats "1,4" \
  --pretok --num-requests 128 \
  --max-new-tokens 128 --no-stop-on-eos \
  --no-prefix-cache --paged-attn \
  --warmup-runs 1 --repeat-runs 3

Before(3 次 run):

Throughput: 2105.17 / 2114.89 / 2114.55 tokens/s (avg 2111.54)
TPOT p99: 60.27 / 60.05 / 60.05 ms/token (avg 60.12)

After(3 次 run):

Throughput: 2147.02 / 2131.60 / 2114.18 tokens/s (avg 2130.93)
TPOT p99: 59.13 / 59.54 / 60.06 ms/token (avg 59.58)

结论

  • completion throughput 平均 +0.92%2111.54 -> 2130.93 tokens/s
  • TPOT p99 平均 -0.91%60.12 -> 59.58 ms/token

这类优化属于典型 “把 hot path 的 dict 查表换成数组索引”:收益不大但稳定可控,也为后续更激进的 KV 热路径优化(buffer 复用 / kernel 融合)打基础。