自作トークナイザーを作ってみた。
2024-12-25
azblob://2024/12/23/eyecatch/2024-12-25-tokenizer-from-scratch-000_1.png

自作トークナイザーを作ってみた。

前回から時間が空いてしまいましたが、今回も作ってみました。今回はトークナイザーです。トークナイザーのコアであるBPEについて勉強してみました。

BPE (Byte Pair Encoding) とは?

Byte Pair Encoding (BPE) は、文字列(データ)中で頻繁に出現するバイトペアを新しいトークン(サブワード)として再帰的に置き換えるアルゴリズムです。もともとは文字列の圧縮手法として提案されましたが、現在では NLP の トークナイザーとして広く活用されています。特に GPT 系モデルや多くの大規模言語モデル(LLM)で採用されるトークナイザーとして有名です。

  • 単語レベルよりも細かく、文字レベルよりも大きい “ちょうどいい” 分割単位が得られる
  • 未知語(辞書に存在しない単語)を細かく分割して表現できる

以下、具体的なサンプルと実装例を見ながら、その流れを解説します。


1. BPE の動作原理(簡単な例)

下記は、BPE がどのように文字列を置き換えるかを示すシンプルな例です。

aaabdaaabac

この文字列は全部で 11 トークン(文字)です。最も頻繁に出現するペアは "aa" なので、例えば "Z" のような未使用バイトを使って置き換えます。

ZabdZabac
Z=aa

次は "ab" が最頻出となるため、それを "Y" に置換します。

ZYdZYac
Y=ab
Z=aa

これでさらに置き換えを続けると、以下のように "ZY""X" に置き換えることも可能です。

XdXac
X=ZY
Y=ab
Z=aa

最終的には 5 トークンとなりました。ここで、さらに頻出ペアがなければ BPE による圧縮は終了です。逆に デコード する際は、これらの置き換え規則を逆順に適用すれば元の文字列を復元できます。


2. Byte-level BPE について

BPE は「Byte」という単語が入っていますが、実装上は Unicode のコードポイント単位で処理されることが多いです。
ところが、OpenAI の研究 [^1]では “byte-level BPE” が提案され、以下のような議論がされています(ざっくり日本語要約):

  • Unicode のあらゆる文字を扱おうとすると、ベースの語彙(初期トークン集合)が 13万語以上にもなる
  • BPE をバイト列に直接適用するなら、ベースの語彙数は 256 個に収まる
  • しかし BPE が “貪欲 (greedy)” に頻度をもとにペアをマージするため、共通単語が微妙に違う形(例:dog, dog., dog?, dog!)で出現すると、トークンを無駄に使ってしまう
  • そのため、OpenAI の実装では文字カテゴリをまたいだマージを制限し、スペース(空白文字)に関しては多少自由度を持たせるなどの工夫を行うことで、適切な語彙分割を実現している

これにより、バイトレベルというシンプルなアイデアのメリットを維持しつつ、文字カテゴリを意識して不要なマージを避けることで、限られたトークン数をより有効に使っています。

参考文献:

この部分は 2-3 の章で実装しています


Wikipedia ダウンロード & BPE コード実装

ここでは、Wikipedia からテキストをダウンロードし、BPE 学習(ペアの頻度計算 & マージ処理)を行うサンプルコードを通しで解説します。あくまで概念実証・実験用の簡易実装ですが、全体の流れを掴むのに役立ちます。

2-1. Wikipedia からテキストを取得し、1ファイルにまとめる

Pythonimport wikipedia as wiki
import os
import glob
from icecream import ic

wiki.set_lang("en")
en_topics: list = ["Python (programming language)", "Attention Is All You Need","Harry Potter"]

for topic in en_topics:
    try:
        if os.path.exists("data/{}_en.txt".format(topic.replace(" ", "_"))):
            print("Skipping \"{}\" as it already exists".format(topic))
            continue
        page = wiki.page(topic, auto_suggest=False)
        content = page.content
        os.makedirs("data", exist_ok=True)
        with open("data/{}_en.txt".format(topic.replace(" ", "_")), "w") as f:
            f.write(content)
        print("Downloaded \"{}\"".format(topic))
    except:
        print("Failed to download \"{}\"".format(topic))
        continue

data_paths = glob.glob("data/*.txt")
training_data_path = os.path.join("./","training_data.txt")

with open(training_data_path, "w") as f:
    for path in data_paths:
        with open(path, "r") as f2:
            content = f2.read()
            f.write(content)

with open(training_data_path, "r") as f:
    full_content = f.read()
    print("File: {} has {} characters".format(training_data_path, len(full_content)))

2-2. BPE で必要になる基礎関数の定義

(a) ペアの頻度を数える関数

Pythondef get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

(b) 指定ペアのマージ

Pythondef merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i+=2
        else: 
            newids.append(ids[i])
            i +=1
    return newids

(c) マージ履歴→辞書化、デコード

Pythondef get_vocab_dict(merges:dict) -> dict: 
    vocab_dict = {idx: bytes([idx]) for idx in range(256)}
    print(f"merges: {merges}")
    for (p0, p1), idx in merges.items():
        if ((p0 in vocab_dict) and (p1 in vocab_dict)):
            vocab_dict[idx] = vocab_dict[p0] + vocab_dict[p1]
    return vocab_dict

def decode(ids:list, vocab_dict: dict):
    tokens = b"".join(vocab_dict[idx] if idx in vocab_dict else b"" for idx in ids)
    text = tokens.decode("utf-8", errors='replace')
    return text        

(d) テキストを符号化する (encode)

Pythondef encode(text, merges: dict):
    tokens = list(text.encode("utf-8"))
    print(f"encode len tokens: {len(tokens)}")
    while True and (len(tokens) >= 2):
        stats = get_stats(tokens)
        print(f"encode stats: {stats}")
        pair = min(stats, key=lambda pair: merges.get(pair, float("inf"))) # merges にないペアをスキップ
        
        if pair not in merges:
            break  # これ以上マージできない
        
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    
    return tokens

2-3. GPT4oで使われる正規表現の例

Pythonimport regex    

regex_pat_str = "|".join(
        [
            r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
            r"""[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?""",
            r"""\p{N}{1,3}""",
            r""" ?[^\s\p{L}\p{N}]+[\r\n/]*""",
            r"""\s*[\r\n]+""",
            r"""\s+(?!\S)""",
            r"""\s+""",
        ]
    )

print(regex.findall(regex_pat_str, "Hello World!"))

これはあくまでトークン分割の一例です。実際にはモデルや目的に合わせて微調整します。 参考

2-4. BPE 学習

Pythonimport regex    

# BPE マージの履歴を保存する dict
merges = {}

def train_tokens(text):
    unicode_text = text.encode("utf-8")
    all_tokens = list(map(int,unicode_text))
    max_token_id = max(all_tokens)
    idx =  max_token_id
    print(f"max_token_id: {max_token_id}")
    
    regex_match_tokens = regex.findall(regex_pat_str, text)
    print(regex_match_tokens)
    num = 0
    for regex_match_token in regex_match_tokens:
        if num == 30:
            break
        print("===================================================================")
        print(regex_match_token)
        unicode_text = regex_match_token.encode("utf-8")
        print("number of unicode characters: {} characters".format(len(unicode_text)))
        
        tokens = list(map(int,unicode_text))
        print(tokens)
        print("number of tokens: {} tokens".format(len(tokens)))
        
        num_merges = 100
        ids = list(tokens)

        for i in range(num_merges):
            stats:dict = get_stats(ids)
            print(f"stats = {stats}")
            if (len(stats) >= 1):
                # 最頻出ペア(=最大頻度のペア)を探す
                pair:tuple = max(stats, key=stats.get)
                idx +=1
                sorted_stats = sorted(((v,k) for k,v in stats.items()),reverse=True)
                print(f"sorted_stats: {sorted_stats}")
                print(f"Most appeared pair: {chr(sorted_stats[0][1][0])} {chr(sorted_stats[0][1][1])} = {sorted_stats[0][0]} times")    
                print(f"merging {pair} into a new token {idx}")
                ids:list = merge(ids, pair, idx)
                merges[pair] = idx
            
            vocab_dict = get_vocab_dict(merges)
            if (idx in vocab_dict):
                print(f"idx: {idx} => {decode([idx], vocab_dict)}")
            if len(stats) == 0:
                break
        
        print("===================================================================")
        num += 1

train_tokens(full_content)
  • train_tokens ではテキストを正規表現でトークンに分割し、各トークンに対して BPE のペアマージを再帰的に実行。最初は ASCII など最大 256 までしか使わないが、マージ毎に新トークン ID を割り当てていきます。

2-5. 学習済み merges を使ってテキストをエンコード

Pythonencoded_tokens = encode("hello world", merges)
print(encoded_tokens)
encoded_tokens = encode("Harry Potter is a series of seven fantasy novels written by British author J.K. Rowling.", merges)
print(encoded_tokens)

ここで、学習したペアマージ情報 (merges) に基づき、任意のテキストを BPE エンコードできるようになりました。


2-6. vocab_dict の出力例

Pythonprint(merges)

vocab_dict = get_vocab_dict(merges)
print(vocab_dict)

vocab_dict_str = {}
for vocab in range(len(vocab_dict)):
    print(decode([vocab], vocab_dict))
    vocab_dict_str[vocab] = decode([vocab], vocab_dict)
    
print(encode(" Potter",merges))

import json
dir = "./vocab_dict.json"
with open(dir, mode="wt", encoding="utf-8") as f:
	json.dump(vocab_dict_str, f, ensure_ascii=False, indent=2)
  • merges はペアと新トークンIDのマップ
  • vocab_dict は最終的なトークン ID → バイト列へのマッピング
  • これを JSON 形式で保存すれば、推論時にも同じトークナイザーを再利用できます。

まとめ

  • BPE は単語レベルと文字レベルの中間点を実現する実用的な手法として、非常に多くの言語モデルで使われています。
  • byte-level BPE は、Unicode 全領域を扱う従来の方法に比べ、初期語彙が 256 個に限定できるという大きな利点があります。
  • ただし、BPE は貪欲法によりペアマージを行うため、不要なマージが発生しやすいという課題があり、文字カテゴリをまたいだマージを抑制するなどの最適化が OpenAI のモデルでは実装されています。
  • 本記事のコードは簡易的なものですが、独自のドメイン語彙が必要な場合や、テキスト種類(プログラムコードなど)に合わせてトークナイザーをカスタマイズしたいときの出発点として活用できます。

最後に


次は何を作ってみましょう。また気が向いたら作ってみます

 

参考リンク / 文献