diff --git a/.gitignore b/.gitignore index 86d70e3..be4ce66 100644 --- a/.gitignore +++ b/.gitignore @@ -57,7 +57,8 @@ coverage.xml *.txt *.yttm artifacts/ -stress +bpe_stress +wordpiece_stress # Translations *.mo diff --git a/LICENSE b/LICENSE index 52a300b..40b88d8 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2019 VK.com +Copyright (c) 2019-2023 VK.com Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/MANIFEST.in b/MANIFEST.in index 4ce0eae..27234f2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,13 +1,13 @@ include youtokentome/cpp/utils.h include youtokentome/cpp/bpe.h include youtokentome/cpp/utf8.h +include youtokentome/cpp/wordpiece.h include youtokentome/cpp/yttm.pyx -include youtokentome/cpp/third_party/flat_hash_map.h -include youtokentome/cpp/third_party/LICENSE +include youtokentome/cpp/third_party/flat_hash_map/flat_hash_map.h +include youtokentome/cpp/third_party/flat_hash_map/LICENSE +include youtokentome/cpp/third_party/thread_pool/thread_pool.h +include youtokentome/cpp/third_party/thread_pool/LICENSE include LICENSE include README.md include requirements.txt include yttm_cli.py - - - diff --git a/README.md b/README.md index 00b123e..74335a0 100644 --- a/README.md +++ b/README.md @@ -6,20 +6,21 @@ # YouTokenToMe -YouTokenToMe is an unsupervised text tokenizer focused on computational efficiency. It currently implements fast Byte Pair Encoding (BPE) [[Sennrich et al.](https://www.aclweb.org/anthology/P16-1162)]. -Our implementation is much faster in training and tokenization than [Hugging Face](https://github.com/huggingface/tokenizers), [fastBPE](https://github.com/glample/fastBPE) - and [SentencePiece](https://github.com/google/sentencepiece). In some test cases, it is 60 times faster. - Check out our [benchmark](benchmark.md) results. +YouTokenToMe is an unsupervised text tokenizer focused on computational efficiency. It currently contains the fastest implementations of: +- Byte Pair Encoding (BPE) [[Sennrich et al.](https://www.aclweb.org/anthology/P16-1162)], [benchmark results](benchmark_bpe.md); +- WordPiece [[Song et al.](https://arxiv.org/abs/2012.15524)], [benchmark results](benchmark_wordpiece.md). Key advantages: * Multithreading for training and tokenization -* The algorithm has `O(N)` complexity, where `N` is the length of training data * Highly efficient implementation in C++ * Python wrapper and command-line interface -Extra features: -* BPE-dropout (as described in [Provilkov et al, 2019](https://arxiv.org/abs/1910.13267)) +## BPE implementation + +Algorighm properties: +* Time complexity is `O(N)`, where `N` is the length of training data +* Supports BPE-dropout (as described in [Provilkov et al, 2019](https://arxiv.org/abs/1910.13267)) As well as in the algorithm from the original paper, ours does not consider tokens that cross word boundaries. Just like in [SentencePiece](https://github.com/google/sentencepiece), all space symbols were replaced by meta symbol "▁" (U+2581). It allows sequences of tokens to be converted back to text and for word boundaries to be restored. @@ -28,15 +29,21 @@ For example, the phrase ```Blazingly fast tokenization!``` can be tokenized into `['▁Bl', 'az', 'ingly', '▁fast', '▁token', 'ization', '!']` +## WordPiece implementation + +Algorighm properties: +* Currently supports tokenizer only, but not training +* Time complexity is `O(Nm^2)`, where `N` is the length of tokenized data and `m` is the max length of word in vocabulary + ## Installation ```bash pip install youtokentome ``` -## Python interface + +## Python BPE interface ### Example -Let's start with a self-contained example. ```python import random @@ -67,11 +74,28 @@ bpe = yttm.BPE(model=model_path) print(bpe.encode([test_text], output_type=yttm.OutputType.ID)) print(bpe.encode([test_text], output_type=yttm.OutputType.SUBWORD)) ``` + +### Methods +Class `youtokentome.BPE` has the following methods: + +#### constructor + +```python +youtokentome.BPE(model, n_threads=-1) +``` +Class constructor. Loads the trained model. + +* `model`: string, path to the trained model +* `n_threads`: int, number of parallel threads used to run. + If equal to -1, then the maximum number of threads available will be used. +   -### Training model + +#### train + ```python -youtokentome.BPE.train(data, model, vocab_size, coverage, n_threads=-1, pad_id=0, unk_id=1, bos_id=2, eos_id=3) +train(data, model, vocab_size, coverage, n_threads=-1, pad_id=0, unk_id=1, bos_id=2, eos_id=3) ``` Trains BPE model and saves to file. @@ -92,22 +116,6 @@ Trains BPE model and saves to file.   -### Model loading - -```python -youtokentome.BPE(model, n_threads=-1) -``` - -Class constructor. Loads the trained model. - -* `model`: string, path to the trained model -* `n_threads`: int, number of parallel threads used to run. - If equal to -1, then the maximum number of threads available will be used. - -  - -### Methods -Class `youtokentome.BPE` has the following methods: #### encode ```python encode(self, sentences, output_type=yttm.OutputType.ID, bos=False, eos=False, reverse=False, dropout_prob=0) @@ -185,16 +193,23 @@ Convert each id to subword and concatenate with space symbol. **Returns:** List of strings. - -## Command line interface -### Example +## Python WordPiece interface -```bash -$ yttm bpe --data TRAINING_DATA_FILE --model OUTPUT_MODEL_FILE --vocab_size 2000 -$ yttm encode --model OUTPUT_MODEL_FILE --output_type subword < TEST_DATA_FILE > ENCODED_DATA -``` +### Example +TODO + +### Methods +Class `youtokentome.WordPiece` has the following methods: + +#### constructor + +#### encode + +#### decode + +## Command line interface ### Supported commands @@ -209,16 +224,16 @@ Options: --help Show this message and exit. Commands: - bpe Train BPE model. - decode Decode ids to text. - encode Encode text to ids or subwords. - vocab Print list of learned subwords. + bpe-train Train BPE model. + bpe-decode Decode ids to text. + bpe-encode Encode text to ids or subwords. + bpe-vocab Print list of learned subwords. ``` Command `bpe` allows you to train Byte Pair Encoding model based on a text file. ``` -$ yttm bpe --help +$ yttm bpe-train --help Usage: yttm bpe [OPTIONS] @@ -237,18 +252,31 @@ Options: --help Show this message and exit. ``` +Convert ids back to text. Use `stdin` for input and `stdout` for output. + +``` +$ yttm bpe-decode --help + +Usage: yttm decode [OPTIONS] + + Decode ids to text. + +Options: + --model PATH Path to file with learned model. [required] + --ignore_ids List of indices to ignore for decoding. Example: --ignore_ids=1,2,3 + --help Show this message and exit. +``` Apply BPE encoding for a corpus of sentences. Use `stdin` for input and `stdout` for output. By default, encoding works in parallel using `n_threads` threads. Number of threads is limited by -8 (see [benchmark](benchmark.md#number-of-threads)). +8 (see [benchmark](benchmark_bpe.md#number-of-threads)). With the `--stream` option, `--n_threads` will be ignored and all sentences will be processed one by one. Each sentence will be tokenized and written to the `stdout` before the next sentence is read. - ``` -$ yttm encode --help +$ yttm bpe-encode --help Usage: yttm encode [OPTIONS] @@ -269,7 +297,7 @@ Options: Print vocabulary. This can be useful for understanding the model. ``` -$ yttm vocab --help +$ yttm bpe-vocab --help Usage: yttm vocab [OPTIONS] @@ -281,24 +309,11 @@ Options: --help Show this message and exit. ``` -Convert ids back to text. Use `stdin` for input and `stdout` for output. - -``` -$ yttm decode --help - -Usage: yttm decode [OPTIONS] +### Examples - Decode ids to text. +TODO: wordpiece -Options: - --model PATH Path to file with learned model. [required] - --ignore_ids List of indices to ignore for decoding. Example: --ignore_ids=1,2,3 - --help Show this message and exit. +```bash +$ yttm bpe-train --data TRAINING_DATA_FILE --model OUTPUT_MODEL_FILE --vocab_size 2000 +$ yttm bpe-encode --model OUTPUT_MODEL_FILE --output_type subword < TEST_DATA_FILE > ENCODED_DATA ``` - - - - - - - diff --git a/benchmark.md b/benchmark_bpe.md similarity index 92% rename from benchmark.md rename to benchmark_bpe.md index 36c43b2..bc1805f 100644 --- a/benchmark.md +++ b/benchmark_bpe.md @@ -1,7 +1,11 @@ -## Speed tests +## BPE Speed tests -`YouTokenToMe` will be compared with [Hugging Face](https://github.com/huggingface/tokenizers), [SentencePiece](https://github.com/google/sentencepiece/) - and [fastBPE](https://github.com/glample/fastBPE). These three algorithms are considered to be fast. +`YouTokenToMe` will be compared with: +* [Hugging Face](https://github.com/huggingface/tokenizers) +* [SentencePiece](https://github.com/google/sentencepiece/) +* [fastBPE](https://github.com/glample/fastBPE) + +These algorithms are considered to be fast. Data from [Wikipedia](https://linguatools.org/tools/corpora/wikipedia-monolingual-corpora/) was used to evaluate algorithm speed. In a similar way to `enwik8` and `enwik9`, the experiments were run on first `10^8` and `10^9` bytes of datasets for English, Russian, Chinese and Japanese. @@ -11,7 +15,7 @@ In this benchmark, `YouTokenToMe` used 4 threads for training and tokenization. doesn't support multithreading for **BPE** at all. `fastBPE` doesn't support multithreading for training. For tokenization, it also used 4 threads. -Source code for benchmark can be found [here](tests/speed_test/speed_test.py). +Source code for benchmark can be found [here](tests/speed_test/bpe.py). The results of the experiments are below. The time is measured in seconds. All experiments were run on the following machine: diff --git a/benchmark_wordpiece.md b/benchmark_wordpiece.md new file mode 100644 index 0000000..aaf4e27 --- /dev/null +++ b/benchmark_wordpiece.md @@ -0,0 +1,38 @@ +## WordPiece Speed tests + +`YouTokenToMe` will be compared with: +* [Hugging Face](https://github.com/huggingface/tokenizers) +* [Keras](https://github.com/keras-team/keras-nlp) +* [Tensorflow](https://github.com/tensorflow/text) +* [Torch](https://github.com/pytorch/text) + +These algorithms are considered to be fast. + +Data from [Wikipedia](https://linguatools.org/tools/corpora/wikipedia-monolingual-corpora/) was used to evaluate algorithm speed. In a similar way to `enwik8` and `enwik9`, the experiments were run on first `10^8` and `10^9` bytes of datasets for English, Russian, Chinese and Japanese. + +Used vocabulary: [bert-base-cased](https://huggingface.co/bert-base-cased). + +In this benchmark, `YouTokenToMe` used 4 threads for training and tokenization. + +Source code for benchmark can be found [here](tests/speed_test/wordpiece.py). +The results of the experiments are below. The time is measured in seconds. + +All experiments were run on the following machine: TODO + +### Tokenization 100MB +TODO: TABLE + +### Tokenization 1GB +TODO: TABLE + +`YouTokenToMe` performed really well in this benchmark. This is especially noticeable for languages with large alphabets. + +## Number of threads + +The table below shows the dependence of performance on the number of threads for `YouTokenToMe`. + +### Tokenization 1GB +TODO: TABLE + + +TODO: CONCLUSION ON THREADS diff --git a/requirements.txt b/requirements.txt index 16d2c18..32e09ab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,10 @@ -setuptools>=32.0.0 -Click>=7.0 -pytest==4.3.1 -tabulate==0.8.5 -Cython==0.29.14 \ No newline at end of file +atomicwrites==1.4.1 +attrs==22.2.0 +click==8.1.3 +Cython==0.29.34 +more-itertools==9.1.0 +pluggy==1.0.0 +py==1.11.0 +pytest==7.2.1 +six==1.16.0 +tabulate==0.9.0 diff --git a/setup.py b/setup.py index 867fc0f..60abd1c 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ "youtokentome/cpp/bpe.cpp", "youtokentome/cpp/utils.cpp", "youtokentome/cpp/utf8.cpp", + "youtokentome/cpp/wordpiece.cpp" ], extra_compile_args=["-std=c++11", "-pthread", "-O3"], language="c++", @@ -35,7 +36,7 @@ python_requires=">=3.5.0", install_requires=["Click>=7.0"], entry_points={"console_scripts": ["yttm = youtokentome.yttm_cli:main"]}, - author="Ivan Belonogov", + author="VKCOM", license="MIT", classifiers=[ "License :: OSI Approved :: MIT License", diff --git a/tests/speed_test/Dockerfile b/tests/speed_test/Dockerfile index 2dc36fa..9631aec 100644 --- a/tests/speed_test/Dockerfile +++ b/tests/speed_test/Dockerfile @@ -8,8 +8,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ cmake \ make \ g++ \ - wget && \ - pip3 install tabulate youtokentome tokenizers + wget \ + bzip2 \ + perl && \ + pip3 install -r requirements.txt && \ + pip3 install youtokentome WORKDIR /repos @@ -26,8 +29,13 @@ RUN g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast WORKDIR /workspace -COPY ./speed_test.py ./speed_test.py RUN cp /repos/fastBPE/fast /workspace/fastBPE +RUN wget -O bert-base-cased.txt https://huggingface.co/bert-base-cased/resolve/main/vocab.txt -# CMD ["python", "speed_test.py", "--langs", "en", "ru", "zh", "ja", "--corpus_size", "100", "--vocab_size", "30000"] -CMD ["python", "speed_test.py", "--langs", "ru", "--corpus_size", "10", "--vocab_size", "30000"] +COPY ./bpe.py ./bpe.py +COPY ./wordpiece.py ./wordpiece.py + +# use comma to separate langs, e.g.: "--langs", "en", "ru", "zh", "ja" +CMD ["python", "bpe.py", "--langs", "ru", "--corpus_size", "10", "--vocab_size", "30000"] + +CMD ["python", "wordpiece.py", "--langs", "ru", "--corpus_size", "10", "--vocab", "bert-base-cased.txt"] diff --git a/tests/speed_test/README.md b/tests/speed_test/README.md index 3283c13..9460755 100644 --- a/tests/speed_test/README.md +++ b/tests/speed_test/README.md @@ -1,25 +1,38 @@ # Running benchmark -* Install [YouTokenToMe](https://github.com/vkcom/youtokentome) -* Install [SentencePiece](https://github.com/google/sentencepiece) -* Install [Hugging Face Tokenizer](https://github.com/huggingface/tokenizers) -* Compile [fastBPE](https://github.com/glample/fastBPE) and specify path to binary file in variable - `PATH_TO_FASTBPE` in `speed_test.py` -* `python speed_test.py` - - **Warning!** This test requires about **20 GBs** of free space on your disk and can take **about one hour** for running. +**Warning!** This test requires about **20 GBs** of free space on your disk and can take **about one hour** for running. It uses Wikipedia monolingual corpora for training and tokenization. [Here](https://linguatools.org/tools/corpora/wikipedia-monolingual-corpora/) you can find more details about the data. - -## Docker -Alternatively benchmark can be run using Docker. +## Recommended approach + +Benchmark can be run using Docker. Substitute `PATH_TO_DOWNLOADED_DATA` with absolute path to the directory where wiki dumps will be downloaded. -``` +```bash cd tests/speed_test docker build -t yttm/speed_test . docker run --rm -v PATH_TO_DOWNLOADED_DATA:/workspace/data -it yttm/speed_test:latest ``` + +## Alternative approach + +## BPE benchmark + +* Install [YouTokenToMe](https://github.com/vkcom/youtokentome) +* Install [Hugging Face Tokenizer](https://github.com/huggingface/tokenizers) +* Install [SentencePiece](https://github.com/google/sentencepiece) +* Compile [fastBPE](https://github.com/glample/fastBPE) and specify path to binary file in variable + `PATH_TO_FASTBPE` in `bpe.py` +* `python bpe.py` + +## WordPiece benchmark + +* Install [YouTokenToMe](https://github.com/vkcom/youtokentome) +* Install [Hugging Face Tokenizer](https://github.com/huggingface/tokenizers) +* Install [Keras](https://github.com/keras-team/keras-nlp) +* Install [Tensorflow](https://github.com/tensorflow/text) +* Install [Torch](https://github.com/pytorch/text) +* `python wordpiece.py` diff --git a/tests/speed_test/speed_test.py b/tests/speed_test/bpe.py similarity index 98% rename from tests/speed_test/speed_test.py rename to tests/speed_test/bpe.py index 56e4d20..e695941 100644 --- a/tests/speed_test/speed_test.py +++ b/tests/speed_test/bpe.py @@ -15,12 +15,12 @@ YOU_TOKEN_TO_ME = "YouTokenToMe" SENTENCE_PIECE = "SentencePiece" FAST_BPE = "fastBPE" -HUGGING_FACE_BPE = "Hugging_Face_BPE" +HUGGING_FACE= "Hugging Face" PATH_TO_FASTBPE = "./fastBPE" -class HuggingfaceInterface: +class HuggingFaceInterface: def train_from_file(self, train_file, vocab_size, model_file, _): tokenizer = HuggingFaceBPETokenizer(HuggingFaceBPEModel(unk_token="[UNK]")) trainer = HuggingFaceBPETrainer(special_tokens=["[UNK]", "[PAD]"], vocab_size=vocab_size) @@ -90,8 +90,8 @@ def get_bpe(impl_name): return SentencePieceInterface() if impl_name == FAST_BPE: return FastBPEInterface() - if impl_name == HUGGING_FACE_BPE: - return HuggingfaceInterface() + if impl_name == HUGGING_FACE: + return HuggingFaceInterface() assert False diff --git a/tests/speed_test/wordpiece.py b/tests/speed_test/wordpiece.py new file mode 100644 index 0000000..6ab72dc --- /dev/null +++ b/tests/speed_test/wordpiece.py @@ -0,0 +1,251 @@ +import argparse +import os +from pathlib import Path +from time import time + +import keras_nlp +import tensorflow +from tabulate import tabulate +from tensorflow_text import BertTokenizer as TensorflowBertTokenizer +from tokenizers import BertWordPieceTokenizer as HuggingFaceBertTokenizer +from torchtext.transforms import BERTTokenizer as TorchBertTokenizer + + +YOU_TOKEN_TO_ME = "YouTokenToMe" +HUGGING_FACE = 'Hugging Face' +KERAS = 'Keras' +TENSORFLOW = 'TensorFlow' +TORCH = 'Torch' + +ALGORITHMS = [YOU_TOKEN_TO_ME, HUGGING_FACE, KERAS, TENSORFLOW, TORCH] +LOWER_CASE = False + + +def collect_to_file(out_file, ids): + if out_file is not None: + with open(out_file, 'w') as f: + for i in ids: + f.write(f'{i} ') + +def run_tensorflow(text_file, vocab_file, n_threads, out_file): + text = "" + with open(text_file, 'r') as f: + text = f.read() + vocab_list = [] + with open(vocab_file, 'r') as f: + for word in f: + vocab_list.append(word) + lookup_table = tensorflow.lookup.StaticVocabularyTable( + tensorflow.lookup.KeyValueTensorInitializer( + keys=vocab_list, + key_dtype=tensorflow.string, + values=tensorflow.range( + tensorflow.size(vocab_list, out_type=tensorflow.int64), dtype=tensorflow.int64), + value_dtype=tensorflow.int64 + ), + num_oov_buckets=1 + ) + tokenizer = TensorflowBertTokenizer(lookup_table, token_out_type=tensorflow.int64, lower_case=LOWER_CASE) + ids = tokenizer.tokenize(text).numpy().tolist() + assert len(ids) > 0 + collect_to_file(out_file, ids) + return len(ids) + + +def run_hugging_face(text_file, vocab_file, n_threads, out_file): + with open(text_file, 'r') as f: + text = f.read() + tokenizer = HuggingFaceBertTokenizer(vocab_file, lowercase=LOWER_CASE) + ids = tokenizer.encode(text).ids + assert len(ids) > 0 + collect_to_file(out_file, ids) + return len(ids) + + +def run_torch(text_file, vocab_file, n_threads, out_file): + with open(text_file, 'r') as f: + text = f.read() + tokenizer = TorchBertTokenizer(vocab_file, do_lower_case=LOWER_CASE) + ids = tokenizer(text) + assert len(ids) > 0 + collect_to_file(out_file, ids) + return len(ids) + + +def run_keras(text_file, vocab_file, n_threads, out_file): + with open(text_file, 'r') as f: + text = f.read() + tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(vocabulary=vocab_file, lowercase=LOWER_CASE) + ids = tokenizer.tokenize(text).numpy().tolist() + assert len(ids) > 0 + collect_to_file(out_file, ids) + return len(ids) + + +def run_you_token_to_me(text_file, vocab_file, n_threads, out_file): + assert(LOWER_CASE == False) + out_file = out_file if out_file is not None else "" + rc = 0 # TODO + assert rc == 0 + return rc + + +def get_wordpiece(impl_name): + if impl_name == YOU_TOKEN_TO_ME: + return run_you_token_to_me + elif impl_name == HUGGING_FACE: + return run_hugging_face + elif impl_name == KERAS: + return run_keras + elif impl_name == TENSORFLOW: + return run_tensorflow + elif impl_name == TORCH: + return run_torch + assert False + + +def download_xml2txt(): + if not Path("xml2txt.pl").exists(): + print("downloading xml2txt.pl ...") + os.system("wget https://www.dropbox.com/s/p3ta9spzfviovk0/xml2txt.pl") + + +def prepare_data(zip_path, size_mb): + expected_extension = ".xml.bz2" + assert zip_path.endswith(expected_extension) + base_path = Path(zip_path).parent + unzip_path = base_path / "wiki.xml" + full_text_path = base_path / "wiki.txt" + cutted_text_path = base_path / f"wiki_{size_mb}MB.txt" + if not Path(unzip_path).exists(): + print(f"unziping file {zip_path} ...") + assert os.system(f"bzip2 -kdc {zip_path} > {unzip_path}") == 0 + if not Path(full_text_path).exists(): + print(f"converting xml to text {unzip_path} ...") + download_xml2txt() + preprocess_command = f"perl xml2txt.pl " + preprocess_command += f" -nomath -notables " + preprocess_command += f" {unzip_path} {full_text_path}" + assert os.system(preprocess_command) == 0 + if not Path(cutted_text_path).exists(): + byte_processed = 0 + with open(cutted_text_path, "w") as fout: + with open(full_text_path, "r") as fin: + while byte_processed < size_mb * 1_000_000: + s = fin.readline() + byte_processed += len(s.encode()) + fout.write(s) + return cutted_text_path + + +def check_inference_file(algorithm, text_file, vocab_file, n_threads, out_file): + wordpiece = get_wordpiece(algorithm) + start_time = time() + res = wordpiece(text_file, vocab_file, n_threads, out_file) + elapsed = time() - start_time + print(f"Runner returned: {res}") + return elapsed + + +def speed_test(text_file: str, vocab_file: str, algorithms, n_threads: int, collect: bool,): + result = {} + for algorithm in algorithms: + print(f'Running {algorithm}') + out_file = f"result_{algorithm}.txt" if collect else None + time_infer = check_inference_file(algorithm, text_file, vocab_file, n_threads, out_file) + print(f'{algorithm} finished in {time_infer:.1f} sec') + result[algorithm] = time_infer + + return result + + +def print_results(cfg, result_name, corpuses, algorithms): + result_table = [ + ["#" for _ in range(len(corpuses) + 1)] for _ in range(len(algorithms)) + ] + table_header = ["#"] + [lang for lang in corpuses] + rev_lang = {lang: i for i, lang in enumerate(table_header)} + rev_algo = {algo: i for i, algo in enumerate(algorithms)} + for i, algo_name in enumerate(algorithms): + result_table[i][0] = algo_name + + for lang, res in cfg.items(): + best = min(res.values()) + for algo in res: + j = rev_lang[lang] + i = rev_algo[algo] + multiplier_str = f"{res[algo]/best:.1f}".rstrip('0').rstrip('.') + result_table[i][j] = f"{res[algo]:.1f} (x{multiplier_str})" + + table_header[0] = result_name + column_align = ["left"] + ["center" for _ in corpuses] + print(tabulate(result_table, table_header, tablefmt="github", colalign=column_align)) + + +def parse_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--vocab", type=str, required=True, help="path to vocab file" + ) + parser.add_argument("--n_threads", type=int, default=8) + parser.add_argument( + "--corpus_size", type=int, default=10, help="Size of testing corpus in MB" + ) + parser.add_argument( + "--langs", + type=str, + nargs="+", + help="list of languages for speed test", + default="en", + ) + parser.add_argument("--collect", action="store_true") + + return parser.parse_args() + + +def main(args): + langs = args.langs if isinstance(args.langs, list) else [args.langs] + # Hugging Face - limit number of processes + os.environ["RAYON_RS_NUM_CPUS"] = str(args.n_threads) + + short_to_long_names = { + "en": "English", + "ru": "Russian", + "ja": "Japanese", + "zh": "Chinese", + } + + # For adding more languages check out this page https://linguatools.org/tools/corpora/wikipedia-monolingual-corpora/ + all_links = { + "English": "https://www.dropbox.com/s/cnrhd11zdtc1pic/enwiki-20181001-corpus.xml.bz2?dl=1", + "Russian": "https://www.dropbox.com/s/lpfmyrl7nxn5ugg/ruwiki-20181001-corpus.xml.bz2?dl=1", + "Japanese": "https://www.dropbox.com/s/wf496hlu512z9kc/jawiki-20140807-corpus.xml.bz2?dl=1", + "Chinese": "https://www.dropbox.com/s/czhr6s5jwaljeue/zhwiki-20140804-corpus.xml.bz2?dl=1", + } + links = { + short_to_long_names[lang]: all_links[short_to_long_names[lang]] + for lang in langs + } + + corpuses = {} + Path("data").mkdir(exist_ok=True) + for lang, link in links.items(): + Path(f"data/{lang}").mkdir(exist_ok=True) + zip_file = f"data/{lang}/wiki.xml.bz2" + if not Path(zip_file).exists(): + os.system(f"wget -O {zip_file} {link}") + corpuses[lang] = prepare_data(zip_file, args.corpus_size) + + global_tokenization = {} + + for lang, corpus_path in corpuses.items(): + tokenization_stat = speed_test(corpus_path, args.vocab, ALGORITHMS, args.n_threads, args.collect) + global_tokenization[lang] = tokenization_stat + + print_results(global_tokenization, f"Tokenization {args.corpus_size}MB", corpuses, ALGORITHMS) + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/tests/unit_tests/README.md b/tests/unit_tests/bpe/README.md similarity index 64% rename from tests/unit_tests/README.md rename to tests/unit_tests/bpe/README.md index f2ef135..500bf35 100644 --- a/tests/unit_tests/README.md +++ b/tests/unit_tests/bpe/README.md @@ -3,7 +3,4 @@ For tests execution simply run: pip install pytest pytest ``` -Testing may take several minutes. - - - +Testing may take several minutes. \ No newline at end of file diff --git a/tests/unit_tests/stress_test.cpp b/tests/unit_tests/bpe/stress_test.cpp similarity index 97% rename from tests/unit_tests/stress_test.cpp rename to tests/unit_tests/bpe/stress_test.cpp index 91a7c63..ce8d850 100644 --- a/tests/unit_tests/stress_test.cpp +++ b/tests/unit_tests/bpe/stress_test.cpp @@ -5,11 +5,10 @@ #include #include #include -#include "stress_test.h" -#include "../../youtokentome/cpp/utils.h" -#include "../../youtokentome/cpp/bpe.h" -#include "../../youtokentome/cpp/utf8.h" +#include "../../../youtokentome/cpp/utils.h" +#include "../../../youtokentome/cpp/bpe.h" +#include "../../../youtokentome/cpp/utf8.h" namespace vkcom { @@ -321,7 +320,7 @@ void manual_test() { BpeConfig bpe_config = {1.0, 1, special_tokens_config}; BPEState model_fast; - status = learn_bpe_from_string(trn_data_copy, n_tokens, "remove_it.txt", bpe_config, &model_fast); + status = bpe_learn_from_string(trn_data_copy, n_tokens, "remove_it.txt", bpe_config, &model_fast); assert(status.ok()); auto model_slow = learn_bpe_slow(trn_data, n_tokens, "remove_it.txt", bpe_config); assert(model_fast.rules == model_slow.rules); @@ -370,7 +369,7 @@ void parallel_test(int n_iter, int n_threads) { auto train_data_copy = train_data; BpeConfig bpe_config = {character_coverage, n_threads, {0, 1, 2, 3}}; BPEState learned_model; - status = learn_bpe_from_string(train_data_copy, vocab_size, "remove_it.txt", bpe_config, &learned_model); + status = bpe_learn_from_string(train_data_copy, vocab_size, "remove_it.txt", bpe_config, &learned_model); assert(status.ok()); BaseEncoder applyer(learned_model, 20); @@ -413,7 +412,7 @@ void base_stress(int n_iter) { auto train_data_copy = train_data; BpeConfig bpe_config = {character_coverage, n_threads, {0, 1, 2, 3}}; BPEState fast_solution_model; - status = learn_bpe_from_string(train_data_copy, vocab_size, "remove_it.txt", bpe_config, &fast_solution_model); + status = bpe_learn_from_string(train_data_copy, vocab_size, "remove_it.txt", bpe_config, &fast_solution_model); assert(status.ok()); auto slow_solution_model = learn_bpe_slow(train_data, vocab_size, "remove_it.txt", bpe_config); diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/bpe/test_cli.py similarity index 85% rename from tests/unit_tests/test_cli.py rename to tests/unit_tests/bpe/test_cli.py index c9cadee..a8b7f89 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/bpe/test_cli.py @@ -18,7 +18,7 @@ def test_bos_eos_reverse(): generate_artifacts() cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=subword", "--n_threads=1", @@ -29,7 +29,7 @@ def test_bos_eos_reverse(): cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=subword", "--n_threads=1", @@ -41,7 +41,7 @@ def test_bos_eos_reverse(): cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=id", "--n_threads=1", @@ -52,7 +52,7 @@ def test_bos_eos_reverse(): cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=id", "--n_threads=1", @@ -67,11 +67,11 @@ def test_bos_eos_reverse(): def test_interactive_mode(): generate_artifacts() print("interactive helper running id ...") - cmd = f"python interactor.py | yttm encode --stream --model={BASE_MODEL_FILE} --output_type=id > log.txt" + cmd = f"python interactor.py | yttm bpe-encode --stream --model={BASE_MODEL_FILE} --output_type=id > log.txt" assert os.system(cmd) == 0 print("interactive helper running subword ...") - cmd = f"python interactor.py | yttm encode --stream --model={BASE_MODEL_FILE} --output_type=subword > log.txt" + cmd = f"python interactor.py | yttm bpe-encode --stream --model={BASE_MODEL_FILE} --output_type=subword > log.txt" assert os.system(cmd) == 0 os.remove("log.txt") @@ -80,7 +80,7 @@ def test_multithreading(): generate_artifacts() cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=subword", "--n_threads=10", @@ -92,7 +92,7 @@ def test_renaming(): generate_artifacts() cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={RENAME_ID_MODEL_FILE}", "--output_type=id", "--bos", @@ -103,7 +103,7 @@ def test_renaming(): cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={RENAME_ID_MODEL_FILE}", "--output_type=id", "--eos", @@ -122,7 +122,7 @@ def test_renaming_unknown(): cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={RENAME_ID_MODEL_FILE}", "--output_type=id", "--reverse", @@ -143,8 +143,8 @@ def test_renaming_unknown(): def test_vocab(): generate_artifacts() - run(["yttm", "vocab", f"--model={BASE_MODEL_FILE}"], check=True) - run(["yttm", "vocab", f"--model={BASE_MODEL_FILE}", "--verbose"], check=True) + run(["yttm", "bpe-vocab", f"--model={BASE_MODEL_FILE}"], check=True) + run(["yttm", "bpe-vocab", f"--model={BASE_MODEL_FILE}", "--verbose"], check=True) def test_decode(): @@ -153,7 +153,7 @@ def test_decode(): with open("decode_text_in.txt", "w") as fout: fout.write(text_in) - cmd_args = ["yttm", "encode", f"--model={BASE_MODEL_FILE}", "--output_type=id"] + cmd_args = ["yttm", "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=id"] run( cmd_args, stdin=open("decode_text_in.txt", "r"), @@ -161,7 +161,7 @@ def test_decode(): check=True, ) - cmd_args = ["yttm", "decode", f"--model={BASE_MODEL_FILE}"] + cmd_args = ["yttm", "bpe-decode", f"--model={BASE_MODEL_FILE}"] run( cmd_args, stdin=open("decode_id.txt", "r"), @@ -176,7 +176,7 @@ def test_decode(): cmd_args = [ "yttm", - "encode", + "bpe-encode", f"--model={BASE_MODEL_FILE}", "--output_type=id", "--bos", @@ -191,7 +191,7 @@ def test_decode(): cmd_args = [ "yttm", - "decode", + "bpe-decode", f"--model={BASE_MODEL_FILE}", f"--ignore_ids={BOS_ID},{EOS_ID}", ] diff --git a/tests/unit_tests/test_manual.py b/tests/unit_tests/bpe/test_manual.py similarity index 100% rename from tests/unit_tests/test_manual.py rename to tests/unit_tests/bpe/test_manual.py diff --git a/tests/unit_tests/test_python_api.py b/tests/unit_tests/bpe/test_python_api.py similarity index 99% rename from tests/unit_tests/test_python_api.py rename to tests/unit_tests/bpe/test_python_api.py index 4fce4c5..2bfe10f 100644 --- a/tests/unit_tests/test_python_api.py +++ b/tests/unit_tests/bpe/test_python_api.py @@ -2,6 +2,7 @@ import random import youtokentome as yttm + from utils_for_testing import ( BASE_MODEL_FILE, RENAME_ID_MODEL_FILE, diff --git a/tests/unit_tests/test_stress.py b/tests/unit_tests/bpe/test_stress.py similarity index 73% rename from tests/unit_tests/test_stress.py rename to tests/unit_tests/bpe/test_stress.py index 98a576e..03e94b0 100644 --- a/tests/unit_tests/test_stress.py +++ b/tests/unit_tests/bpe/test_stress.py @@ -9,16 +9,16 @@ def compile_test(): if tests_compiled: return build_files = ["bpe.cpp", "utils.cpp", "utf8.cpp"] - files = ["../../youtokentome/cpp/" + file_name for file_name in build_files] + files = ["../../../youtokentome/cpp/" + file_name for file_name in build_files] files.append("stress_test.cpp") - print("compiling stress test ...") + print("compiling bpe stress test ...") command = [ "g++", *files, "-o", - "stress", + "bpe_stress", "-std=c++11", "-pthread", "-Og", @@ -35,16 +35,16 @@ def compile_test(): def test_stress(): compile_test() - run(["./stress", "base", "1000"], check=True) + run(["./bpe_stress", "base", "1000"], check=True) def test_manual(): compile_test() - run(["./stress", "manual"], check=True) + run(["./bpe_stress", "manual"], check=True) os.remove("remove_it.txt") def test_parallel(): compile_test() - run(["./stress", "parallel", "50"], check=True) + run(["./bpe_stress", "parallel", "50"], check=True) os.remove("remove_it.txt") diff --git a/tests/unit_tests/utils_for_testing.py b/tests/unit_tests/bpe/utils_for_testing.py similarity index 97% rename from tests/unit_tests/utils_for_testing.py rename to tests/unit_tests/bpe/utils_for_testing.py index 42dba38..eceaa96 100644 --- a/tests/unit_tests/utils_for_testing.py +++ b/tests/unit_tests/bpe/utils_for_testing.py @@ -37,7 +37,7 @@ def generate_artifacts(): cmd_args = [ "yttm", - "bpe", + "bpe-train", f"--data={TRAIN_FILE}", f"--model={BASE_MODEL_FILE}", "--vocab_size=16000", @@ -49,7 +49,7 @@ def generate_artifacts(): run(cmd_args, check=True) cmd_args = [ "yttm", - "bpe", + "bpe-train", f"--data={TRAIN_FILE}", f"--model={RENAME_ID_MODEL_FILE}", "--vocab_size=16000", diff --git a/tests/unit_tests/stress_test.h b/tests/unit_tests/stress_test.h deleted file mode 100644 index 5c877f8..0000000 --- a/tests/unit_tests/stress_test.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "../../youtokentome/cpp/third_party/flat_hash_map.h" -#include "../../youtokentome/cpp/bpe.h" - -namespace vkcom { - -flat_hash_map -compute_alphabet_helper(const flat_hash_map &char_cnt, - uint64_t data_len, - flat_hash_set &removed_chars, - const BpeConfig &bpe_config); - -Status learn_bpe_from_string(std::string &text_utf8, - int n_tokens, - const std::string &output_file, - BpeConfig bpe_config, - BPEState *bpe_state); - -} // namespace vkcom diff --git a/tests/unit_tests/wordpiece/README.md b/tests/unit_tests/wordpiece/README.md new file mode 100644 index 0000000..500bf35 --- /dev/null +++ b/tests/unit_tests/wordpiece/README.md @@ -0,0 +1,6 @@ +For tests execution simply run: +``` +pip install pytest +pytest +``` +Testing may take several minutes. \ No newline at end of file diff --git a/tests/unit_tests/wordpiece/stress_test.cpp b/tests/unit_tests/wordpiece/stress_test.cpp new file mode 100644 index 0000000..a9a7fd7 --- /dev/null +++ b/tests/unit_tests/wordpiece/stress_test.cpp @@ -0,0 +1,186 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../../youtokentome/cpp/wordpiece.h" + +using namespace vkcom; + +struct TestCase { + std::string text; + std::vector vocab; + std::vector answer_encoded; + std::vector answer_decoded; +}; + +template +void dump_vector(const std::string &filename, const std::vector &vec, char delim) { + std::ofstream fout(filename); + for (const auto& item : vec) { + fout << item << delim; + } +} + +void dump_test_case(const TestCase &test_case) { + { + std::ofstream fout("stress.txt"); + fout << test_case.text; + } + dump_vector("vocab.txt", test_case.vocab, '\n'); + dump_vector("anwer_encoded.txt", test_case.answer_encoded, ' '); + dump_vector("answer_decoded.txt", test_case.answer_decoded, ' '); +} + +void check(const TestCase &test_case, const std::vector &encoded, const std::vector &decoded) { + if (encoded != test_case.answer_encoded || decoded != test_case.answer_decoded) { + dump_test_case(test_case); + throw std::runtime_error("STRESS TEST FAILED, test case dumped"); + } +} + +std::string get_random_string(std::mt19937 &rnd, size_t string_length) { + static const std::string kAllChars = "abcdefghijklmnopqrstuvwxyz"; + if (string_length == 0) { + throw std::runtime_error("string_length cannot be 0"); + } + std::string result; + result.reserve(string_length); + while (string_length > 0) { + --string_length; + size_t index = std::uniform_int_distribution(0ul, kAllChars.size() - 1)(rnd); + result.push_back(kAllChars[index]); + } + return result; +} + +TestCase generate_test_case(size_t text_len, size_t parts) { + std::mt19937 rnd(17); + std::string text; + text.reserve(text_len + parts); + std::uniform_int_distribution word_len(1ul, std::max(2 * text_len / parts, 3ul)); + + std::unordered_map vocab_map; + std::vector answer_encoded; + std::vector answer_decoded; + answer_encoded.reserve(parts); + answer_decoded.reserve(parts); + + for (size_t i = 0; i < parts && text.size() < text.capacity(); i++) { + const size_t vocab_size = vocab_map.size(); + if (i + 1 == parts) { + size_t leftover = text.capacity() - text.size(); + std::string word = get_random_string(rnd, leftover); + if (vocab_map[word] == 0) { + vocab_map[word] = static_cast(vocab_size) + 1; + } + text.append(word); + answer_encoded.push_back(vocab_map[word] - 1); + answer_decoded.push_back(std::move(word)); + } else if (i > 0 && i % 10 == 0) { + std::uniform_int_distribution rnd_word(0ul, vocab_size - 1); + auto it = std::next(vocab_map.begin(), rnd_word(rnd)); + text.append(it->first); + text.push_back(' '); + answer_encoded.push_back(it->second - 1); + answer_decoded.push_back(it->first); + } else { + std::string word = get_random_string(rnd, word_len(rnd)); + if (vocab_map[word] == 0) { + vocab_map[word] = static_cast(vocab_size) + 1; + } + text.append(word); + text.push_back(' '); + answer_encoded.push_back(vocab_map[word] - 1); + answer_decoded.push_back(std::move(word)); + } + } + + std::vector vocab; + vocab.resize(vocab_map.size()); + for (auto it = vocab_map.begin(); it != vocab_map.end(); it++) { + vocab[it->second - 1] = it->first; + } + return TestCase{std::move(text), std::move(vocab), std::move(answer_encoded), std::move(answer_decoded)}; +} + +void test_stress(size_t text_len_from, + size_t text_len_to, + size_t text_len_step, + size_t parts_from, + size_t parts_to, + int n_threads) { + for (size_t text_len = text_len_from; text_len <= text_len_to; text_len += text_len_step) { + for (size_t parts = std::min(text_len, parts_from); parts <= std::min(text_len, parts_to); + parts++) { + + const std::string text_filename("stress.txt"); + TestCase test_case = generate_test_case(text_len, parts); + std::cout << "running stress, text_len " << test_case.text.size() << ' ' << text_len << ", vocab_size " + << test_case.vocab.size() << std::endl; + { + std::ofstream fout(text_filename); + fout << test_case.text; + } + + Status status; + std::vector encoded; + wordpiece::Encoder encoder(test_case.vocab, n_threads); + status = encoder.encode_as_ids(text_filename, &encoded); + if (!status.ok()) { + dump_test_case(test_case); + throw std::runtime_error("encode_as_ids failed, test_case dumped: " + status.error_message()); + } + std::vector decoded; + status = encoder.encode_as_subwords(text_filename, &decoded); + if (!status.ok()) { + dump_test_case(test_case); + throw std::runtime_error("encode_as_subwords failed, test_case dumped: " + status.error_message()); + } + check(test_case, encoded, decoded); + } + } +} + +void run_small(int n_threads) { + test_stress(10, 300, 5, 2, 100, n_threads); + test_stress(10, 300, 5, 2, 100, n_threads); +} + +void run_large(int n_threads) { + test_stress(100000, + 1000000, + 400000, + 30000, + 30000, + n_threads); + test_stress(10000000, + 10000000, + 200000, + 30000, + 30000, + n_threads); +} + +int main(int argc, char **argv) { + if (argc != 2) { + assert(false); + } + std::string mode = argv[1]; + if (argc == 2 && mode == "small") { + run_small(1); + } else if (argc == 2 && mode == "large") { + run_large(1); + } else if (argc == 2 && mode == "parallel") { + run_small(0); + run_large(0); + } else { + assert(false); + } +} + diff --git a/tests/unit_tests/wordpiece/test_cli.py b/tests/unit_tests/wordpiece/test_cli.py new file mode 100644 index 0000000..49960d2 --- /dev/null +++ b/tests/unit_tests/wordpiece/test_cli.py @@ -0,0 +1,3 @@ +import os +import random +from subprocess import run \ No newline at end of file diff --git a/tests/unit_tests/wordpiece/test_manual.py b/tests/unit_tests/wordpiece/test_manual.py new file mode 100644 index 0000000..fe1c77c --- /dev/null +++ b/tests/unit_tests/wordpiece/test_manual.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +import youtokentome as yttm + + +def check(text, vocab, output_type=yttm.OutputType.ID): + TEXT_FILE = "text_file.txt" + VOCAB_FILE = "vocab_file.txt" + with open(TEXT_FILE, 'w') as f: + f.write(text) + with open(VOCAB_FILE, 'w') as f: + for word in vocab: + f.write(word) + f.write('\n') + + encoder = yttm.WordPiece(VOCAB_FILE) + return encoder.encode(TEXT_FILE, output_type=output_type) + + +def test_russian(): + ids = check("привет мир", ["привет", "мир"]) + assert ids == [0, 1] + + ids = check("привет мир", ["при", "##вет", "мир"]) + assert ids == [0, 1, 2] + + ids = check("токенизация это круто", ["ток", "крут", "это", "##за", "##ция", "ция"]) + assert ids == [-1, 2, -1] + + ids = check("токенизация это круто", ["ток", "крут", "это", "##за", "##ени", "##о", "##ция", "ция"]) + assert ids == [0, 4, 3, 6, 2, 1, 5] + + +def test_english(): + ids = check("self-made", ["self", "made", "-", "##-", "##made"]) + assert ids == [0, 2, 1] + + ids = check("self, made", ["self", "made", ",", "##,", "##made"]) + assert ids == [0, 2, 1] + + ids = check("self , made", ["self", "made", ",", "##,", "##made"]) + assert ids == [0, 2, 1] + + +def test_japanese(): + pass + + +def test_misc(): + ids = check("abcdef", ["a", "##bcdef", "ab", "##c", "##d", "##e", "##f"]) + assert ids == [2, 3, 4, 5, 6] + + ids = check("abcdef abc abcd", ["abcd", "def", "abc"]) + assert ids == [-1, 2, 0] + + ids = check("abc", ["a", "abd"]) + assert ids == [-1] + + ids = check("abc a abc abd", ["a", "abd"]) + assert ids == [-1, 0, -1, 1] + + ids = check("abcdef", ["bcde", "ac", "def", "bc", "bcdef", "##a", "##b", "##c", "##d"]) + assert ids == [-1] diff --git a/tests/unit_tests/wordpiece/test_python_api.py b/tests/unit_tests/wordpiece/test_python_api.py new file mode 100644 index 0000000..a8acf38 --- /dev/null +++ b/tests/unit_tests/wordpiece/test_python_api.py @@ -0,0 +1,4 @@ +import os +import random + +import youtokentome as yttm \ No newline at end of file diff --git a/tests/unit_tests/wordpiece/test_stress.py b/tests/unit_tests/wordpiece/test_stress.py new file mode 100644 index 0000000..520dbd8 --- /dev/null +++ b/tests/unit_tests/wordpiece/test_stress.py @@ -0,0 +1,47 @@ +import os +from subprocess import run + + +tests_compiled = False + +def compile_test(): + global tests_compiled + if tests_compiled: + return + build_files = ["wordpiece.cpp", "utils.cpp", "utf8.cpp"] + files = ["../../../youtokentome/cpp/" + file_name for file_name in build_files] + files.append("stress_test.cpp") + + print("compiling wordpiece stress test ...") + + command = [ + "g++", + *files, + "-o", + "wordpiece_stress", + "-std=c++11", + "-pthread", + "-Og", + "-D_GLIBCXX_DEBUG", + "-fno-omit-frame-pointer -fsanitize=address -fsanitize=leak -fsanitize=undefined", + ] + + command = " ".join(command) + print("command:", command) + run(command, check=True, shell=True) + tests_compiled = True + + +def test_small(): + compile_test() + run(["./wordpiece_stress", "small"], check=True) + + +def test_manual(): + compile_test() + run(["./wordpiece_stress", "large"], check=True) + + +def test_parallel(): + compile_test() + run(["./wordpiece_stress", "parallel"], check=True) diff --git a/youtokentome/__init__.py b/youtokentome/__init__.py index a0d7baa..ab836bc 100644 --- a/youtokentome/__init__.py +++ b/youtokentome/__init__.py @@ -1,2 +1,3 @@ from .youtokentome import BPE from .youtokentome import OutputType +from .youtokentome import WordPiece diff --git a/youtokentome/cpp/bpe.cpp b/youtokentome/cpp/bpe.cpp index 0175afa..1bc07de 100644 --- a/youtokentome/cpp/bpe.cpp +++ b/youtokentome/cpp/bpe.cpp @@ -1,5 +1,3 @@ -#include - #include "bpe.h" #include @@ -7,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -19,70 +18,76 @@ #include #include -#include "third_party/flat_hash_map.h" +#include "third_party/flat_hash_map/flat_hash_map.h" #include "utf8.h" #include "utils.h" -namespace vkcom { +namespace { -struct VectorSegment { - constexpr static uint64_t MOD = 2032191299; - constexpr static uint64_t P = 726328703; +const std::string UNK_TOKEN = ""; +const std::string PAD_TOKEN = ""; +const std::string BOS_TOKEN = ""; +const std::string EOS_TOKEN = ""; - const char* begin; - const char* end; - uint64_t hash; +} // namespace - VectorSegment(const char* begin, const char* end): begin(begin), end(end) { - hash = 0; - for (auto it = begin; it != end; it++) { - hash = (hash * P + (unsigned char)(*it)) % MOD; - } - } +namespace vkcom { - bool operator==(const VectorSegment &other) const { - if (other.hash != hash || end - begin != other.end - other.begin) { - return false; - } - for (auto it = begin, other_it = other.begin; it != end; it++, other_it++) { - if (*it != *other_it) { - return false; - } - } - return true; - } -}; +bool BPE_Rule::operator==(const BPE_Rule &other) const { + return x == other.x && y == other.y && z == other.z; +} -} // namespace vkcom +BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {} -namespace std { -template<> -struct hash { - uint64_t operator()(const vkcom::VectorSegment &x) const { return x.hash; } -}; -} // namespace std +void BPEState::dump(const std::string &file_name) { + std::ofstream fout(file_name, std::ios::out); + if (fout.fail()) { + std::cerr << "Can't open file: " << file_name << std::endl; + assert(false); + } + fout << char2id.size() << " " << rules.size() << std::endl; + for (auto s : char2id) { + fout << s.first << " " << s.second << std::endl; + } -namespace vkcom { + for (auto rule : rules) { + fout << rule.x << " " << rule.y << " " << rule.z << std::endl; + } + special_tokens.dump(fout); + fout.close(); +} -Status fast_read_file_utf8(const std::string &file_name, std::string *file_content) { - static const int buf_size = 1000000; - *file_content = ""; - auto fin = fopen(file_name.data(), "rb"); - if (fin == nullptr) { - return Status(1, "Failed to open file: " + file_name); +Status BPEState::load(const std::string &file_name) { + char2id.clear(); + rules.clear(); + std::ifstream fin(file_name, std::ios::in); + if (fin.fail()) { + return Status(1, "Can not open file with model: " + file_name); } - while (true) { - uint64_t cur_size = file_content->size(); - file_content->resize(cur_size + buf_size); - int buf_len = fread((void *) (file_content->data() + cur_size), 1, buf_size, fin); - if (buf_len < buf_size) { - file_content->resize(file_content->size() - (buf_size - buf_len)); - fclose(fin); - return Status(); - } + int n, m; + fin >> n >> m; + for (int i = 0; i < n; i++) { + uint32_t inner_id; + uint32_t utf32_id; + fin >> inner_id >> utf32_id; + char2id[inner_id] = utf32_id; } + for (int i = 0; i < m; i++) { + uint32_t x, y, z; + fin >> x >> y >> z; + rules.emplace_back(x, y, z); + } + special_tokens.load(fin); + fin.close(); + return Status(); } +BpeConfig::BpeConfig(double _character_coverage, int _n_threads, + const SpecialTokens &_special_tokens) + : character_coverage(_character_coverage), + n_threads(_n_threads), + special_tokens(_special_tokens) {} + std::string token2word(const std::vector &source, const flat_hash_map &id2char) { std::vector res; @@ -385,10 +390,10 @@ struct WordCount { }; -flat_hash_map compute_word_count( +flat_hash_map compute_word_count( char* sbegin, char* send, const flat_hash_map &char2id) { - flat_hash_map hash2wordcnt; + flat_hash_map hash2wordcnt; std::vector word; UTF8Iterator utf8_iter(sbegin, send); @@ -400,7 +405,8 @@ flat_hash_map compute_word_count( char* begin_of_word = utf8_iter.get_ptr(); for (; !utf8_iter.empty() && !is_space(*utf8_iter); ++utf8_iter); char* end_of_word = utf8_iter.get_ptr(); - VectorSegment word_hash(begin_of_word, end_of_word); + VectorSegmentBuilder word_hash_builder(begin_of_word, end_of_word); + BpeVectorSegment word_hash = word_hash_builder.finish(); auto it = hash2wordcnt.find(word_hash); if (it == hash2wordcnt.end()) { word.clear(); @@ -856,7 +862,7 @@ uint64_t compute_char_count(flat_hash_map& char_cnt, char* b return char_count; } -Status learn_bpe_from_string(std::string &text_utf8, int n_tokens, +Status bpe_learn_from_string(std::string &text_utf8, int n_tokens, const std::string &output_file, BpeConfig bpe_config, BPEState *bpe_state) { assert(bpe_config.n_threads >= 1 || bpe_config.n_threads == -1); @@ -883,7 +889,7 @@ Status learn_bpe_from_string(std::string &text_utf8, int n_tokens, flat_hash_set removed_chars; flat_hash_map char2id; - std::vector> hash2wordcnt(n_threads); + std::vector> hash2wordcnt(n_threads); int error_flag = 0; flat_hash_map> recipe; @@ -1041,7 +1047,7 @@ Status learn_bpe_from_string(std::string &text_utf8, int n_tokens, word_cnt_global.resize(hash2wordcnt[0].size()); std::transform( hash2wordcnt[0].begin(), hash2wordcnt[0].end(), word_cnt_global.begin(), - [](const std::pair &x) { return x.second; }); + [](const std::pair &x) { return x.second; }); hash2wordcnt.shrink_to_fit(); text_utf8.shrink_to_fit(); @@ -1365,7 +1371,7 @@ void print_config(const std::string &input_path, const std::string &model_path, std::cerr << std::endl; } -Status train_bpe(const std::string &input_path, const std::string &model_path, +Status bpe_train(const std::string &input_path, const std::string &model_path, int vocab_size, BpeConfig bpe_config) { Status status = check_config(bpe_config, vocab_size); if (!status.ok()) { @@ -1380,7 +1386,7 @@ Status train_bpe(const std::string &input_path, const std::string &model_path, } std::cerr << "learning bpe..." << std::endl; BPEState bpe_state; - status = learn_bpe_from_string(data, vocab_size, model_path, bpe_config, &bpe_state); + status = bpe_learn_from_string(data, vocab_size, model_path, bpe_config, &bpe_state); if (!status.ok()) { return status; } @@ -1980,7 +1986,7 @@ Status BaseEncoder::encode_cli(const std::string &output_type_str, bool stream, int chars_remove = 0; do { processed = 0; - auto sentences = read_lines_from_stdin(batch_limit, &processed); + auto sentences = read_lines(std::cin, batch_limit, &processed); if (output_type == SUBWORD) { std::vector> subwords; Status status = encode_as_subwords(sentences, &subwords, bos, eos, reverse, dropout_prob); diff --git a/youtokentome/cpp/bpe.h b/youtokentome/cpp/bpe.h index e563aa7..901207c 100644 --- a/youtokentome/cpp/bpe.h +++ b/youtokentome/cpp/bpe.h @@ -1,23 +1,72 @@ #pragma once -#include #include #include -#include "third_party/flat_hash_map.h" +#include + +#include "third_party/flat_hash_map/flat_hash_map.h" #include "utils.h" +// TODO: introduce vkcom::bpe namespace namespace vkcom { -const std::string UNK_TOKEN = ""; -const std::string PAD_TOKEN = ""; -const std::string BOS_TOKEN = ""; -const std::string EOS_TOKEN = ""; +struct BPE_Rule { + // x + y -> z + uint32_t x{0}; + uint32_t y{0}; + uint32_t z{0}; + + BPE_Rule() = default; + + BPE_Rule(uint32_t x, uint32_t y, uint32_t z); + + bool operator==(const BPE_Rule &other) const; +}; + +struct BpeConfig { + double character_coverage = 1; + int n_threads = 0; + SpecialTokens special_tokens; + + BpeConfig() = default; + + BpeConfig(double character_coverage, int n_threads, const SpecialTokens &special_tokens); +}; -enum OutputType { ID, SUBWORD }; +struct BPEState { + flat_hash_map char2id; + std::vector rules; + SpecialTokens special_tokens; -Status train_bpe(const std::string &input_path, const std::string &model_path, - int vocab_size, BpeConfig config); + void dump(const std::string &file_name); + + Status load(const std::string &file_name); +}; + +struct EncodingConfig { + bool bos; + bool eos; + bool reverse; + double dropout_prob; +}; + +Status bpe_train(const std::string &input_path, + const std::string &model_path, + int vocab_size, + BpeConfig config); + +Status bpe_learn_from_string(std::string &text_utf8, + int n_tokens, + const std::string &output_file, + BpeConfig bpe_config, + BPEState *bpe_state); + +flat_hash_map +compute_alphabet_helper(const flat_hash_map &char_cnt, + uint64_t data_len, + flat_hash_set &removed_chars, + const BpeConfig &bpe_config); class BaseEncoder { public: @@ -34,15 +83,19 @@ class BaseEncoder { void fill_from_state(); - Status encode_as_ids( - const std::vector &sentences, std::vector> *ids, bool bos = false, - bool eos = false, bool reverse = false, double dropout_prob=0) const; + Status encode_as_ids(const std::vector &sentences, + std::vector> *ids, + bool bos = false, + bool eos = false, + bool reverse = false, + double dropout_prob = 0) const; - Status encode_as_subwords( - const std::vector &sentences, - std::vector> *subwords, - bool bos = false, - bool eos = false, bool reverse = false, double dropout_prob=0) const; + Status encode_as_subwords(const std::vector &sentences, + std::vector> *subwords, + bool bos = false, + bool eos = false, + bool reverse = false, + double dropout_prob = 0) const; Status id_to_subword(int id, std::string *subword, bool replace_space = false) const; @@ -52,7 +105,9 @@ class BaseEncoder { std::vector *sentences, const std::unordered_set *ignore_ids) const; - Status decode(const std::vector &ids, std::string *sentence, const std::unordered_set *ignore_ids) const; + Status decode(const std::vector &ids, + std::string *sentence, + const std::unordered_set *ignore_ids) const; Status decode(const std::vector &ids, std::vector *sentences, @@ -62,8 +117,12 @@ class BaseEncoder { std::vector vocabulary() const; - Status encode_cli(const std::string &output_type, bool stream, bool bos = false, - bool eos = false, bool reverse = false, double dropout_prob = 0) const; + Status encode_cli(const std::string &output_type, + bool stream, + bool bos = false, + bool eos = false, + bool reverse = false, + double dropout_prob = 0) const; Status decode_cli(const std::unordered_set *ignore_ids) const; @@ -74,11 +133,10 @@ class BaseEncoder { const EncodingConfig &encoding_config, OutputType output_type) const; - Status encode_parallel( - const std::vector &sentences, - const EncodingConfig &encoding_config, OutputType output_type, - std::vector *decoder_results - ) const; + Status encode_parallel(const std::vector &sentences, + const EncodingConfig &encoding_config, + OutputType output_type, + std::vector *decoder_results) const; }; } // namespace vkcom diff --git a/youtokentome/cpp/third_party/LICENSE b/youtokentome/cpp/third_party/flat_hash_map/LICENSE similarity index 100% rename from youtokentome/cpp/third_party/LICENSE rename to youtokentome/cpp/third_party/flat_hash_map/LICENSE diff --git a/youtokentome/cpp/third_party/flat_hash_map.h b/youtokentome/cpp/third_party/flat_hash_map/flat_hash_map.h similarity index 100% rename from youtokentome/cpp/third_party/flat_hash_map.h rename to youtokentome/cpp/third_party/flat_hash_map/flat_hash_map.h diff --git a/youtokentome/cpp/third_party/thread_pool/LICENSE b/youtokentome/cpp/third_party/thread_pool/LICENSE new file mode 100644 index 0000000..3b66ae6 --- /dev/null +++ b/youtokentome/cpp/third_party/thread_pool/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Ibragim Dzhiblavi + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/youtokentome/cpp/third_party/thread_pool/thread_pool.h b/youtokentome/cpp/third_party/thread_pool/thread_pool.h new file mode 100644 index 0000000..459cbc9 --- /dev/null +++ b/youtokentome/cpp/third_party/thread_pool/thread_pool.h @@ -0,0 +1,91 @@ +// Copyright (c) 2023 Ibragim Dzhiblavi +// Modified 2023 Gleb Koveshnikov + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace vkcom { + +class ThreadPool { + public: + using Task = std::function; + + public: + ThreadPool(int thread_count) { + if (thread_count <= 0) { + thread_count = static_cast(std::thread::hardware_concurrency()); + } + if (thread_count == 0) { + thread_count = 8; + } + for (int thread = 0; thread < thread_count; ++thread) { + threads_.emplace_back([this] { + while (!stop_.load(std::memory_order_relaxed)) { + std::unique_lock lock(mutex_); + work_cv_.wait(lock, [this] { + return stop_.load(std::memory_order_relaxed) || !task_queue_.empty(); + }); + if (stop_.load(std::memory_order_relaxed)) { + break; + } + if (task_queue_.empty()) { + continue; + } + ++active_tasks_; + auto task = std::move(task_queue_.front()); + task_queue_.pop(); + lock.unlock(); + task(); + lock.lock(); + --active_tasks_; + complete_cv_.notify_one(); + } + }); + } + } + + ~ThreadPool() { + stop_.store(true, std::memory_order_relaxed); + work_cv_.notify_all(); + for (auto &thread : threads_) { + if (thread.joinable()) { + thread.join(); + } + } + } + + void submit(Task &&task) { + { + std::lock_guard lg(mutex_); + task_queue_.emplace(std::move(task)); + } + work_cv_.notify_one(); + } + + void waitCompletion() { + std::unique_lock lock(mutex_); + if (active_tasks_ != 0 || !task_queue_.empty()) { + complete_cv_.wait(lock, [this] { return active_tasks_ == 0 && task_queue_.empty(); }); + } + } + + [[nodiscard]] size_t maxThreads() const noexcept { return threads_.size(); } + + private: + std::atomic stop_{false}; + size_t active_tasks_{0}; + std::mutex mutex_; + std::condition_variable work_cv_; + std::condition_variable complete_cv_; + std::vector threads_; + std::queue task_queue_; +}; + +} // namespace vkcom \ No newline at end of file diff --git a/youtokentome/cpp/utf8.cpp b/youtokentome/cpp/utf8.cpp index 3a67172..8b2b2ca 100644 --- a/youtokentome/cpp/utf8.cpp +++ b/youtokentome/cpp/utf8.cpp @@ -1,21 +1,34 @@ #include "utf8.h" -#include #include -#include -#include -#include "utils.h" namespace vkcom { -using std::string; -using std::vector; +bool is_space(uint32_t ch) { + return (ch < 256 && std::isspace(static_cast(ch))) || (ch == SPACE_TOKEN); +} + +bool is_punctuation(uint32_t ch) { + return (ch < 256 && std::ispunct(static_cast(ch))) || ch == 183 || ch == 171 + || ch == 187 || ch == 8249 || ch == 8250 || (8208 <= ch && ch <= 8248); +} + +bool is_chinese(uint32_t ch) { + if ((ch >= 0x4E00 && ch <= 0x9FFF) || (ch >= 0x3400 && ch <= 0x4DBF) + || (ch >= 0x20000 && ch <= 0x2A6DF) || (ch >= 0x2A700 && ch <= 0x2B73F) + || (ch >= 0x2B740 && ch <= 0x2B81F) || (ch >= 0x2B820 && ch <= 0x2CEAF) + || (ch >= 0xF900 && ch <= 0xFAFF) || (ch >= 0x2F800 && ch <= 0x2FA1F)) { + return true; + } + return false; +} +bool is_spacing_char(uint32_t ch) { return is_space(ch) || is_punctuation(ch) || is_chinese(ch); } bool check_byte(char x) { return (static_cast(x) & 0xc0u) == 0x80u; } -bool check_codepoint(uint32_t x) { - return (x < 0xd800) || (0xdfff < x && x < 0x110000); -} +bool check_symbol_start(char x) { return !check_byte(x); }; + +bool check_codepoint(uint32_t x) { return (x < 0xd800) || (0xdfff < x && x < 0x110000); } uint64_t utf_length(char ch) { if ((static_cast(ch) & 0x80u) == 0) { @@ -34,7 +47,7 @@ uint64_t utf_length(char ch) { return 0; } -uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len) { +uint32_t chars_to_utf8(const char *begin, uint64_t size, uint64_t *utf8_len) { uint64_t length = utf_length(begin[0]); if (length == 1) { *utf8_len = 1; @@ -48,8 +61,7 @@ uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len) { *utf8_len = 2; return code_point; } - } else if (size >= 3 && length == 3 && check_byte(begin[1]) && - check_byte(begin[2])) { + } else if (size >= 3 && length == 3 && check_byte(begin[1]) && check_byte(begin[2])) { code_point += (static_cast(begin[0]) & 0x0fu) << 12u; code_point += (static_cast(begin[1]) & 0x3fu) << 6u; code_point += (static_cast(begin[2]) & 0x3fu); @@ -57,8 +69,8 @@ uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len) { *utf8_len = 3; return code_point; } - } else if (size >= 4 && length == 4 && check_byte(begin[1]) && - check_byte(begin[2]) && check_byte(begin[3])) { + } else if (size >= 4 && length == 4 && check_byte(begin[1]) && check_byte(begin[2]) + && check_byte(begin[3])) { code_point += (static_cast(begin[0]) & 0x07u) << 18u; code_point += (static_cast(begin[1]) & 0x3fu) << 12u; code_point += (static_cast(begin[2]) & 0x3fu) << 6u; @@ -73,7 +85,13 @@ uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len) { return INVALID_UNICODE; } -void utf8_to_chars(uint32_t x, std::back_insert_iterator it) { +bool starts_with_space(const char *begin, int64_t size) { + uint64_t len = 0; + uint32_t symbol = chars_to_utf8(begin, size, &len); + return is_space(symbol); +} + +void utf8_to_chars(uint32_t x, std::back_insert_iterator it) { assert(check_codepoint(x)); if (x <= 0x7f) { @@ -100,16 +118,16 @@ void utf8_to_chars(uint32_t x, std::back_insert_iterator it) { *(it++) = 0x80u | (x & 0x3fu); } -string encode_utf8(const vector& text) { - string utf8_text; +std::string encode_utf8(const std::vector &text) { + std::string utf8_text; for (const uint32_t c : text) { utf8_to_chars(c, std::back_inserter(utf8_text)); } return utf8_text; } -vector decode_utf8(const char* begin, const char* end) { - vector decoded_text; +std::vector decode_utf8(const char *begin, const char *end) { + std::vector decoded_text; uint64_t utf8_len = 0; bool invalid_input = false; for (; begin < end; begin += utf8_len) { @@ -121,14 +139,13 @@ vector decode_utf8(const char* begin, const char* end) { } } if (invalid_input) { - std::cerr << "WARNING Input contains invalid unicode characters." - << std::endl; + std::cerr << "WARNING Input contains invalid unicode characters." << std::endl; } return decoded_text; } -vector decode_utf8(const string& utf8_text) { +std::vector decode_utf8(const std::string &utf8_text) { return decode_utf8(utf8_text.data(), utf8_text.data() + utf8_text.size()); } -} // namespace vkcom +} // namespace vkcom diff --git a/youtokentome/cpp/utf8.h b/youtokentome/cpp/utf8.h index 0888bb3..e3a7c72 100644 --- a/youtokentome/cpp/utf8.h +++ b/youtokentome/cpp/utf8.h @@ -1,25 +1,46 @@ #pragma once -#include -#include #include +#include +#include +// TODO: introduce vkcom::utf8 namespace namespace vkcom { -constexpr static uint32_t INVALID_UNICODE = 0x0fffffff; +const uint32_t SPACE_TOKEN = 9601; + +constexpr static uint32_t INVALID_UNICODE = 0x110000; + +bool is_space(uint32_t ch); + +bool is_punctuation(uint32_t ch); + +bool is_chinese_char(uint32_t ch); + +bool is_spacing_char(uint32_t ch); + +bool check_byte(char x); -uint32_t chars_to_utf8(const char* begin, uint64_t size, uint64_t* utf8_len); +bool check_symbol_start(char x); + +bool check_codepoint(uint32_t x); + +uint64_t utf_length(char ch); + +uint32_t chars_to_utf8(const char *begin, uint64_t size, uint64_t *utf8_len); + +bool starts_with_space(const char *begin, int64_t size); void utf8_to_chars(uint32_t x, std::back_insert_iterator it); -std::string encode_utf8(const std::vector &utext); +std::string encode_utf8(const std::vector &text); std::vector decode_utf8(const char *begin, const char *end); std::vector decode_utf8(const std::string &utf8_text); struct UTF8Iterator { - UTF8Iterator(char* begin, char* end): begin(begin), end(end) {} + UTF8Iterator(char *begin, char *end) : begin(begin), end(end) {} UTF8Iterator operator++() { if (!state) { @@ -37,22 +58,21 @@ struct UTF8Iterator { return code_point; } - char* get_ptr() { - return begin; - } - uint64_t get_utf8_len() { - return utf8_len; - } + char *get_ptr() { return begin; } + + uint64_t get_utf8_len() { return utf8_len; } bool empty() { assert(begin <= end); return begin == end; } -private: + + private: char *begin, *end; uint32_t code_point = 0; uint64_t utf8_len = 0; bool state = false; + void parse() { if (state) { return; diff --git a/youtokentome/cpp/utils.cpp b/youtokentome/cpp/utils.cpp index 64ed922..45502d0 100644 --- a/youtokentome/cpp/utils.cpp +++ b/youtokentome/cpp/utils.cpp @@ -1,20 +1,22 @@ #include "utils.h" -#include + #include -#include #include #include namespace vkcom { +Status::Status(int code, std::string message) : code(code), message(std::move(message)) {} + +const std::string &Status::error_message() const { return message; } + +bool Status::ok() const { return code == 0; } + void SpecialTokens::dump(std::ofstream &fout) { - fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id - << std::endl; + fout << unk_id << " " << pad_id << " " << bos_id << " " << eos_id << std::endl; } -void SpecialTokens::load(std::ifstream &fin) { - fin >> unk_id >> pad_id >> bos_id >> eos_id; -} +void SpecialTokens::load(std::ifstream &fin) { fin >> unk_id >> pad_id >> bos_id >> eos_id; } uint32_t SpecialTokens::max_id() const { int ret = 0; @@ -39,83 +41,45 @@ uint64_t SpecialTokens::n_special_tokens() const { } SpecialTokens::SpecialTokens(int pad_id, int unk_id, int bos_id, int eos_id) - : pad_id(pad_id), unk_id(unk_id), bos_id(bos_id), eos_id(eos_id) {} - -bool BPE_Rule::operator==(const BPE_Rule &other) const { - return x == other.x && y == other.y && z == other.z; -} - -BPE_Rule::BPE_Rule(uint32_t x, uint32_t y, uint32_t z) : x(x), y(y), z(z) {} - -void BPEState::dump(const std::string &file_name) { - std::ofstream fout(file_name, std::ios::out); - if (fout.fail()) { - std::cerr << "Can't open file: " << file_name << std::endl; - assert(false); - } - fout << char2id.size() << " " << rules.size() << std::endl; - for (auto s : char2id) { - fout << s.first << " " << s.second << std::endl; - } - - for (auto rule : rules) { - fout << rule.x << " " << rule.y << " " << rule.z << std::endl; - } - special_tokens.dump(fout); - fout.close(); -} + : pad_id(pad_id), unk_id(unk_id), bos_id(bos_id), eos_id(eos_id) {} -Status BPEState::load(const std::string &file_name) { - char2id.clear(); - rules.clear(); - std::ifstream fin(file_name, std::ios::in); - if (fin.fail()) { - return Status(1, "Can not open file with model: " + file_name); - } - int n, m; - fin >> n >> m; - for (int i = 0; i < n; i++) { - uint32_t inner_id; - uint32_t utf32_id; - fin >> inner_id >> utf32_id; - char2id[inner_id] = utf32_id; - } - for (int i = 0; i < m; i++) { - uint32_t x, y, z; - fin >> x >> y >> z; - rules.emplace_back(x, y, z); +std::vector read_all_lines(std::istream& stream) { + std::vector sentences; + std::string s; + while (std::getline(stream, s)) { + sentences.push_back(std::move(s)); } - special_tokens.load(fin); - fin.close(); - return Status(); -} - -BpeConfig::BpeConfig(double _character_coverage, int _n_threads, - const SpecialTokens &_special_tokens) - : character_coverage(_character_coverage), - n_threads(_n_threads), - special_tokens(_special_tokens) {} - -bool is_space(uint32_t ch) { - return (ch < 256 && isspace(ch)) || (ch == SPACE_TOKEN); + return sentences; } -std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed) { +std::vector read_lines(std::istream& stream, uint64_t batch_limit, uint64_t *processed) { std::vector sentences; std::string s; - while (*processed < batch_limit && getline(std::cin, s)) { + while (*processed < batch_limit && std::getline(stream, s)) { *processed += s.size(); sentences.push_back(std::move(s)); } return sentences; } -Status::Status(int code, std::string message) : code(code), message(std::move(message)) {} - -const std::string &Status::error_message() const { - return message; -} -bool Status::ok() const { - return code == 0; +Status fast_read_file_utf8(const std::string &file_name, std::string *file_content) { + static const int buf_size = 1000000; + *file_content = ""; + // TODO: use ifstream and seekg+tellg+seekg to reserve + auto fin = fopen(file_name.data(), "rb"); + if (fin == nullptr) { + return Status(1, "Failed to open file: " + file_name); + } + while (true) { + uint64_t cur_size = file_content->size(); + file_content->resize(cur_size + buf_size); + int buf_len = fread((void *)(file_content->data() + cur_size), 1, buf_size, fin); + if (buf_len < buf_size) { + file_content->resize(file_content->size() - (buf_size - buf_len)); + fclose(fin); + return Status(); + } + } } -} // namespace vkcom + +} // namespace vkcom diff --git a/youtokentome/cpp/utils.h b/youtokentome/cpp/utils.h index ce802d5..687dbf7 100644 --- a/youtokentome/cpp/utils.h +++ b/youtokentome/cpp/utils.h @@ -1,24 +1,27 @@ #pragma once #include +#include #include #include -#include "third_party/flat_hash_map.h" namespace vkcom { -const uint32_t SPACE_TOKEN = 9601; -struct BPE_Rule { - // x + y -> z - uint32_t x{0}; - uint32_t y{0}; - uint32_t z{0}; +enum OutputType { ID, SUBWORD }; - BPE_Rule() = default; +struct DecodeResult { + std::vector ids; + std::vector pieces; +}; - BPE_Rule(uint32_t x, uint32_t y, uint32_t z); +struct Status { + int code{0}; + std::string message; + Status() = default; + Status(int code, std::string message); - bool operator==(const BPE_Rule &other) const; + const std::string &error_message() const; + bool ok() const; }; struct SpecialTokens { @@ -42,64 +45,109 @@ struct SpecialTokens { uint64_t n_special_tokens() const; }; -struct BpeConfig { - double character_coverage = 1; - int n_threads = 0; - SpecialTokens special_tokens; +std::vector read_all_lines(std::istream& stream); - BpeConfig() = default; +std::vector read_lines(std::istream& stream, uint64_t batch_limit, uint64_t *processed); - BpeConfig(double character_coverage, int n_threads, - const SpecialTokens &special_tokens); -}; +Status fast_read_file_utf8(const std::string &file_name, std::string *file_content); -struct Status { - int code{0}; - std::string message; - Status() = default; - Status(int code, std::string message); +template +void write_to_stdout(const std::vector &items, bool flush) { + for (const auto &item : items) { + std::cout << item << " "; + } + std::cout << "\n"; + if (flush) { + std::cout << std::flush; + } +} - const std::string &error_message() const; - bool ok() const; -}; +template +void write_to_stdout(const std::vector> &sentences, bool flush) { + for (const auto &sentence : sentences) { + write_to_stdout(sentence, false); + } + if (flush) { + std::cout << std::flush; + } +} -struct BPEState { - flat_hash_map char2id; - std::vector rules; - SpecialTokens special_tokens; +template +struct VectorSegment { + private: + const T *begin_; + const T *end_; + uint64_t hash_; - void dump(const std::string &file_name); + public: + VectorSegment(const T *begin, const T *end, uint64_t hash) + : begin_(begin), end_(end), hash_(hash) {} - Status load(const std::string &file_name); -}; + bool operator==(const VectorSegment &other) const { + if (other.hash() != hash() || end_ - begin_ != other.end_ - other.begin_) { + return false; + } + for (auto it = begin_, other_it = other.begin_; it != end_; it++, other_it++) { + if (*it != *other_it) { + return false; + } + } + return true; + } -struct DecodeResult { - std::vector ids; - std::vector pieces; + uint64_t hash() const { return hash_; } }; -struct EncodingConfig { - bool bos; - bool eos; - bool reverse; - double dropout_prob; -}; +template +class VectorSegmentBuilder { + private: + constexpr static uint64_t MOD = 2032191299; + constexpr static uint64_t P = 726328703; + + const T *begin_; + const T *end_; + std::vector prefix_hash_; + + public: + explicit VectorSegmentBuilder(const std::vector &segment) + : VectorSegmentBuilder(segment.data(), segment.data() + segment.size()) {} + + VectorSegmentBuilder(const T *begin, const T *end) : begin_(begin), end_(end) { + using HashT = typename std::make_unsigned::type; + uint64_t hash = 0; + prefix_hash_.reserve(static_cast(end - begin)); + for (const T *it = begin_; it != end_; it++) { + hash = (hash * P + static_cast(*it)) % MOD; + prefix_hash_.push_back(hash); + } + } -bool is_space(uint32_t ch); + VectorSegment finish() const { return VectorSegment(begin_, end_, hash()); } -std::vector read_lines_from_stdin(uint64_t batch_limit, uint64_t *processed); + size_t size() const { return prefix_hash_.size(); } -template -void write_to_stdout(const std::vector> &sentences, bool flush) { - for (const auto &sentence : sentences) { - for (const auto &token : sentence) { - std::cout << token << " "; + bool empty() const { return prefix_hash_.empty(); } + + uint64_t hash() const { return prefix_hash_.empty() ? 0 : prefix_hash_.back(); } + + void pop_back() noexcept { + if (!prefix_hash_.empty()) { + prefix_hash_.pop_back(); + --end_; } - std::cout << "\n"; - } - if (flush) { - std::cout << std::flush; } -} +}; + +using BpeVectorSegment = VectorSegment; +using WordPieceVectorSegment = VectorSegment; + +} // namespace vkcom + +namespace std { + +template +struct hash> { + uint64_t operator()(const vkcom::VectorSegment &x) const { return x.hash(); } +}; -} // namespace vkcom +} // namespace std diff --git a/youtokentome/cpp/wordpiece.cpp b/youtokentome/cpp/wordpiece.cpp new file mode 100644 index 0000000..da21313 --- /dev/null +++ b/youtokentome/cpp/wordpiece.cpp @@ -0,0 +1,323 @@ +#include "wordpiece.h" + +#include +#include +#include +#include + +#include "utf8.h" + +namespace vkcom::wordpiece { + +namespace { + +const std::string UNK_TOKEN = "[UNK]"; +const std::string PAD_TOKEN = "[PAD]"; +const std::string BOS_TOKEN = "[BOS]"; +const std::string EOS_TOKEN = "[EOS]"; + +bool is_suffix_vocab(const std::vector &word) { + static const uint32_t kSharp = static_cast('#'); + return word.size() >= 2 && word[0] == kSharp && word[1] == kSharp; +} + +bool is_special_token(const std::vector &word) { + return word.size() > 2 && word[0] == static_cast('[') + && word.back() == static_cast(']'); +} + +std::vector parse_text(const char *text, size_t size, vkcom::ThreadPool &thread_pool) { + static const size_t kWorkBatch = 5000000; + + if (size < 2 * kWorkBatch) { + return vkcom::decode_utf8(text, text + size); + } + + const size_t thread_count = std::min(thread_pool.maxThreads(), size / kWorkBatch); + const size_t work_batch = size / thread_count + 1; + std::vector> per_thread_text_utf8(thread_count); + size_t work_start = 0; + for (size_t thread_id = 0; thread_id < thread_count && work_start < size; thread_id++) { + size_t work_end = std::min(size, work_start + work_batch); + while (work_end < size && !vkcom::check_symbol_start(text[work_end])) { + ++work_end; + } + thread_pool.submit([thread_id, work_start, work_end, text, &per_thread_text_utf8] { + const char *begin = text + work_start; + const size_t len = work_end - work_start; + per_thread_text_utf8[thread_id] = vkcom::decode_utf8(begin, begin + len); + }); + work_start = work_end; + } + + thread_pool.waitCompletion(); + size_t text_utf8_size = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + text_utf8_size += per_thread_text_utf8[thread_id].size(); + } + std::vector text_utf8(text_utf8_size); + text_utf8.resize(text_utf8_size); + work_start = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + std::vector &segment = per_thread_text_utf8[thread_id]; + if (!segment.empty()) { + std::memcpy(text_utf8.data() + work_start, segment.data(), segment.size() * sizeof(uint32_t)); + work_start += segment.size(); + } + } + + return text_utf8; +} + +std::vector read_lines_helper(const std::string &filename) { + std::ifstream fin(filename); + return read_all_lines(fin); +} + +} // namespace + +WordPieceToken::WordPieceToken(const std::string &encoded_word) + : is_prefix(true), is_special(false), is_malformed(false), word(vkcom::decode_utf8(encoded_word)) { + if (is_suffix_vocab(word)) { + is_prefix = false; + word.erase(word.begin(), word.begin() + 2); + } else if (is_special_token(word)) { + is_special = true; + } + + bool all_punctuation = true; + for (uint32_t code_point : word) { + if (code_point == vkcom::INVALID_UNICODE) { + is_malformed = true; + } + if (!vkcom::is_punctuation(code_point) && !vkcom::is_space(code_point)) { + all_punctuation = false; + } + } + if (word.empty()) { + throw std::runtime_error("Vocab word is empty"); + } + if (is_malformed || (all_punctuation && word.size() > 1)) { + is_malformed = true; + std::cerr << "Vocab word is malformed: " << encoded_word << std::endl; + } +} + +WordPieceVocabulary::WordPieceVocabulary(const std::vector &words) { + tokens.reserve(words.size()); + int token_id = 0; + max_token_len = 0; + for (const std::string &word : words) { + update_special_tokens(word, token_id); + WordPieceToken token(word); + max_token_len = std::max(max_token_len, token.word.size()); + tokens.push_back(std::move(token)); + ++token_id; + } +} + +void WordPieceVocabulary::update_special_tokens(const std::string &word, int token_id) { + if (word == UNK_TOKEN) { + special_tokens.unk_id = token_id; + } else if (word == PAD_TOKEN) { + special_tokens.pad_id = token_id; + } else if (word == BOS_TOKEN) { + special_tokens.bos_id = token_id; + } else if (word == EOS_TOKEN) { + special_tokens.eos_id = token_id; + } +} + +Encoder::Encoder(const std::string &vocab_path, int n_threads) + : Encoder(read_lines_helper(vocab_path), n_threads) {} + +Encoder::Encoder(std::vector vocab, int n_threads) + : vocab_(std::move(vocab)), word_piece_vocab_(vocab_), thread_pool_(n_threads) { + build_word_maps(); +} + +Status Encoder::encode_as_ids(const std::string &text_path, std::vector *ids) const { + try { + std::string text_str; + Status status = fast_read_file_utf8(text_path, &text_str); + if (!status.ok()) { + return status; + } + const std::vector text = parse_text(text_str.data(), text_str.size(), thread_pool_); + *ids = encode_parallel(text); + return Status(); + } catch (const std::exception &ex) { + return Status(1, ex.what()); + } catch (...) { + return Status(1, "Unknown error"); + } +} + +Status Encoder::encode_as_subwords(const std::string &text_path, + std::vector *subwords) const { + try { + std::string text_str; + Status status = fast_read_file_utf8(text_path, &text_str); + if (!status.ok()) { + return status; + } + const std::vector text = parse_text(text_str.data(), text_str.size(), thread_pool_); + std::vector ids = encode_parallel(text); + for (int id : ids) { + subwords->push_back(vocab_[id]); + } + return Status(); + } catch (const std::exception &ex) { + return Status(1, ex.what()); + } catch (...) { + return Status(1, "Unknown error"); + } +} + +Status Encoder::decode(const std::vector &ids, + std::vector *subwords, + const std::unordered_set *ignore_ids) const { + try { + for (int id : ids) { + if (!ignore_ids || ignore_ids->count(id) == 0) { + subwords->push_back(vocab_.at(id)); + } + } + return Status(); + } catch (const std::exception &ex) { + return Status(1, ex.what()); + } catch (...) { + return Status(1, "Unknown error"); + } +} + +bool Encoder::is_word_prefix(const std::vector &text, size_t index) { + return index == 0 || vkcom::is_spacing_char(text[index]) + || vkcom::is_spacing_char(text[index - 1]); +} + +void Encoder::build_word_maps() { + for (size_t i = 0; i < word_piece_vocab_.tokens.size(); i++) { + const auto &token = word_piece_vocab_.tokens[i]; + if (token.is_special || token.is_malformed) { + continue; + } + vkcom::VectorSegmentBuilder segment(token.word); + WordMap *word_to_id = token.is_prefix ? &prefix_to_id_ : &suffix_to_id_; + (*word_to_id)[segment.finish()] = static_cast(i); + } +} + +std::vector Encoder::encode_parallel(const std::vector &text) const { + static const size_t kWorkBatch = 1000000; + + if (text.size() < 2 * kWorkBatch) { + return encode_impl(text, 0, text.size()); + } + + const size_t thread_count = std::min(thread_pool_.maxThreads(), text.size() / kWorkBatch); + const size_t work_batch = text.size() / thread_count + 1; + std::vector> per_thread_token_ids(thread_count); + size_t work_begin = 0; + for (size_t thread_id = 0; thread_id < thread_count && work_begin < text.size(); thread_id++) { + size_t work_end = std::min(text.size(), work_begin + work_batch); + while (work_end < text.size() && !vkcom::is_space(text[work_end])) { + ++work_end; + } + thread_pool_.submit([this, thread_id, work_begin, work_end, &per_thread_token_ids, &text] { + per_thread_token_ids[thread_id] = encode_impl(text, work_begin, work_end); + }); + work_begin = work_end; + } + + thread_pool_.waitCompletion(); + + size_t token_count = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + token_count += per_thread_token_ids[thread_id].size(); + } + std::vector token_ids(token_count); + work_begin = 0; + for (size_t thread_id = 0; thread_id < thread_count; thread_id++) { + std::vector &segment = per_thread_token_ids[thread_id]; + if (!segment.empty()) { + std::memcpy(token_ids.data() + work_begin, segment.data(), segment.size() * sizeof(int)); + work_begin += segment.size(); + } + } + + return token_ids; +} + +std::vector +Encoder::encode_impl(const std::vector &text, size_t begin, size_t end) const { + size_t max_len = std::min(word_piece_vocab_.max_token_len, end - begin); + if (begin == end) { + return {}; + } + if (word_piece_vocab_.tokens.empty()) { + throw std::runtime_error("abc"); + } + if (max_len == 0) { + throw std::runtime_error("her"); + } + const int unk_token_id = word_piece_vocab_.special_tokens.unk_id; + + std::vector token_ids; + token_ids.reserve((end - begin) / max_len + 1); + + while (begin != end && vkcom::is_space(text[begin])) { + ++begin; + } + + size_t tokens_since_prefix = 0; + + while (begin != end) { + size_t word_len = 1; + if (!vkcom::is_punctuation(text[begin])) { + while (word_len < std::min(max_len, end - begin) + && !vkcom::is_spacing_char(text[begin + word_len])) { + ++word_len; + } + } + + const uint32_t *segment_begin = text.data() + static_cast(begin); + const uint32_t *segment_end = segment_begin + static_cast(word_len); + const WordMap *word_to_id = is_word_prefix(text, begin) ? &prefix_to_id_ : &suffix_to_id_; + + vkcom::VectorSegmentBuilder segment(segment_begin, segment_end); + while (!segment.empty()) { + auto it = word_to_id->find(segment.finish()); + if (it != word_to_id->end()) { + ++tokens_since_prefix; + token_ids.push_back(it->second); + begin += segment.size(); + break; + } else { + segment.pop_back(); + } + } + + if (segment.empty()) { + while (tokens_since_prefix > 0) { + token_ids.pop_back(); + --tokens_since_prefix; + } + token_ids.push_back(unk_token_id); + begin += word_len; + while (begin != end && !is_word_prefix(text, begin)) { + ++begin; + } + } else if (begin != end && is_word_prefix(text, begin)) { + tokens_since_prefix = 0; + } + + while (begin != end && vkcom::is_space(text[begin])) { + ++begin; + } + } + + return token_ids; +} + +} // namespace vkcom::wordpiece diff --git a/youtokentome/cpp/wordpiece.h b/youtokentome/cpp/wordpiece.h new file mode 100644 index 0000000..8ff7e34 --- /dev/null +++ b/youtokentome/cpp/wordpiece.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#include +#include +#include + +#include "third_party/thread_pool/thread_pool.h" +#include "utils.h" + +namespace vkcom::wordpiece { + +struct WordPieceToken { + explicit WordPieceToken(const std::string &encoded_word); + + bool is_prefix; + bool is_special; + bool is_malformed; + std::vector word; +}; + +struct WordPieceVocabulary { + explicit WordPieceVocabulary(const std::vector &words); + + std::vector tokens; + vkcom::SpecialTokens special_tokens; + size_t max_token_len = 0; + + private: + void update_special_tokens(const std::string &word, int token_id); +}; + +class Encoder { + public: + explicit Encoder(const std::string &vocab_path, int n_threads); + + explicit Encoder(std::vector vocab, int n_threads); + + Status encode_as_ids(const std::string &text_path, std::vector *ids) const; + + Status encode_as_subwords(const std::string &text_path, std::vector *subwords) const; + + Status decode(const std::vector &ids, + std::vector *subwords, + const std::unordered_set *ignore_ids) const; + + Status id_to_subword(int id, std::string *subword) const; + + int subword_to_id(const std::string &token) const; + + private: + static bool is_word_prefix(const std::vector &text, size_t index); + + void build_word_maps(); + + std::vector encode_parallel(const std::vector &text) const; + std::vector encode_impl(const std::vector &text, size_t begin, size_t end) const; + + std::vector vocab_; + WordPieceVocabulary word_piece_vocab_; + + // TODO: flat_hash_map ? + using WordMap = std::unordered_map; + WordMap prefix_to_id_; // no ## in word prefix + WordMap suffix_to_id_; // ## in word prefix + + mutable ThreadPool thread_pool_; +}; + +} // namespace vkcom::wordpiece diff --git a/youtokentome/cpp/yttm.pyx b/youtokentome/cpp/yttm.pyx index 1d7774d..f0a4f07 100644 --- a/youtokentome/cpp/yttm.pyx +++ b/youtokentome/cpp/yttm.pyx @@ -2,13 +2,11 @@ from libcpp.vector cimport vector from libcpp.unordered_set cimport unordered_set from libcpp.string cimport string from libcpp cimport bool -import os from pathlib import Path from typing import Collection -cdef extern from "bpe.h" namespace "vkcom": - +cdef extern from "utils.h" namespace "vkcom": cdef cppclass SpecialTokens: int pad_id int unk_id @@ -26,31 +24,29 @@ cdef extern from "bpe.h" namespace "vkcom": cdef extern from "bpe.h" namespace "vkcom": - Status train_bpe(const string &source_path, const string& model_path, int vocab_size, const BpeConfig& bpe_config) + Status bpe_train(const string &source_path, const string &model_path, int vocab_size, const BpeConfig &bpe_config) cdef extern from "bpe.h" namespace "vkcom": cdef cppclass BaseEncoder: - BaseEncoder(const string& model_path, int n_threads, Status* status) - - Status encode_as_ids(const vector[string] &sentences, vector[vector[int]]* ids, bool bos, bool eos, bool reverse, double dropout_prob) const - Status encode_as_subwords(const vector[string]& sentences, vector[vector[string]]* subwords, bool bos, bool eos, bool reverse, double dropout_prob) const + BaseEncoder(const string &model_path, int n_threads, Status *status) + Status encode_as_ids(const vector[string] &sentences, vector[vector[int]] *ids, bool bos, bool eos, bool reverse, double dropout_prob) const + Status encode_as_subwords(const vector[string] &sentences, vector[vector[string]] *subwords, bool bos, bool eos, bool reverse, double dropout_prob) const Status encode_cli(string output_type, bool stream, bool bos, bool eos, bool reverse, double dropout_prob) const - Status decode_cli(const unordered_set[int]* ignore_ids) const - - void vocab_cli(bool verbose) const - - Status id_to_subword(int id, string* subword) const + Status decode(const vector[vector[int]] &ids, vector[string] *output, const unordered_set[int] *ignore_ids) const + Status decode_cli(const unordered_set[int] *ignore_ids) const + Status id_to_subword(int id, string *subword) const int subword_to_id(const string &subword) const - Status decode(const vector[vector[int]]& ids, vector[string]* output, const unordered_set[int]* ignore_ids) const + int vocab_size() const vector[string] vocabulary() const + void vocab_cli(bool verbose) const cdef class BPE: - cdef BaseEncoder* encoder + cdef BaseEncoder *encoder def __dealloc__(self): del self.encoder @@ -80,7 +76,7 @@ cdef class BPE: bpe_config.special_tokens.bos_id = bos_id bpe_config.special_tokens.eos_id = eos_id - cdef Status status = train_bpe(data.encode(), model.encode(), vocab_size, bpe_config) + cdef Status status = bpe_train(data.encode(), model.encode(), vocab_size, bpe_config) if status.code != 0: raise ValueError(status.message.decode()) @@ -134,7 +130,6 @@ cdef class BPE: return subword.decode() def decode(self, ids, ignore_ids): - if not isinstance(ids, list): raise TypeError( "{} is not a list instance".format(type(ids)) @@ -180,3 +175,51 @@ cdef class BPE: def vocab_cli(self, verbose): self.encoder.vocab_cli(verbose) +cdef extern from "wordpiece.h" namespace "vkcom::wordpiece": + cdef cppclass Encoder: + Encoder(const string &vocab_path, int n_threads) + + Status encode_as_ids(const string &text_path, vector[int] *ids) const + + Status encode_as_subwords(const string &text_path, vector[string] *subwords) const + + Status decode(const vector[int] &ids, vector[string] *subwords, const unordered_set[int] *ignore_ids) const + + Status id_to_subword(int id, string *subword) const + int subword_to_id(const string &subword) const + +cdef class WordPiece: + cdef Encoder *encoder + + def __dealloc__(self): + del self.encoder + + def __init__(self, vocab_path, n_threads=0): + self.encoder = new Encoder(vocab_path.encode(), n_threads) + + def encode(self, text_path, output_type): + cdef Status status + cdef vector[int] ids + cdef vector[string] subwords + if output_type == 'id': + status = self.encoder.encode_as_ids(text_path.encode(), &ids) + if status.code != 0: + raise ValueError(status.message.decode()) + return ids + elif output_type == 'subword': + status = self.encoder.encode_as_subwords(text_path.encode(), &subwords) + if status.code != 0: + raise ValueError(status.message.decode()) + return subwords + else: + raise ValueError('output_type must be equal to "id" or "subword"') + + def decode(self, ids, ignore_ids): + if ignore_ids is None: + ignore_ids = set() + cdef unordered_set[int] c_ignore_ids = unordered_set[int](ignore_ids) + cdef vector[string] subwords + cdef Status status = self.encoder.decode(ids, &subwords, &c_ignore_ids) + if status.code != 0: + raise ValueError(status.message.decode()) + return subwords diff --git a/youtokentome/youtokentome.py b/youtokentome/youtokentome.py index 593febf..7b3ded6 100644 --- a/youtokentome/youtokentome.py +++ b/youtokentome/youtokentome.py @@ -97,3 +97,40 @@ def __setstate__(self, dict): self.bpe_cython = _youtokentome_cython.BPE( model_path=self.model, n_threads=self.n_threads ) + + +class WordPiece: + def __init__(self, vocab_path: str, n_threads: int = 0): + self.vocab_path = vocab_path + self.n_threads = n_threads + + self.word_piece_cython = _youtokentome_cython.WordPiece( + vocab_path=vocab_path, n_threads=n_threads + ) + + def encode( + self, + text_path: str, + output_type: OutputType = OutputType.ID + ) -> Union[List[List[int]], List[List[str]]]: + output_type_str = "id" if output_type == OutputType.ID else "subword" + return self.word_piece_cython.encode(text_path, output_type_str) + + def decode( + self, + ids: List[int], + ignore_ids: Optional[Collection[int]] = None + ) -> List[str]: + return self.word_piece_cython.decode(ids, ignore_ids) + + def __getstate__(self): + return {"vocab_path": self.vocab_path, "n_threads": self.n_threads} + + def __setstate__(self, dict): + self.vocab_path = dict["vocab_path"] + self.n_threads = dict["n_threads"] + + self.word_piece_cython = _youtokentome_cython.WordPiece( + vocab_path=vocab_path, n_threads=self.n_threads + ) + diff --git a/youtokentome/yttm_cli.py b/youtokentome/yttm_cli.py index 7e66879..4ae8c2e 100644 --- a/youtokentome/yttm_cli.py +++ b/youtokentome/yttm_cli.py @@ -57,7 +57,7 @@ def main(): default=3, show_default=True, ) -def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eos_id): +def bpe_train(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eos_id): """Train BPE model.""" yttmc.BPE.train( data=data, @@ -105,7 +105,7 @@ def bpe(data, model, vocab_size, coverage, n_threads, pad_id, unk_id, bos_id, eo show_default=True, help="BPE-dropout probability (the probability of a merge being dropped)", ) -def encode(model, output_type, n_threads, bos, eos, reverse, stream, dropout_prob): +def bpe_encode(model, output_type, n_threads, bos, eos, reverse, stream, dropout_prob): """Encode text to ids or subwords.""" if n_threads < -1 or n_threads == 0: raise ValueError( @@ -143,7 +143,7 @@ def validate_ignore_ids(ctx, param, value): required=False, help="List of indices to ignore for decoding. Example: --ignore_ids=1,2,3", ) -def decode(model, ignore_ids): +def bpe_decode(model, ignore_ids): """Decode ids to text.""" bpe = yttmc.BPE(model) bpe.decode_cli(ignore_ids) @@ -157,13 +157,13 @@ def decode(model, ignore_ids): help="Path to file with learned model.", ) @click.option("--verbose", is_flag=True, help="Add merging rules.") -def vocab(model, verbose): +def bpe_vocab(model, verbose): """Print list of learned subwords.""" bpe = yttmc.BPE(model) bpe.vocab_cli(verbose) -main.add_command(bpe) -main.add_command(encode) -main.add_command(decode) -main.add_command(vocab) +main.add_command(bpe_train) +main.add_command(bpe_encode) +main.add_command(bpe_decode) +main.add_command(bpe_vocab)