从零实现 LLM Inference:060. KV append identity fast path(少做一次 index_select)
KVBlockManager.append_token_batch() 的 baseline 写法比较直接:
- 先把 fast batch 的
batch_idx做成 tensor key_new/value_new.index_select(0, batch_idx)得到k_src/v_src- 再把
k_src/v_srcscatter 写进 KV cache
但在 decode 的大多数 step 里,fast_batch_idx 其实就是一个很常见的形态:
- identity:
[0, 1, 2, ..., B-1](batch 没有被拆成 fast/slow 两块) - const pos:同一个 step 下,
pos对整个 batch 是常量
这时 index_select 就纯属多做一次拷贝;pos_t 也是不必要的分配。
这版 PR 就做两件小事:
1) fast_batch_idx 是 identity 时:直接用 key_new/value_new,跳过 index_select
2) fast_pos 是常量时:用标量 pos0 直接索引,跳过 pos_t 的构造
代码变更
roseinfer/engine.py
核心 diff:
diff --git a/rosellm/roseinfer/engine.py b/rosellm/roseinfer/engine.py
@@
- idx_t = torch.tensor(fast_batch_idx, device=device, dtype=torch.long)
blk_t = torch.tensor(fast_block_idx, device=device, dtype=torch.long)
- pos_t = torch.tensor(fast_pos, device=device, dtype=torch.long)
- k_src = key_new.index_select(0, idx_t)
- v_src = value_new.index_select(0, idx_t)
- k_layer[blk_t, :, pos_t, :] = k_src
- v_layer[blk_t, :, pos_t, :] = v_src
+ full_fast = ... # fast_batch_idx == [0..B-1] && no slow
+ if full_fast:
+ k_src = key_new
+ v_src = value_new
+ else:
+ idx_t = torch.tensor(fast_batch_idx, device=device, dtype=torch.long)
+ k_src = key_new.index_select(0, idx_t)
+ v_src = value_new.index_select(0, idx_t)
+
+ pos0 = fast_pos[0]
+ const_pos = all(p == pos0 for p in fast_pos)
+ if const_pos:
+ k_layer[blk_t, :, pos0, :] = k_src
+ v_layer[blk_t, :, pos0, :] = v_src
+ else:
+ pos_t = torch.tensor(fast_pos, device=device, dtype=torch.long)
+ k_layer[blk_t, :, pos_t, :] = k_src
+ v_layer[blk_t, :, pos_t, :] = v_src
运行
pytest -q
.................................. [100%]
34 passed, 1 warning in 2.72s
Benchmark(HF GPT-2 / streaming)
这组 benchmark 关掉 prefix cache,专注看 decode 的 steady-state(KV append 会被调用很多次)。
HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1 \
python -m rosellm.roseinfer.benchmark_streaming \
--hf-model-id gpt2 --device cuda \
--prompt 'Hello' --pretok --tokenize-workers 0 \
--num-requests 256 --max-new-tokens 256 \
--submit-interval-ms 0 \
--max-batch-size 16 --prefill-max-batch-size 16 \
--prefill-admission-policy fifo \
--paged-attn --no-prefix-cache --no-stop-on-eos \
--warmup-runs 1 --repeat-runs 1
Before:
TPOT p50/p95/p99: 66.77/66.86/66.89 ms/token
ITL p50/p95/p99: 66.01/71.79/82.06 ms
Throughput (completion,total): 3807.13 tokens/s
After:
TPOT p50/p95/p99: 65.41/65.49/65.52 ms/token
ITL p50/p95/p99: 64.68/70.10/74.99 ms
Throughput (completion,total): 3886.10 tokens/s
TPOT p50 大约 -2.0%,吞吐大约 +2.1%,ITL p99 也更稳。
结论
index_select在 identity batch 上属于“重复拷贝”,能省就省。pos常量时避免构造pos_t,属于很典型的“把 Python 小对象和小 tensor 从热路径里挪出去”。- 下一步如果继续抠 KV 写入:可以把
blk_t/pos_t这类 index tensor 也做成可复用 buffer(尤其是 decode 稳态时它们变化很慢)。