5 minute read

本 PR 来继续完善推理框架,目前我们最大的问题是 decode 的时候是遍历每个 session,依次做的 decode,在本 PR 下,我们会把这个逻辑改成 batch decode,也就是在 decode 的那个 forward 之前,把所有 session 的 last token id 都拿出来聚合成一个 tensor 然后一块走新的 forward,结果再分拆到各个 session 上,kv-cache 等也做类似的合并分拆操作。

并且在这种改法之后,实际上隐式实现了 continuous batching,也就是可以在 decode 的时候随时加入新的 request,只不过还没有实现并发控制,必须在循环中手动同步加入新请求。

代码变更

engine.py

最主要的实际上就是在 InferenceEngine class 上面加了一个 decode_step_sessions 方法,入参接受 sessions,内部把 sessions 提取出来 last token id 以及 kv-cache,把他们聚合后走一次 forward,forward 的结果再分拆到各 session 里。

此外就是在 Offline Scheduler class 上写了一个 step 方法,里面去调用 decode_step_sessions,然后对结果进行采样输出(额外添加了一个 apply_batch_logits 的方法,接受 logits,输出对应的采样 token)。

diff --git a/rosellm/roseinfer/engine.py b/rosellm/roseinfer/engine.py
index ab1e640..4fbf631 100644
--- a/rosellm/roseinfer/engine.py
+++ b/rosellm/roseinfer/engine.py
@@ -600,6 +600,124 @@ class InferenceEngine:
         if any(tails):
             yield tails
 
+    @torch.no_grad()
+    def decode_step_sessions(
+        self,
+        sessions: list["InferenceSession"],
+    ) -> torch.Tensor:
+        assert sessions
+        from torch.amp import autocast
+
+        device = self.device
+        batch_size = len(sessions)
+        last_ids: list[int] = []
+        seq_lens: list[int] = []
+        for sess in sessions:
+            if sess.finished:
+                continue
+            assert sess.kv_cache is not None
+            assert sess.generated_ids
+            last_ids.append(sess.generated_ids[-1])
+            key0, _ = sess.kv_cache[0]
+            seq_lens.append(key0.size(2))
+        assert len(last_ids) == batch_size
+        lens = torch.tensor(seq_lens, device=device, dtype=torch.long)
+        max_len = max(seq_lens)
+        input_ids = torch.tensor(  # [B, 1]
+            last_ids,
+            dtype=torch.long,
+            device=device,
+        ).view(batch_size, 1)
+        past_mask = torch.arange(  # [B, max_len], bool
+            max_len,
+            device=device,
+        ).unsqueeze(0) < lens.unsqueeze(1)
+        new_mask = torch.ones(  # [B, 1]
+            batch_size,
+            1,
+            device=device,
+            dtype=past_mask.dtype,
+        )
+        attention_mask = torch.cat(
+            [past_mask, new_mask],
+            dim=1,
+        ).to(torch.long)
+
+        batched_past = []
+        num_layers = len(sessions[0].kv_cache)
+        for layer_idx in range(num_layers):
+            k_list = []
+            v_list = []
+            for idx, sess in enumerate(sessions):
+                k_layer, v_layer = sess.kv_cache[layer_idx]
+                T_i = seq_lens[idx]
+                if T_i < max_len:
+                    pad_len = max_len - T_i
+                    pad_shape = (
+                        1,
+                        k_layer.size(1),
+                        pad_len,
+                        k_layer.size(3),
+                    )
+                    k_pad = torch.zeros(
+                        pad_shape,
+                        dtype=k_layer.dtype,
+                        device=k_layer.device,
+                    )
+                    v_pad = torch.zeros(
+                        pad_shape,
+                        dtype=v_layer.dtype,
+                        device=v_layer.device,
+                    )
+                    k_full = torch.cat(
+                        [k_layer, k_pad],
+                        dim=2,
+                    )
+                    v_full = torch.cat(
+                        [v_layer, v_pad],
+                        dim=2,
+                    )
+                else:
+                    k_full = k_layer
+                    v_full = v_layer
+                k_list.append(k_full)
+                v_list.append(v_full)
+            k_cat = torch.cat(k_list, dim=0)
+            v_cat = torch.cat(v_list, dim=0)
+            batched_past.append((k_cat, v_cat))
+        if self.use_amp:
+            with autocast(
+                device_type=device.type,
+                dtype=self.amp_dtype,
+            ):
+                logits, _, presents = self.model(
+                    input_ids=input_ids,
+                    attention_mask=attention_mask,
+                    labels=None,
+                    past_key_values=tuple(batched_past),
+                    use_cache=True,
+                )
+        else:
+            logits, _, presents = self.model(
+                input_ids=input_ids,
+                attention_mask=attention_mask,
+                labels=None,
+                past_key_values=tuple(batched_past),
+                use_cache=True,
+            )
+        last_logits = logits[:, -1, :]  # [B, V]
+        for layer_idx in range(num_layers):
+            k_b, v_b = presents[layer_idx]
+            for idx, sess in enumerate(sessions):
+                if sess.finished:
+                    continue
+                prev_len = seq_lens[idx]
+                new_len = prev_len + 1
+                k_slice = k_b[idx : idx + 1, :, :new_len, :].contiguous()
+                v_slice = v_b[idx : idx + 1, :, :new_len, :].contiguous()
+                sess.kv_cache[layer_idx] = (k_slice, v_slice)
+        return last_logits
+
 
 class InferenceSession:
     def __init__(self, engine: "InferenceEngine") -> None:
@@ -799,6 +917,33 @@ class InferenceSession:
             self.finished = True
         return token_id
 
+    @torch.no_grad()
+    def apply_batch_logits(
+        self,
+        last_logits: torch.Tensor,
+    ) -> int | None:
+        if self.finished:
+            return None
+        eng = self.engine
+        logits_2d = last_logits.view(1, -1)  # [1, V]
+        next_token = eng._sample_next_token(
+            logits_2d,
+            temperature=self.temperature,
+            top_k=self.top_k,
+            top_p=self.top_p,
+            do_sample=self.do_sample,
+        )
+        token_id = int(next_token)
+        self.generated_ids.append(token_id)
+        self.step_count += 1
+        if self.stop_on_eos:
+            eos_id = eng.eos_token_id
+            if eos_id is not None and token_id == eos_id:
+                self.finished = True
+        if self.max_new_tokens > 0 and self.step_count >= self.max_new_tokens:
+            self.finished = True
+        return token_id
+
     def release_kv_blocks(self) -> None:
         if self.kv_manager is None:
             return
@@ -897,20 +1042,30 @@ class OfflineScheduler:
         self._sessions[request_id] = session
         return request_id
 
+    def has_unfinished(self) -> bool:
+        return any(not sess.finished for sess in self._sessions.values())
+
+    @torch.no_grad()
+    def step(self) -> dict[int, int]:
+        active_pairs: list[tuple[int, InferenceSession]] = [
+            (rid, sess) for rid, sess in self._sessions.items() if not sess.finished
+        ]
+        if not active_pairs:
+            return {}
+        sessions = [pair[1] for pair in active_pairs]
+        last_logits = self.engine.decode_step_sessions(sessions)
+        step_tokens: dict[int, int] = {}
+        for idx, (rid, sess) in enumerate(active_pairs):
+            logits_row = last_logits[idx]
+            token_id = sess.apply_batch_logits(logits_row)
+            if token_id is not None:
+                step_tokens[rid] = token_id
+        return step_tokens
+
     @torch.no_grad()
     def run(self) -> dict[int, str]:
-        active_ids: set[int] = {
-            rid for rid, sess in self._sessions.items() if not sess.finished
-        }
-        while active_ids:
-            for rid in list(active_ids):
-                session = self._sessions[rid]
-                if session.finished:
-                    active_ids.remove(rid)
-                    continue
-                _ = session.step_once()
-                if session.finished:
-                    active_ids.remove(rid)
+        while self.has_unfinished():
+            self.step()
         outputs: dict[int, str] = {}
         for rid, session in self._sessions.items():
             outputs[rid] = session.decode_text()

运行

再把之前的 offline example 执行一下:

$ ./offline_example.sh 
### request 0
hi, ¢:
- The figure of the possible locations in the study was that:
- The level of the population was defined by the mean-weight ratio.
- The level of the population was defined by the mean-weight ratio in the total body.
- The level of the population was defined by the mean-weight ratio.
- The overall-weight ratio was defined by the mean-weight ratio.
- The average-weight ratio was defined by the mean-weight ratio

### request 1
hello, the best suited and the control of situations.
and officials for tallening, it is very recently been evokes a thunderstorms, and the task.
Model TEACHESAYSerrillas, with the ambiguous parts for the simplestly, a thoughtfully detaileding up to BDS, the commentaries and lyrics, and insurance companies around 6thane (to-ofiscountains backbones, or the otherworldlyly date backpacks that he/hertzuniversity, but

### request 2
how is that he/or her route.
parallel data, the military bases.
Graphic description in favor of course, you feel free online. It is a series. This has been made the Australian State University of this article, and industry, it in terms requires that keeps a crime of the discovery. Thesserscentreduced from storage, which currently accepted by many times, and culture has been identified.
Azzles with a general public transport your knowledge of the term,