4 minute read

在实现了简单的 inference server 和初步的 openai api 接口之后,我们需要进一步接入我们的 online scheduler,之前我们的 inference server 只接了 generate 以及 stream_generate 这种旧的接口,还没有接使用了 kv-block 的 online scheduler,当接入了 online scheduler 之后,我们也就真正地能从入口请求来感受下初步的 continuous batching 能力了。

代码变更

主要是添加了一个 scheduler manager 来处理 server 和 online scheduler 的交互工作。

server.py

diff --git a/rosellm/roseinfer/server.py b/rosellm/roseinfer/server.py
index 099af58..c96e252 100644
--- a/rosellm/roseinfer/server.py
+++ b/rosellm/roseinfer/server.py
@@ -1,13 +1,16 @@
 import argparse
+import queue
+import threading
 import time
 import uuid
-from typing import List, Literal, Optional
+from typing import Dict, Iterator, List, Literal, Optional
 
 from fastapi import FastAPI
 from fastapi.responses import StreamingResponse
 from pydantic import BaseModel
 
-from .engine import InferenceEngine
+from .detokenizer import BaseDetokenizer
+from .engine import InferenceEngine, OnlineScheduler
 
 
 class GenerateRequest(BaseModel):
@@ -88,6 +91,114 @@ class ChatCompletionRequest(BaseModel):
     stream: bool = False
 
 
+class SchedulerManager:
+    def __init__(
+        self,
+        engine: InferenceEngine,
+        max_batch_size: int = 8,
+    ) -> None:
+        self.engine = engine
+        self.scheduler = OnlineScheduler(
+            engine,
+            max_batch_size=max_batch_size,
+        )
+        self._lock = threading.Lock()
+        self._queues: Dict[int, "queue.Queue[Optional[str]]"] = {}
+        self._detoks: Dict[int, BaseDetokenizer] = {}
+        self._running = True
+        self._worker = threading.Thread(
+            target=self._worker_loop,
+            daemon=True,
+        )
+        self._worker.start()
+
+    def add_request(
+        self,
+        prompt: str,
+        max_new_tokens: int = 64,
+        temperature: float = 1.0,
+        top_k: int = 0,
+        top_p: float = 1.0,
+        stop_on_eos: bool = True,
+        do_sample: bool = False,
+    ) -> int:
+        token_ids = self.engine.tokenizer.encode(
+            prompt,
+            add_special_tokens=False,
+        )
+        if not token_ids:
+            token_ids = [self.engine.eos_token_id]
+        detok = self.engine._make_detok()
+        detok.start_prompt(token_ids)
+        with self._lock:
+            request_id = self.scheduler.add_request(
+                prompt=prompt,
+                max_new_tokens=max_new_tokens,
+                temperature=temperature,
+                top_k=top_k,
+                top_p=top_p,
+                stop_on_eos=stop_on_eos,
+                do_sample=do_sample,
+            )
+            q: "queue.Queue[Optional[str]]" = queue.Queue()
+            self._queues[request_id] = q
+            self._detoks[request_id] = detok
+            session = self.scheduler._sessions[request_id]
+            for tid in session.generated_ids:
+                piece = detok.on_token(int(tid))
+                if piece:
+                    q.put(piece)
+            if session.finished:
+                tail = detok.flush()
+                if tail:
+                    q.put(tail)
+                q.put(None)
+        return request_id
+
+    def stream_text(self, request_id: int) -> Iterator[str]:
+        q = self._queues[request_id]
+        while True:
+            piece = q.get()
+            if piece is None:
+                break
+            yield piece
+
+    def _worker_loop(self) -> None:
+        while self._running:
+            with self._lock:
+                has_work = self.scheduler.has_unfinished()
+                if has_work:
+                    step_tokens = self.scheduler.step()
+                else:
+                    step_tokens = {}
+            if not has_work:
+                time.sleep(0.005)
+                continue
+            for rid, token_id in step_tokens.items():
+                detok = self._detoks.get(rid)
+                q = self._queues.get(rid)
+                if detok is None or q is None:
+                    continue
+                piece = detok.on_token(int(token_id))
+                if piece:
+                    q.put(piece)
+            finished_ids: list[int] = []
+            with self._lock:
+                for rid in list(self._queues.keys()):
+                    if self.scheduler.is_finished(rid):
+                        finished_ids.append(rid)
+            for rid in finished_ids:
+                detok = self._detoks.pop(rid, None)
+                q = self._queues.pop(rid, None)
+                if q is None:
+                    continue
+                if detok is not None:
+                    tail = detok.flush()
+                    if tail:
+                        q.push(tail)
+                q.put(None)
+
+
 def format_messages_as_prompt(messages: List[ChatMessage]) -> str:
     lines: list[str] = []
     for m in messages:
@@ -120,6 +231,7 @@ def estimate_usage(
 
 def create_app(engine: InferenceEngine) -> FastAPI:
     app = FastAPI(title="roseinfer", version="0.1.0")
+    sched_manager = SchedulerManager(engine, max_batch_size=8)
 
     @app.get("/health")
     def health() -> dict[str, str]:
@@ -130,18 +242,18 @@ def create_app(engine: InferenceEngine) -> FastAPI:
         body: GenerateRequest,
     ) -> GenerateResponse | StreamingResponse:
         if body.stream:
+            request_id = sched_manager.add_request(
+                prompt=body.prompt,
+                max_new_tokens=body.max_new_tokens,
+                temperature=body.temperature,
+                top_k=body.top_k,
+                top_p=body.top_p,
+                stop_on_eos=body.stop_on_eos,
+                do_sample=body.do_sample,
+            )
 
-            def token_stream() -> bytes:
-                chunks = engine.stream_generate(
-                    prompt=body.prompt,
-                    max_new_tokens=body.max_new_tokens,
-                    temperature=body.temperature,
-                    top_k=body.top_k,
-                    top_p=body.top_p,
-                    stop_on_eos=body.stop_on_eos,
-                    do_sample=body.do_sample,
-                )
-                for piece in chunks:
+            def token_stream() -> Iterator[bytes]:
+                for piece in sched_manager.stream_text(request_id):
                     yield piece.encode("utf-8")
 
             return StreamingResponse(
@@ -168,6 +280,15 @@ def create_app(engine: InferenceEngine) -> FastAPI:
         completion_id = f"chatcmpl-{uuid.uuid4().hex}"
         model_name = body.model or "roseinfer"
         if body.stream:
+            request_id = sched_manager.add_request(
+                prompt=prompt,
+                max_new_tokens=body.max_tokens,
+                temperature=body.temperature,
+                top_k=body.top_k,
+                top_p=body.top_p,
+                stop_on_eos=True,
+                do_sample=True,
+            )
 
             def event_stream():
                 first_chunk = ChatCompletionChunk(
@@ -187,15 +308,7 @@ def create_app(engine: InferenceEngine) -> FastAPI:
                     ],
                 )
                 yield f"data: {first_chunk.model_dump_json()}\n\n".encode("utf-8")
-                for piece in engine.stream_generate(
-                    prompt=prompt,
-                    max_new_tokens=body.max_tokens,
-                    temperature=body.temperature,
-                    top_k=body.top_k,
-                    top_p=body.top_p,
-                    stop_on_eos=True,
-                    do_sample=True,
-                ):
+                for piece in sched_manager.stream_text(request_id):
                     if not piece:
                         continue
                     chunk = ChatCompletionChunk(

运行

再类似之前一样运行:

(/data/projects/rosellm/.conda) wine@wine-MS-7D90:/data/projects/rosellm/rosellm$ python -m rosellm.roseinfer.server --checkpoint-path rosetrainer/checkpoints/gpt2_small_ddp_edu_amp_bf16_init.pt --tokenizer-name gpt2 --device cuda --port 8080
INFO:     Started server process [1625699]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8080 (Press CTRL+C to quit)
INFO:     127.0.0.1:54660 - "POST /v1/chat/completions HTTP/1.1" 200 OK

进行请求:

(/data/projects/rosellm/.conda) wine@wine-MS-7D90:/data/projects/rosellm/rosellm$ curl -X POST "http://127.0.0.1:8080/v1/chat/completions"   -H "Content-Type: application/json"   -d '{
    "model": "roseinfer-gpt2",
    "messages": [
      {"role": "system", "content": "You are a helpful assistant."},
      {"role": "user", "content": "Hello!"}
    ],
    "max_tokens": 32,
    "temperature": 0.7,
    "top_p": 0.95,
    "stream": true
  }'
data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":"assistant","content":""},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" "},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"\n"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"A"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"."},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" Why"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" is"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" the"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" client"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" not"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" to"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" attend"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" a"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" high"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" school"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"?"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"\n"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"A"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"."},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" What"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" is"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" the"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" client"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"?"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"\n"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"A"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"."},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" What"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" is"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" the"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":" client"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"?"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":"\n"},"finish_reason":null}]}

data: {"id":"chatcmpl-147a5660b7544006ae0d91499894913d","object":"chat.completion.chunk","created":1765281448,"model":"roseinfer-gpt2","choices":[{"index":0,"delta":{"role":null,"content":null},"finish_reason":"stop"}]}

data: [DONE]

(/data/projects/rosellm/.conda) wine@wine-MS-7D90:/data/projects/rosellm/rosellm$