从零实现 LLM Inference:033. Prefill Admission Packing
上一版我们引入了 prefill token budget:限制每轮 admission 允许吃掉的 prompt tokens,从而避免一轮 prefill 抢占太久。
但只要 admission 还是 FIFO,prompt 长短差异很大时,依然会遇到典型的 head-of-line blocking:
- 队头是长 prompt(甚至超过 budget):这一轮只能先处理它(或被迫单独成 batch),后面的短 prompt 明明很快,却要一起排队等
- 结果就是:短请求的 TTFT tail 被长请求“拖着走”
业界常见做法是在 token budget 的基础上再做一步:在 waiting queue 上做 packing(简单说就是“看一眼后面的,先把能塞进去的小的塞满”)。
这次 mini PR 就做这件事:
- 增加
prefill_admission_policy=fifo|pack(默认 fifo,不破坏旧行为) pack模式下对 pending 做一个 lookahead window,按 cost(prompt tokens)greedy 选一组塞进 budget- 加一个公平性兜底:
prefill_force_fifo_every,每 N 轮强制 FIFO,避免长 prompt 永远被跳过
代码变更
roseinfer/server.py
- 新增 admission policy + lookahead + force FIFO knobs
_take_pending_for_prefill支持pack:在 window 内按 token cost greedy 选取(输出顺序仍按 arrival 保持稳定)
diff --git a/rosellm/roseinfer/server.py b/rosellm/roseinfer/server.py
index 0f28643..d4f9885 100644
--- a/rosellm/roseinfer/server.py
+++ b/rosellm/roseinfer/server.py
@@ -107,6 +107,9 @@ class _PendingRequest:
do_sample: bool
+PrefillAdmissionPolicy = Literal["fifo", "pack"]
+
+
def _take_pending_for_prefill(
pending_buf: "deque[_PendingRequest]",
pending_q: "queue.Queue[_PendingRequest]",
@@ -114,6 +117,9 @@ def _take_pending_for_prefill(
max_reqs: int,
max_tokens: Optional[int],
max_context: int,
+ admission_policy: PrefillAdmissionPolicy,
+ lookahead: int,
+ force_fifo: bool,
) -> list[_PendingRequest]:
if max_reqs <= 0:
raise ValueError("max_reqs must be positive")
@@ -121,29 +127,75 @@ def _take_pending_for_prefill(
raise ValueError("max_context must be positive")
if max_tokens is not None and max_tokens <= 0:
raise ValueError("max_tokens must be positive")
+ if lookahead <= 0:
+ raise ValueError("lookahead must be positive")
+
+ if force_fifo or admission_policy == "fifo" or max_tokens is None:
+ out: list[_PendingRequest] = []
+ tokens_used = 0
+ while len(out) < max_reqs:
+ if pending_buf:
+ req = pending_buf.popleft()
+ else:
+ try:
+ req = pending_q.get_nowait()
+ except queue.Empty:
+ break
+
+ cost = min(len(req.prompt_token_ids), max_context)
+ if max_tokens is not None:
+ if not out and cost > max_tokens:
+ out.append(req)
+ break
+ if out and tokens_used + cost > max_tokens:
+ pending_buf.appendleft(req)
+ break
+
+ out.append(req)
+ tokens_used += cost
+ return out
+
+ # admission_policy == "pack" and max_tokens is not None.
+ window: list[_PendingRequest] = []
+ while len(window) < lookahead:
+ if pending_buf:
+ window.append(pending_buf.popleft())
+ continue
+ try:
+ window.append(pending_q.get_nowait())
+ except queue.Empty:
+ break
+ if not window:
+ return []
+
+ costs = [min(len(req.prompt_token_ids), max_context) for req in window]
+ order = sorted(range(len(window)), key=lambda i: (costs[i], i))
+ selected = [False for _ in window]
+ tokens_used = 0
+ selected_count = 0
+ for idx in order:
+ if selected_count >= max_reqs:
+ break
+ cost = costs[idx]
+ if cost > max_tokens:
+ continue
+ if tokens_used + cost > max_tokens:
+ continue
+ selected[idx] = True
+ tokens_used += cost
+ selected_count += 1
+ if selected_count == 0:
+ selected[0] = True
+
+ out: list[_PendingRequest] = []
+ for idx, req in enumerate(window):
+ if selected[idx]:
+ out.append(req)
+ if len(out) >= max_reqs:
+ break
+ for idx in range(len(window) - 1, -1, -1):
+ if not selected[idx]:
+ pending_buf.appendleft(window[idx])
return out
@@ -156,6 +208,9 @@ class SchedulerManager:
prefill_max_tokens: Optional[int] = None,
record_token_timestamps: bool = False,
decode_first: bool = False,
+ prefill_admission_policy: PrefillAdmissionPolicy = "fifo",
+ prefill_admission_lookahead: int = 64,
+ prefill_force_fifo_every: int = 0,
) -> None:
if max_batch_size <= 0:
raise ValueError("max_batch_size must be positive")
@@ -170,6 +225,16 @@ class SchedulerManager:
if self._prefill_max_tokens is not None and self._prefill_max_tokens <= 0:
raise ValueError("prefill_max_tokens must be positive")
self._decode_first = bool(decode_first)
+ if prefill_admission_policy not in ("fifo", "pack"):
+ raise ValueError("prefill_admission_policy must be fifo|pack")
+ self._prefill_admission_policy = prefill_admission_policy
+ self._prefill_admission_lookahead = int(prefill_admission_lookahead)
+ if self._prefill_admission_lookahead <= 0:
+ raise ValueError("prefill_admission_lookahead must be positive")
+ self._prefill_force_fifo_every = int(prefill_force_fifo_every)
+ if self._prefill_force_fifo_every < 0:
+ raise ValueError("prefill_force_fifo_every must be non-negative")
+ self._prefill_iter = 0
@@
with self._lock:
if not self._running:
break
max_new = self._prefill_max_batch_size
max_tokens = self._prefill_max_tokens
max_context = int(self.engine.config.max_position_embeddings)
decode_first = self._decode_first
+ admission_policy = self._prefill_admission_policy
+ lookahead = self._prefill_admission_lookahead
+ force_fifo_every = self._prefill_force_fifo_every
+
+ self._prefill_iter += 1
+ force_fifo = (
+ force_fifo_every > 0
+ and (self._prefill_iter % force_fifo_every == 0)
+ )
@@
pending = _take_pending_for_prefill(
self._pending_buf,
self._pending,
max_reqs=max_new,
max_tokens=max_tokens,
max_context=max_context,
+ admission_policy=admission_policy,
+ lookahead=lookahead,
+ force_fifo=force_fifo,
)
roseinfer/benchmark_streaming.py
把 knobs 暴露成 CLI,方便做 before/after:
diff --git a/rosellm/roseinfer/benchmark_streaming.py b/rosellm/roseinfer/benchmark_streaming.py
index a06d549..cfa7439 100644
--- a/rosellm/roseinfer/benchmark_streaming.py
+++ b/rosellm/roseinfer/benchmark_streaming.py
@@ -107,6 +107,25 @@ def parse_args() -> argparse.Namespace:
"Max prompt tokens to prefill per worker iteration " "(default: unlimited)."
),
)
+ parser.add_argument(
+ "--prefill-admission-policy",
+ type=str,
+ default="fifo",
+ choices=["fifo", "pack"],
+ help="Prefill admission policy (default: fifo).",
+ )
+ parser.add_argument(
+ "--prefill-admission-lookahead",
+ type=int,
+ default=64,
+ help="Pending lookahead window for pack admission.",
+ )
+ parser.add_argument(
+ "--prefill-force-fifo-every",
+ type=int,
+ default=0,
+ help="Force FIFO admission every N iterations (0: disable).",
+ )
@@ -231,6 +250,9 @@ def main() -> None:
prefill_max_batch_size=args.prefill_max_batch_size,
prefill_max_tokens=args.prefill_max_tokens,
decode_first=args.decode_first,
+ prefill_admission_policy=args.prefill_admission_policy,
+ prefill_admission_lookahead=int(args.prefill_admission_lookahead),
+ prefill_force_fifo_every=int(args.prefill_force_fifo_every),
record_token_timestamps=True,
)
测试
补一个 packing 的核心语义测试:能绕过 oversize 队头,把短请求先塞进去,同时保证不会卡死。
diff --git a/tests/test_prefill_admission_packing.py b/tests/test_prefill_admission_packing.py
new file mode 100644
index 0000000..2836c43
--- /dev/null
+++ b/tests/test_prefill_admission_packing.py
@@ -0,0 +1,66 @@
+import queue
+from collections import deque
+
+from rosellm.roseinfer.server import _PendingRequest, _take_pending_for_prefill
+
+
+def _req(rid: int, n: int) -> _PendingRequest:
+ return _PendingRequest(
+ request_id=int(rid),
+ prompt="",
+ prompt_token_ids=[1] * int(n),
+ max_new_tokens=1,
+ temperature=1.0,
+ top_k=0,
+ top_p=1.0,
+ stop_on_eos=False,
+ do_sample=False,
+ )
+
+
+def test_take_pending_for_prefill_pack_skips_oversize_head() -> None:
+ buf: deque[_PendingRequest] = deque()
+ q: "queue.Queue[_PendingRequest]" = queue.Queue()
+ q.put(_req(0, 100))
+ q.put(_req(1, 2))
+ q.put(_req(2, 2))
+
+ out = _take_pending_for_prefill(
+ buf,
+ q,
+ max_reqs=8,
+ max_tokens=4,
+ max_context=1024,
+ admission_policy="pack",
+ lookahead=16,
+ force_fifo=False,
+ )
+ assert [r.request_id for r in out] == [1, 2]
+ assert list(buf)[0].request_id == 0
+ try:
+ q.get_nowait()
+ except queue.Empty:
+ pass
+ else:
+ raise AssertionError("expected queue to be empty")
+
+
+def test_take_pending_for_prefill_pack_progresses_when_all_oversize() -> None:
+ buf: deque[_PendingRequest] = deque()
+ q: "queue.Queue[_PendingRequest]" = queue.Queue()
+ q.put(_req(0, 100))
+ q.put(_req(1, 100))
+
+ out = _take_pending_for_prefill(
+ buf,
+ q,
+ max_reqs=8,
+ max_tokens=4,
+ max_context=1024,
+ admission_policy="pack",
+ lookahead=16,
+ force_fifo=False,
+ )
+ assert [r.request_id for r in out] == [0]
+ assert list(buf)[0].request_id == 1
(另外,把 test_scheduler_manager_prefill_max_tokens.py 里 _take_pending_for_prefill 的调用签名同步了一下,这里就不重复贴了。)
运行
单测
pytest -q
输出:
.................. [100%]
=============================== warnings summary ===============================
../anaconda3/lib/python3.10/site-packages/urllib3/util/ssl_.py:260
/data/projects/anaconda3/lib/python3.10/site-packages/urllib3/util/ssl_.py:260: DeprecationWarning: ssl.PROTOCOL_TLS is deprecated
context = SSLContext(ssl_version or PROTOCOL_TLS)
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
18 passed, 1 warning in 1.63s
Benchmark(HF GPT-2)
这个 workload 刻意制造 HOL:
--prefill-max-tokens 256:每轮 admission 有 token budget--prompt-repeats "512,1,1,1":队列里持续出现 oversize 长 prompt- FIFO 下:每遇到一个 oversize 队头,就会把后面的短请求一起“拖住”
- pack 下:lookahead 里先把能塞进 budget 的短请求取出来,长请求留在队头(并用
--prefill-force-fifo-every 8保证它也能被处理)
Before(FIFO)
HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python -m rosellm.roseinfer.benchmark_streaming \
--hf-model-id gpt2 \
--device cuda \
--prompt "Hello" \
--prompt-repeats "512,1,1,1" \
--unique-prompts \
--num-requests 128 \
--submit-interval-ms 0 \
--max-batch-size 8 \
--prefill-max-batch-size 128 \
--prefill-max-tokens 256 \
--max-new-tokens 32 \
--no-stop-on-eos \
--no-prefix-cache
=== streaming benchmark ===
Model: gpt2
Device: cuda
Requests: 128
Prompt tokens (total): 16864
Completion tokens (total): 4096
Submit wall: 0.103464 s
add_request latency p50/p95/p99: 0.03/0.42/0.46 ms
TTFT p50/p95/p99: 672.29/1106.39/1144.49 ms
TPOT p50/p95/p99: 142.26/145.73/146.08 ms/token
ITL p50/p95/p99: 152.93/160.77/189.82 ms
Latency p50/p95/p99: 5184.35/5341.20/5347.30 ms
Throughput (completion,total): 751.55 tokens/s
After(Pack + force FIFO)
HF_HUB_OFFLINE=1 TRANSFORMERS_OFFLINE=1 python -m rosellm.roseinfer.benchmark_streaming \
--hf-model-id gpt2 \
--device cuda \
--prompt "Hello" \
--prompt-repeats "512,1,1,1" \
--unique-prompts \
--num-requests 128 \
--submit-interval-ms 0 \
--max-batch-size 8 \
--prefill-max-batch-size 128 \
--prefill-max-tokens 256 \
--prefill-admission-policy pack \
--prefill-admission-lookahead 64 \
--prefill-force-fifo-every 8 \
--max-new-tokens 32 \
--no-stop-on-eos \
--no-prefix-cache
=== streaming benchmark ===
Model: gpt2
Device: cuda
Requests: 128
Prompt tokens (total): 16864
Completion tokens (total): 4096
Submit wall: 0.105095 s
add_request latency p50/p95/p99: 0.03/0.42/0.46 ms
TTFT p50/p95/p99: 216.77/611.07/689.69 ms
TPOT p50/p95/p99: 156.40/158.05/158.17 ms/token
ITL p50/p95/p99: 155.06/190.19/221.63 ms
Latency p50/p95/p99: 5086.20/5258.17/5263.87 ms
Throughput (completion,total): 763.48 tokens/s
这组数据里最核心的变化:
- TTFT p99:
1144.49ms -> 689.69ms(-39.7%) - Latency p99:
5347.30ms -> 5263.87ms(-1.6%) - Throughput:
751.55 -> 763.48 tokens/s(+1.6%)
同时也能看到一个典型取舍:
- pack 会让更多短请求更早进入 active set,decode 侧的并发更快变高,所以 TPOT/ITL 的 tail 可能会上升
- 这部分一般需要配合后续的策略一起做:比如更细粒度的 admission(chunked prefill)、或者把 decode 的公平性/优先级再细化