从零实现 LLM Training:016. Multi Doc Dataset
我们已经实现了张量并行、混合精度、checkpoint、eval and logging,并使用了真实数据做了一些小训练,并且做了很简单的生成,这个 PR 可以来好好修一修目前项目中 Dataset 不合理的处理,比如目前是每行后加 eos,是不合理的,我们自然引入多文档(多 txt),可以在每个文档本身看成一个整体,只在文档之间加 eos,并且我们之前取 index 的位置实际上是固定的(需要是 seq_len 的倍数),这里也是不合理的,应该以任意位置都可以做起点。
本文对应 PR 的整体目标还是简洁,能够满足最小的需求,大概是处理上百篇 txt 这种级别,再高的话需要引入 megatron 的工业级方案:
- 分 IndexedDataset 存 .bin, .idx(前者是所有 doc 平铺的 tokenize 之后的结果,后者则用于标识每个 doc 对应在 .bin 中的起始位置以及长度)
- 分 GPTDataset 存:
- Do_idx:多个文档 id 重复 shuffle,用来模拟随机的文档流
- Sa_idx:sample index,每一项是 doc_index_pos, offset_in_that_doc,表示这个 sample 是从 Do_idx 的哪个下标开始的,并且位于该文档中的哪个 offset 上,相邻两项 Sa_idx 就可以确定这个 sample 对应的区间,注意这里明显可以看到,sample 可以跨文档
- Sh_idx:shuffle index,相当于实际取第 k 个样本的时候取的是 Sa_idx[Sh_idx[k]],使样本顺序随机
- 分 BlendedDataset 去把不同 corpus 做一些融合策略
简单起见,本 PR 还是做内存版,即把所有 tokenize 好的 token 存到内存的一个 all_ids 中,然后对 start 位置做随机 sample
dataset.py
diff --git a/rosellm/rosetrainer/dataset.py b/rosellm/rosetrainer/dataset.py
index 9a9fc99..b0970d2 100644
--- a/rosellm/rosetrainer/dataset.py
+++ b/rosellm/rosetrainer/dataset.py
@@ -1,4 +1,5 @@
-from typing import List
+import random
+from typing import List, Optional
import torch
from torch.utils.data import Dataset
@@ -31,46 +32,64 @@ class TextDatasetForCausalLM(Dataset):
tokenizer,
seq_len: int,
add_eos: bool = True,
+ max_tokens: Optional[int] = None,
+ seed: Optional[int] = None,
) -> None:
super().__init__()
self.tokenizer = tokenizer
self.seq_len = seq_len
- self.add_eos = add_eos
- texts: List[str] = []
+ eos_id: Optional[int] = None
+ if add_eos and hasattr(tokenizer, "eos_token_id"):
+ eos_id = tokenizer.eos_token_id
+ all_ids: List[int] = []
+ total_tokens = 0
+ total_files = 0
for path in file_paths:
+ total_files += 1
with open(path, "r", encoding="utf-8") as f:
- for line in f:
- line = line.strip()
- if not line:
- continue
- texts.append(line)
- if not texts:
- raise ValueError("the file is full of empty lines")
- all_ids: List[int] = []
- for text in texts:
- if add_eos and hasattr(tokenizer, "eos_token_id"):
- encoded = tokenizer.encode(
- text,
- add_special_tokens=False,
- )
- all_ids.extend(encoded)
- all_ids.append(tokenizer.eos_token_id)
- else:
- encoded = tokenizer.encode(
- text,
- add_special_tokens=False,
- )
- all_ids.extend(encoded)
+ text = f.read()
+ if not text.strip():
+ continue
+ ids = tokenizer.encode(text, add_special_tokens=False)
+ if not ids:
+ continue
+ all_ids.extend(ids)
+ total_tokens += len(ids)
+ if eos_id is not None:
+ all_ids.append(eos_id)
+ total_tokens += 1
+ if max_tokens is not None and total_tokens >= max_tokens:
+ all_ids = all_ids[:max_tokens]
+ total_tokens = len(all_ids)
+ break
+ print(f"total files: {total_files}")
+ print(f"total tokens: {total_tokens}")
if len(all_ids) < seq_len:
raise ValueError("the total number of tokens is less than seq_len")
self.all_ids = all_ids
- self.num_samples = len(all_ids) // seq_len
+ self.total_tokens = len(all_ids)
+ max_start = self.total_tokens - seq_len
+ num_samples = self.total_tokens // seq_len
+ if num_samples > max_start + 1:
+ num_samples = max_start + 1
+ if seed is None:
+ rng = random.Random()
+ else:
+ rng = random.Random(seed)
+ candidates = list(range(max_start + 1))
+ if num_samples < len(candidates):
+ start_indices = rng.sample(candidates, num_samples)
+ else:
+ rng.shuffle(candidates)
+ start_indices = candidates
+ self.start_indices = start_indices
+ self.num_samples = len(self.start_indices)
def __len__(self) -> int:
return self.num_samples
def __getitem__(self, idx: int):
- start = idx * self.seq_len
+ start = self.start_indices[idx]
end = start + self.seq_len
ids = self.all_ids[start:end]
input_ids = torch.tensor(ids, dtype=torch.long)
train_minimal.py
diff --git a/rosellm/rosetrainer/train_minimal.py b/rosellm/rosetrainer/train_minimal.py
index ed74147..8553d00 100644
--- a/rosellm/rosetrainer/train_minimal.py
+++ b/rosellm/rosetrainer/train_minimal.py
@@ -130,12 +130,18 @@ def main(args: argparse.Namespace) -> None:
file_paths=args.train_data,
tokenizer=tokenizer,
seq_len=args.seq_len,
+ add_eos=True,
+ max_tokens=args.max_tokens,
+ seed=args.data_seed,
)
if args.val_data:
val_dataset = TextDatasetForCausalLM(
file_paths=args.val_data,
tokenizer=tokenizer,
seq_len=args.seq_len,
+ add_eos=True,
+ max_tokens=args.max_tokens,
+ seed=args.data_seed + 1,
)
else:
val_size = max(int(0.1 * len(train_dataset)), 1)
@@ -342,6 +348,18 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="use random toy data rather than real data",
)
+ parser.add_argument(
+ "--max-tokens",
+ type=int,
+ default=None,
+ help="max tokens to sample from the dataset",
+ )
+ parser.add_argument(
+ "--data-seed",
+ type=int,
+ default=None,
+ help="seed for the data sampler",
+ )
return parser.parse_args()
train_ddp.py
diff --git a/rosellm/rosetrainer/train_ddp.py b/rosellm/rosetrainer/train_ddp.py
index fcfe19e..867ca31 100644
--- a/rosellm/rosetrainer/train_ddp.py
+++ b/rosellm/rosetrainer/train_ddp.py
@@ -163,12 +163,18 @@ def main(args: argparse.Namespace) -> None:
file_paths=args.train_data,
tokenizer=tokenizer,
seq_len=args.seq_len,
+ add_eos=True,
+ max_tokens=args.max_tokens,
+ seed=args.data_seed,
)
if args.val_data:
val_dataset = TextDatasetForCausalLM(
file_paths=args.val_data,
tokenizer=tokenizer,
seq_len=args.seq_len,
+ add_eos=True,
+ max_tokens=args.max_tokens,
+ seed=args.data_seed + 1,
)
else:
val_size = max(int(0.1 * len(train_dataset)), 1)
@@ -411,6 +417,18 @@ def parse_args() -> argparse.Namespace:
action="store_true",
help="use toy data",
)
+ parser.add_argument(
+ "--max-tokens",
+ type=int,
+ default=None,
+ help="max tokens to sample from the dataset",
+ )
+ parser.add_argument(
+ "--data-seed",
+ type=int,
+ default=None,
+ help="seed for the data sampler",
+ )
return parser.parse_args()
运行
可以搞多个 gutenberg 的 txt 进行训练:
#!/bin/bash
torchrun --nproc_per_node=2 train_ddp.py \
--n-layers 12 \
--n-heads 12 \
--d-model 768 \
--d-ff 3072 \
--dropout 0.1 \
--max-position-embeddings 1024 \
--seq-len 1024 \
--batch-size 2 \
--num-steps 6000 \
--lr 3e-4 \
--train-data data/gutenberg/*.txt \
--tokenizer-name gpt2 \
--max-tokens 10000000 \
--data-seed 42 \
--checkpoint-path checkpoints/gpt2_small_ddp.pt
形如:
('epoch 1 step 760 / 6000 ', 'train loss: 5.8938 ', 'val loss: 5.5948 ', 'val ppl: 269.0184 ', 'amp: True')
('epoch 1 step 770 / 6000 ', 'train loss: 6.0291 ', 'val loss: 5.5967 ', 'val ppl: 269.5446 ', 'amp: True')
('epoch 1 step 780 / 6000 ', 'train loss: 5.6675 ', 'val loss: 5.5812 ', 'val ppl: 265.3791 ', 'amp: True')
('epoch 1 step 790 / 6000 ', 'train loss: 5.9663 ', 'val loss: 5.5908 ', 'val ppl: 267.9512 ', 'amp: True')
('epoch 1 step 800 / 6000 ', 'train loss: 5.0629 ', 'val loss: 5.5759 ', 'val ppl: 263.9961 ', 'amp: True')
('epoch 1 step 810 / 6000 ', 'train loss: 6.4135 ', 'val loss: 5.5756 ', 'val ppl: 263.9148 ', 'amp: True')
('epoch 1 step 820 / 6000 ', 'train loss: 5.8329 ', 'val loss: 5.5482 ', 'val ppl: 256.7634 ', 'amp: True')
('epoch 1 step 830 / 6000 ', 'train loss: 4.9374 ', 'val loss: 5.5305 ', 'val ppl: 252.2779 ', 'amp: True')
('epoch 1 step 840 / 6000 ', 'train loss: 3.3967 ', 'val loss: 5.5210 ', 'val ppl: 249.8808 ', 'amp: True')
('epoch 1 step 850 / 6000 ', 'train loss: 5.2579 ', 'val loss: 5.5221 ', 'val ppl: 250.1612 ', 'amp: True')
('epoch 1 step 860 / 6000 ', 'train loss: 5.2088 ', 'val loss: 5.5418 ', 'val ppl: 255.1386 ', 'amp: True')