从零实现 LLM Inference:029. Prefix Cache + Prefill Micro-Batching
上一版(Prefill Micro-Batching)把 streaming admission 的 prefill 合并成一次 batched forward,TTFT/吞吐都有明显收益。
但那一版为了让 PR 足够小:prefix cache 开启时,add_requests() 会 fallback 到逐条 add_request()。prefix cache hit 没问题,miss 就会退化成串行 prefill(burst 场景 TTFT/尾延迟被拉爆)。
这次 mini PR 把这条链路补齐:prefix cache 开启时也能 batch prefill(hit/miss 分流 + 同 prompt 去重),并顺手把 streaming benchmark 的指标补齐到 TTFT/TPOT/ITL(p99)。
代码变更
这次一共动了四块:
OnlineScheduler.add_requests():prefix cache 开启时不再 fallback,miss 走一次 batched prefill。- batch admission 里做 同 prompt 去重:同一轮只 prefill 一次,其它 request 直接共享 blocks + logits。
benchmark_streaming补齐 TPOT/ITL,并把所有关键指标统一打印 p50/p95/p99。- 加一个最小单测:确保 prefix cache 开启时
add_requests()仍然只 forward 一次。
roseinfer/engine.py
核心逻辑拆成三段(顺序很关键):
- Prefix cache hit:直接
attach(prompt, session),拿到last_logits+ 复用 KV blocks(不跑 prefill)。 - Prefix cache miss:把 miss 的请求合成一次 batched prefill forward,再逐条把 KV 写入 block manager;同时
prefix_cache.put()存进去。 - 同 prompt 去重:同一批 admission 里如果出现相同 prompt,只 prefill 一次,其它 request 共享 blocks + logits(避免“同 prompt burst 反而变慢”)。
另外:为了避免采样(torch.multinomial)顺序变化导致的随机性差异,第一个 token 的 sampling 统一按请求原始顺序逐个做。
diff --git a/rosellm/roseinfer/engine.py b/rosellm/roseinfer/engine.py
index 78ba210..b981dee 100644
--- a/rosellm/roseinfer/engine.py
+++ b/rosellm/roseinfer/engine.py
@@ -1574,21 +1574,6 @@ class OnlineScheduler:
) -> list[int]:
if not requests:
return []
- if self.use_prefix_cache:
- return [
- self.add_request(
- prompt=req.prompt,
- max_new_tokens=req.max_new_tokens,
- temperature=req.temperature,
- top_k=req.top_k,
- top_p=req.top_p,
- stop_on_eos=req.stop_on_eos,
- do_sample=req.do_sample,
- prompt_token_ids=req.prompt_token_ids,
- request_id=req.request_id,
- )
- for req in requests
- ]
eng = self.engine
eng.model.eval()
@@ -1616,7 +1601,12 @@ class OnlineScheduler:
sessions: list[InferenceSession] = []
token_ids_list: list[list[int]] = []
- for req in requests:
+ last_logits_per_req: list[torch.Tensor | None] = [None for _ in requests]
+ miss_idx: list[int] = []
+ dup_of: dict[int, int] = {}
+ first_idx_for_prompt: dict[str, int] = {}
+
+ for i, req in enumerate(requests):
rid = alloc_rid(req.request_id)
rids.append(rid)
@@ -1656,37 +1646,84 @@ class OnlineScheduler:
)
sessions.append(sess)
- batch_idx = [i for i, s in enumerate(sessions) if not s.finished]
- if batch_idx:
- batch_token_ids = [token_ids_list[i] for i in batch_idx]
+ if self.use_prefix_cache:
+ src = first_idx_for_prompt.get(req.prompt)
+ if src is not None:
+ dup_of[i] = src
+ continue
+ first_idx_for_prompt[req.prompt] = i
+
+ cached_logits = eng.prefix_cache.attach(req.prompt, sess)
+ if cached_logits is not None:
+ last_logits_per_req[i] = cached_logits
+ continue
+
+ miss_idx.append(i)
+
+ if miss_idx:
+ batch_token_ids = [token_ids_list[i] for i in miss_idx]
input_ids, attn_mask, lengths, _ = eng._encode_prompt_token_ids_batch(
batch_token_ids
)
- batch_sessions = [sessions[i] for i in batch_idx]
+ batch_sessions = [sessions[i] for i in miss_idx]
last_logits = eng._prefill_register_kv_batch(
sessions=batch_sessions,
input_ids=input_ids,
attention_mask=attn_mask,
lengths=lengths,
)
- for b, sess in enumerate(batch_sessions):
- token_id = eng._sample_next_token(
- logits=last_logits[b : b + 1],
- temperature=sess.temperature,
- top_k=sess.top_k,
- top_p=sess.top_p,
- do_sample=sess.do_sample,
- )
- sess.generated_ids.append(int(token_id))
- sess.step_count = 1
- if sess.stop_on_eos:
- eos_id = eng.eos_token_id
- if eos_id is not None and int(token_id) == eos_id:
- sess.finished = True
- if sess.max_new_tokens > 0 and sess.step_count >= sess.max_new_tokens:
- sess.finished = True
+ for b, idx in enumerate(miss_idx):
+ logits = last_logits[b : b + 1]
+ last_logits_per_req[idx] = logits
+ if self.use_prefix_cache:
+ eng.prefix_cache.put(
+ requests[idx].prompt,
+ sessions[idx],
+ logits,
+ )
+
+ if dup_of:
+ kvm = eng.kv_manager
+ for idx, src in dup_of.items():
+ sess = sessions[idx]
+ if sess.finished:
+ continue
+ src_sess = sessions[src]
+ if src_sess.finished:
+ sess.finished = True
+ continue
+ sess.prompt_length = src_sess.prompt_length
+ sess.block_ids_per_layer = [[] for _ in range(kvm.num_layers)]
+ for layer_idx, block_ids in enumerate(src_sess.block_ids_per_layer):
+ if not block_ids:
+ continue
+ kvm.incref_blocks(block_ids)
+ sess.block_ids_per_layer[layer_idx] = list(block_ids)
+ last_logits_per_req[idx] = last_logits_per_req[src]
+
+ for idx, sess in enumerate(sessions):
+ if sess.finished:
+ continue
+ logits = last_logits_per_req[idx]
+ if logits is None:
+ raise RuntimeError(f"missing prefill logits for request {rids[idx]}")
+ token_id = eng._sample_next_token(
+ logits=logits,
+ temperature=sess.temperature,
+ top_k=sess.top_k,
+ top_p=sess.top_p,
+ do_sample=sess.do_sample,
+ )
+ sess.generated_ids.append(int(token_id))
+ sess.step_count = 1
+ if sess.stop_on_eos:
+ eos_id = eng.eos_token_id
+ if eos_id is not None and int(token_id) == eos_id:
+ sess.finished = True
+ if sess.max_new_tokens > 0 and sess.step_count >= sess.max_new_tokens:
+ sess.finished = True
+ if sess.finished:
+ sess.release_kv_blocks()
for rid, sess in zip(rids, sessions):
self._sessions[rid] = sess
roseinfer/server.py
为了算 TPOT/ITL,需要“每个 token 真正被推到 queue 的时间戳”。最简单的做法就是在 worker 里记录 time.perf_counter()。
默认不打开(避免影响正常 server);benchmark_streaming 显式开启。
diff --git a/rosellm/roseinfer/server.py b/rosellm/roseinfer/server.py
index 8479835..b84fafa 100644
--- a/rosellm/roseinfer/server.py
+++ b/rosellm/roseinfer/server.py
@@ -111,6 +111,7 @@ class SchedulerManager:
self,
engine: InferenceEngine,
max_batch_size: int = 8,
+ record_token_timestamps: bool = False,
) -> None:
@@ -121,6 +122,8 @@ class SchedulerManager:
self._wakeup = threading.Event()
self._queues: Dict[int, "queue.Queue[Optional[str]]"] = {}
self._detoks: Dict[int, BaseDetokenizer] = {}
+ self._record_token_timestamps = bool(record_token_timestamps)
+ self._token_timestamps: Dict[int, list[float]] = {}
@@ -177,6 +181,8 @@ class SchedulerManager:
q: "queue.Queue[Optional[str]]" = queue.Queue()
self._queues[request_id] = q
self._detoks[request_id] = detok
+ if self._record_token_timestamps:
+ self._token_timestamps[request_id] = []
@@ -193,6 +199,14 @@ class SchedulerManager:
self._wakeup.set()
return request_id
+
+ def pop_token_timestamps(
+ self,
+ request_id: int,
+ ) -> list[float]:
+ with self._lock:
+ out = self._token_timestamps.pop(request_id, None)
+ return list(out) if out is not None else []
@@ -249,10 +263,17 @@ class SchedulerManager:
with self._lock:
q = self._queues.get(rid)
detok = self._detoks.get(rid)
+ token_ts = (
+ self._token_timestamps.get(rid)
+ if self._record_token_timestamps
+ else None
+ )
@@ -257,6 +278,8 @@ class SchedulerManager:
if q is None or detok is None:
self.scheduler.discard_request(rid)
continue
for tid in self.scheduler.get_generated_ids(rid):
+ if token_ts is not None:
+ token_ts.append(time.perf_counter())
piece = detok.on_token(int(tid))
if piece:
q.put(piece)
@@ -274,9 +295,16 @@ class SchedulerManager:
with self._lock:
q = self._queues.get(rid)
detok = self._detoks.get(rid)
+ token_ts = (
+ self._token_timestamps.get(rid)
+ if self._record_token_timestamps
+ else None
+ )
@@ -281,6 +309,8 @@ class SchedulerManager:
if q is None or detok is None:
self.scheduler.discard_request(rid)
continue
+ if token_ts is not None:
+ token_ts.append(time.perf_counter())
piece = detok.on_token(int(token_id))
if piece:
q.put(piece)
roseinfer/benchmark_streaming.py
TPOT/ITL 直接从 token_timestamps 算,所有指标都打印 p50/p95/p99。
diff --git a/rosellm/roseinfer/benchmark_streaming.py b/rosellm/roseinfer/benchmark_streaming.py
index e44daed..6462272 100644
--- a/rosellm/roseinfer/benchmark_streaming.py
+++ b/rosellm/roseinfer/benchmark_streaming.py
@@ -23,6 +23,7 @@ class StreamResult:
finish_ts: float
completion_text: str
completion_tokens: int
+ token_timestamps: list[float]
@@ -178,7 +179,11 @@ def main() -> None:
kv_cache_max_concurrency=kv_cache_max_concurrency,
prefix_cache_max_entries=len(set(prompts)),
)
- mgr = SchedulerManager(engine, max_batch_size=int(args.max_batch_size))
+ mgr = SchedulerManager(
+ engine,
+ max_batch_size=int(args.max_batch_size),
+ record_token_timestamps=True,
+ )
@@ -194,14 +199,15 @@ def main() -> None:
first_token_ts: float | None = None
pieces: list[str] = []
for piece in mgr.stream_text(request_id):
- if first_token_ts is None:
- first_token_ts = time.perf_counter()
pieces.append(piece)
finish_ts = time.perf_counter()
+ token_ts = mgr.pop_token_timestamps(request_id)
+ if token_ts:
+ first_token_ts = token_ts[0]
if first_token_ts is None:
first_token_ts = finish_ts
completion_text = "".join(pieces)
- completion_tokens = count_tokens(engine.tokenizer, completion_text)
+ completion_tokens = len(token_ts)
@@ -244,8 +251,18 @@ def main() -> None:
add_lats = [r.submit_end - r.submit_start for r in results]
ttfts = [r.first_token_ts - r.submit_start for r in results]
totals = [r.finish_ts - r.submit_start for r in results]
- completion_tokens = [r.completion_tokens for r in results]
- sum_tokens = sum(completion_tokens)
+ completion_tokens = [int(r.completion_tokens) for r in results]
+ sum_tokens = int(sum(completion_tokens))
+
+ tpots: list[float] = []
+ itls: list[float] = []
+ for r in results:
+ ts = r.token_timestamps
+ if len(ts) < 2:
+ continue
+ tpots.append((ts[-1] - ts[0]) / float(len(ts) - 1))
+ for i in range(1, len(ts)):
+ itls.append(ts[i] - ts[i - 1])
@@ -270,6 +287,20 @@ def main() -> None:
f"{percentile(ttfts, 95)*1e3:.2f}/"
f"{percentile(ttfts, 99)*1e3:.2f} ms"
)
+ tpot_p50 = statistics.median(tpots) if tpots else 0.0
+ itl_p50 = statistics.median(itls) if itls else 0.0
+ print(
+ f"TPOT p50/p95/p99: "
+ f"{tpot_p50*1e3:.2f}/"
+ f"{percentile(tpots, 95)*1e3:.2f}/"
+ f"{percentile(tpots, 99)*1e3:.2f} ms/token"
+ )
+ print(
+ f"ITL p50/p95/p99: "
+ f"{itl_p50*1e3:.2f}/"
+ f"{percentile(itls, 95)*1e3:.2f}/"
+ f"{percentile(itls, 99)*1e3:.2f} ms"
+ )
新增测试
主要覆盖两件事:
- prefix cache 开启时
add_requests()不再 fallback 逐条 prefill(forward 只跑一次) - 同 prompt 在同一批 admission 里会去重(forward 仍然只跑一次)
文件:tests/test_online_scheduler_add_requests_prefix_cache.py
diff --git a/tests/test_online_scheduler_add_requests_prefix_cache.py b/tests/test_online_scheduler_add_requests_prefix_cache.py
new file mode 100644
index 0000000..830ea02
--- /dev/null
+++ b/tests/test_online_scheduler_add_requests_prefix_cache.py
@@ -0,0 +1,156 @@
+import torch
+
+from rosellm.roseinfer.engine import InferenceEngine, OnlineRequest, OnlineScheduler
+from rosellm.rosetrainer.config import GPTConfig
+from rosellm.rosetrainer.model import GPTModel
+
+
+class _CountingTokenizer:
+ def __init__(self, vocab_size: int = 128) -> None:
+ self.vocab_size = int(vocab_size)
+ self.eos_token_id = 0
+ self.pad_token_id = 0
+ self.encode_calls = 0
+
+ def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
+ self.encode_calls += 1
+ del text, add_special_tokens
+ return [1, 2, 3]
+
+ def decode(self, ids: list[int], skip_special_tokens: bool = True) -> str:
+ del ids, skip_special_tokens
+ return ""
+
+
+def test_online_scheduler_add_requests_prefix_cache_batches_prefill() -> None:
+ torch.manual_seed(0)
+ cfg = GPTConfig(
+ vocab_size=128,
+ max_position_embeddings=32,
+ n_layers=2,
+ n_heads=2,
+ d_model=32,
+ d_ff=64,
+ dropout=0.0,
+ )
+ tok = _CountingTokenizer(vocab_size=128)
+ model = GPTModel(cfg)
+ forward_calls = 0
+ orig_forward = model.forward
+
+ def counting_forward(*args, **kwargs): # type: ignore[no-untyped-def]
+ nonlocal forward_calls
+ forward_calls += 1
+ return orig_forward(*args, **kwargs)
+
+ model.forward = counting_forward # type: ignore[method-assign]
+
+ engine = InferenceEngine(
+ model=model,
+ config=cfg,
+ tokenizer=tok,
+ tokenizer_name="dummy",
+ device="cpu",
+ use_amp=False,
+ kv_cache_max_concurrency=8,
+ prefix_cache_max_entries=8,
+ )
+
+ scheduler = OnlineScheduler(engine, max_batch_size=8, use_prefix_cache=True)
+ scheduler.add_requests(
+ [
+ OnlineRequest(
+ prompt="p0",
+ prompt_token_ids=[1, 2, 3],
+ max_new_tokens=1,
+ stop_on_eos=False,
+ do_sample=False,
+ request_id=0,
+ ),
+ OnlineRequest(
+ prompt="p1",
+ prompt_token_ids=[1, 2, 3, 4],
+ max_new_tokens=1,
+ stop_on_eos=False,
+ do_sample=False,
+ request_id=1,
+ ),
+ OnlineRequest(
+ prompt="p2",
+ prompt_token_ids=[1, 2],
+ max_new_tokens=1,
+ stop_on_eos=False,
+ do_sample=False,
+ request_id=2,
+ ),
+ ]
+ )
+ assert tok.encode_calls == 0
+ assert forward_calls == 1
+
+
+def test_online_scheduler_add_requests_prefix_cache_dedups_prompts_in_batch() -> None:
+ torch.manual_seed(0)
+ cfg = GPTConfig(
+ vocab_size=128,
+ max_position_embeddings=32,
+ n_layers=2,
+ n_heads=2,
+ d_model=32,
+ d_ff=64,
+ dropout=0.0,
+ )
+ tok = _CountingTokenizer(vocab_size=128)
+ model = GPTModel(cfg)
+ forward_calls = 0
+ orig_forward = model.forward
+
+ def counting_forward(*args, **kwargs): # type: ignore[no-untyped-def]
+ nonlocal forward_calls
+ forward_calls += 1
+ return orig_forward(*args, **kwargs)
+
+ model.forward = counting_forward # type: ignore[method-assign]
+
+ engine = InferenceEngine(
+ model=model,
+ config=cfg,
+ tokenizer=tok,
+ tokenizer_name="dummy",
+ device="cpu",
+ use_amp=False,
+ kv_cache_max_concurrency=8,
+ prefix_cache_max_entries=8,
+ )
+
+ scheduler = OnlineScheduler(engine, max_batch_size=8, use_prefix_cache=True)
+ scheduler.add_requests(
+ [
+ OnlineRequest(
+ prompt="same",
+ prompt_token_ids=[1, 2, 3],
+ max_new_tokens=1,
+ stop_on_eos=False,
+ do_sample=False,
+ request_id=0,
+ ),
+ OnlineRequest(
+ prompt="same",
+ prompt_token_ids=[1, 2, 3],
+ max_new_tokens=1,
+ stop_on_eos=False,
+ do_sample=False,
+ request_id=1,
+ ),
+ OnlineRequest(
+ prompt="same",
+ prompt_token_ids=[1, 2, 3],
+ max_new_tokens=1,
+ stop_on_eos=False,
+ do_sample=False,
+ request_id=2,
+ ),
+ ]
+ )
+ assert tok.encode_calls == 0
+ assert forward_calls == 1
指标口径
这版 benchmark 里三个指标的口径:
- TTFT:
t_first_token - t_submit - TPOT:
(t_last_token - t_first_token) / (n_tokens - 1) - ITL:
t_i - t_{i-1}(把所有 token 的间隔摊平到一个分布,看 p50/p95/p99)
这里的 t_first_token / t_i 是在 worker 里 token 真正被推到 streaming queue 的时刻打点,所以它包含了:
- prefill 在 worker 里占用的时间
- worker 的调度/锁/队列开销
- decode step 之间的空洞
所以这次虽然没改 decode kernel,TPOT/ITL 也可能变好:之前 miss 串行 prefill 会把一整段时间塞在 worker 里,decode step 被挡住,token gap 自然变大;现在 miss 合并成一次 batched prefill,这段阻塞就收敛了。
运行
单测
pytest -q
输出:
......... [100%]
9 passed, 1 warning in 1.63s
Benchmark(HF GPT-2)
同一个命令,直接对比(离线跑可以加上 HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1):
HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python -m rosellm.roseinfer.benchmark_streaming \
--hf-model-id gpt2 \
--device cpu \
--prompt "Hello" \
--unique-prompts \
--num-requests 32 \
--max-new-tokens 8 \
--no-stop-on-eos
Before(prefix cache 开启,但 miss 串行 prefill)
=== streaming benchmark ===
Model: gpt2
Device: cpu
Requests: 32
Prompt tokens (total): 128
Completion tokens (total): 256
Submit wall: 0.084512 s
add_request latency p50/p95/p99: 0.05/0.14/54.01 ms
TTFT p50/p95/p99: 418.48/798.66/798.77 ms
TPOT p50/p95/p99: 140.64/161.01/165.49 ms/token
ITL p50/p95/p99: 107.39/270.98/400.98 ms
Latency p50/p95/p99: 1415.70/1478.94/1479.14 ms
Throughput (completion,total): 163.84 tokens/s
After(prefix cache hit/miss 分流 + miss 合并 batched prefill)
=== streaming benchmark ===
Model: gpt2
Device: cpu
Requests: 32
Prompt tokens (total): 128
Completion tokens (total): 256
Submit wall: 0.082502 s
add_request latency p50/p95/p99: 0.04/0.16/52.72 ms
TTFT p50/p95/p99: 143.41/260.17/260.26 ms
TPOT p50/p95/p99: 100.38/102.51/106.35 ms/token
ITL p50/p95/p99: 101.38/126.63/150.47 ms
Latency p50/p95/p99: 862.18/927.23/927.38 ms
Throughput (completion,total): 253.68 tokens/s
对比一下核心指标:
- TTFT p50:
418.48ms -> 143.41ms(~2.9x) - TTFT p95:
798.66ms -> 260.17ms(~3.1x) - TPOT p50:
140.64ms -> 100.38ms / token(~1.4x) - ITL p99:
400.98ms -> 150.47ms(~2.7x) - Latency p50:
1415.70ms -> 862.18ms(~1.6x) - Throughput:
163.84 -> 253.68 tokens/s(~1.5x)