从零实现 LLM Inference:053. CUDA Graph Host Buffers(把 metadata copy 变成真·non_blocking)
前几版我们一直在抠 streaming / scheduler 的 CPU overhead,但只要你把 decode 跑快(比如 --paged-attn --cuda-graph),很多“看起来无关紧要”的小开销就会被放大。
最典型的就是这段逻辑(paged-attn + cuda graph replay 前):
- 每步把
last_ids/seq_lens/slot_ids这些 Python list 转成torch.tensor(...) - 再
copy_到 graph 的输入 buffer - 还会写
non_blocking=True,但 source 不是 pinned memory 的话这并不是真的 non_blocking(同时每步还有分配/释放)
这一版目标很明确:把这些 metadata 的 H2D copy 变成真正的 pinned + non_blocking,并且把每步的 tensor 分配去掉。
代码变更
roseinfer/engine.py
核心思路:
1) 在 _PagedDecodeCudaGraph 里新增一组 CPU pinned staging buffer(以及对应的 numpy view)
2) 每步 decode 时,直接用 numpy 把 list 写进 pinned buffer(非常快)
3) 再从 pinned buffer copy_ 到 graph 的 GPU buffer(这才是有效的 non_blocking=True)
核心 diff:
diff --git a/rosellm/roseinfer/engine.py b/rosellm/roseinfer/engine.py
@@
+import numpy as np
@@
class _PagedDecodeCudaGraph:
@@
+ input_ids_host: torch.Tensor # [B] int64 cpu (pinned)
+ position_ids_host: torch.Tensor # [B] int64 cpu (pinned)
+ slot_mapping_host: torch.Tensor # [B] int32 cpu (pinned)
+ context_lens_host: torch.Tensor # [B] int32 cpu (pinned)
+ input_ids_host_np: np.ndarray
+ position_ids_host_np: np.ndarray
+ slot_mapping_host_np: np.ndarray
+ context_lens_host_np: np.ndarray
@@
def _get_or_create_paged_decode_cuda_graph(...):
@@
+ input_ids_host = torch.empty((B,), device="cpu", dtype=torch.long, pin_memory=True)
+ position_ids_host = torch.empty((B,), device="cpu", dtype=torch.long, pin_memory=True)
+ slot_mapping_host = torch.empty((B,), device="cpu", dtype=torch.int32, pin_memory=True)
+ context_lens_host = torch.empty((B,), device="cpu", dtype=torch.int32, pin_memory=True)
@@
captured = _PagedDecodeCudaGraph(
@@
+ input_ids_host=input_ids_host,
+ position_ids_host=position_ids_host,
+ slot_mapping_host=slot_mapping_host,
+ context_lens_host=context_lens_host,
+ input_ids_host_np=input_ids_host.numpy(),
+ position_ids_host_np=position_ids_host.numpy(),
+ slot_mapping_host_np=slot_mapping_host.numpy(),
+ context_lens_host_np=context_lens_host.numpy(),
)
@@
if self.use_cuda_graph:
graph = self._get_or_create_paged_decode_cuda_graph(...)
- graph.input_ids[:,0].copy_(torch.tensor(last_ids), non_blocking=True)
- graph.position_ids[:,0].copy_(torch.tensor(seq_lens), non_blocking=True)
- graph.slot_mapping.copy_(torch.tensor(slot_ids), non_blocking=True)
- graph.context_lens.copy_(torch.tensor(seq_lens), non_blocking=True)
+ graph.input_ids_host_np[:] = last_ids
+ graph.position_ids_host_np[:] = seq_lens
+ graph.slot_mapping_host_np[:] = slot_ids
+ graph.context_lens_host_np[:] = seq_lens
+ graph.input_ids[:,0].copy_(graph.input_ids_host, non_blocking=True)
+ graph.position_ids[:,0].copy_(graph.position_ids_host, non_blocking=True)
+ graph.slot_mapping.copy_(graph.slot_mapping_host, non_blocking=True)
+ graph.context_lens.copy_(graph.context_lens_host, non_blocking=True)
运行
pytest -q
.............................. [100%]
=============================== warnings summary ===============================
../anaconda3/lib/python3.10/site-packages/urllib3/util/ssl_.py:260
/data/projects/anaconda3/lib/python3.10/site-packages/urllib3/util/ssl_.py:260: DeprecationWarning: ssl.PROTOCOL_TLS is deprecated
context = SSLContext(ssl_version or PROTOCOL_TLS)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
30 passed, 1 warning in 2.12s
Benchmark(HF GPT-2 / streaming / paged-attn + cuda graph)
这一组刻意打开 --paged-attn --cuda-graph,让 decode 更快、CPU side 的 jitter 更显眼。
python -m rosellm.roseinfer.benchmark_streaming \
--hf-model-id gpt2 --device cuda \
--prompt "Hello" \
--pretok --tokenize-workers 0 \
--stream-interval 8 \
--num-requests 64 \
--warmup-runs 1 --repeat-runs 1 \
--submit-interval-ms 20 --submit-schedule absolute \
--max-batch-size 64 --prefill-max-batch-size 64 \
--max-new-tokens 128 \
--no-stop-on-eos --no-prefix-cache \
--paged-attn --cuda-graph
Before
=== warmup 1/1 ===
=== streaming benchmark ===
Model: gpt2
Device: cuda
Pretok: True
Tokenize workers: 0
Stream interval: 8
Paged attention: True
CUDA Graph: True
NVTX: False
Requests: 64
Prompt tokens (total): 64
Completion tokens (total): 8192
Submit wall: 1.308910 s
Submit interval/schedule: 20.000 ms / absolute
Submit lag p50/p95/p99: 18.33/54.12/57.65 ms
add_request latency p50/p95/p99: 0.02/0.08/0.09 ms
Tokenize p50/p95/p99: 0.00/0.00/0.00 ms
Queue wait (post-tok) p50/p95/p99: 7.45/78.36/80.96 ms
Prefill->first token p50/p95/p99: 4.75/5.04/7.04 ms
TTFT p50/p95/p99: 12.33/83.08/85.95 ms
TPOT p50/p95/p99: 12.51/13.29/13.87 ms/token
ITL p50/p95/p99: 4.09/77.69/85.01 ms
Latency p50/p95/p99: 1602.08/1699.43/1770.23 ms
Throughput (completion,total): 3042.19 tokens/s
After(host buffers)
=== warmup 1/1 ===
=== streaming benchmark ===
Model: gpt2
Device: cuda
Pretok: True
Tokenize workers: 0
Stream interval: 8
Paged attention: True
CUDA Graph: True
NVTX: False
Requests: 64
Prompt tokens (total): 64
Completion tokens (total): 8192
Submit wall: 1.296431 s
Submit interval/schedule: 20.000 ms / absolute
Submit lag p50/p95/p99: 23.90/55.74/57.10 ms
add_request latency p50/p95/p99: 0.02/0.07/0.08 ms
Tokenize p50/p95/p99: 0.00/0.00/0.00 ms
Queue wait (post-tok) p50/p95/p99: 7.43/77.83/78.57 ms
Prefill->first token p50/p95/p99: 4.70/5.01/5.90 ms
TTFT p50/p95/p99: 12.21/82.61/83.27 ms
TPOT p50/p95/p99: 11.77/13.08/13.63 ms/token
ITL p50/p95/p99: 4.02/78.46/80.44 ms
Latency p50/p95/p99: 1509.76/1672.10/1738.46 ms
Throughput (completion,total): 3210.51 tokens/s
结论
这版本质是把 “每步 decode 的 metadata 准备” 从:
torch.tensor(list)(分配 + 拷贝)- pageable host → device(
non_blocking名义上开了但不一定真异步)
变成:
- 写入 pinned host buffer(numpy view 写 list,几乎没成本)
- pinned host → device(这才是有效的
non_blocking=True)
在这组配置下能看到比较稳定的收益:
Throughput: 3042.19 → 3210.51 tokens/s(+5.5%)TPOT p50: 12.51 → 11.77 ms/token(-5.9%)Latency p50: 1602.08 → 1509.76 ms(-5.8%)
后面如果继续往上抠:
dirty_slot_ids_t/ dirty row 的索引更新也可以做同样的 buffer 复用- 把 submit lag / queue wait 的长尾(p95/p99)再单独抓 trace 去定位是不是还有某个 CPU/GPU 同步点在制造抖动