From 836921fdeb498820b71dcc7b70e990e828f4c6bc Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Mon, 4 Mar 2024 18:49:02 +0100 Subject: [PATCH] Add UDOP (#22940) * First draft * More improvements * More improvements * More fixes * Fix copies * More improvements * More fixes * More improvements * Convert checkpoint * More improvements, set up tests * Fix more tests * Add UdopModel * More improvements * Fix equivalence test * More fixes * Redesign model * Extend conversion script * Use real inputs for conversion script * Add image processor * Improve conversion script * Add UdopTokenizer * Add fast tokenizer * Add converter * Update README's * Add processor * Add fully fledged tokenizer * Add fast tokenizer * Use processor in conversion script * Add tokenizer tests * Fix one more test * Fix more tests * Fix tokenizer tests * Enable fast tokenizer tests * Fix more tests * Fix additional_special_tokens of fast tokenizer * Fix tokenizer tests * Fix more tests * Fix equivalence test * Rename image to pixel_values * Rename seg_data to bbox * More renamings * Remove vis_special_token * More improvements * Add docs * Fix copied from * Update slow tokenizer * Update fast tokenizer design * Make text input optional * Add first draft of processor tests * Fix more processor tests * Fix decoder_start_token_id * Fix test_initialization * Add integration test * More improvements * Improve processor, add test * Add more copied from * Add more copied from * Add more copied from * Add more copied from * Remove print statement * Update README and auto mapping * Delete files * Delete another file * Remove code * Fix test * Fix docs * Remove asserts * Add doc tests * Include UDOP in exotic model tests * Add expected tesseract decodings * Add sentencepiece * Use same design as T5 * Add UdopEncoderModel * Add UdopEncoderModel to tests * More fixes * Fix fast tokenizer * Fix one more test * Remove parallelisable attribute * Fix copies * Remove legacy file * Copy from T5Tokenizer * Fix rebase * More fixes, copy from T5 * More fixes * Fix init * Use ArthurZ/udop for tests * Make all model tests pass * Remove UdopForConditionalGeneration from auto mapping * Fix more tests * fixups * more fixups * fix the tokenizers * remove un-necessary changes * nits * nits * replace truncate_sequences_boxes with truncate_sequences for fix-copies * nit current path * add a test for input ids * ids that we should get taken from c9f7a32f57440d90ff79890270d376a1cc0acb68 * nits converting * nits * apply ruff * nits * nits * style * fix slow order of addition * fix udop fast range as well * fixup * nits * Add docstrings * Fix gradient checkpointing * Update code examples * Skip tests * Update integration test * Address comment * Make fixup * Remove extra ids from tokenizer * Skip test * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update year * Address comment * Address more comments * Address comments * Add copied from * Update CI * Rename script * Update model id * Add AddedToken, skip tests * Update CI * Fix doc tests * Do not use Tesseract for the doc tests * Remove kwargs * Add original inputs * Update casting * Fix doc test * Update question * Update question * Use LayoutLMv3ImageProcessor * Update organization * Improve docs * Update forward signature * Make images optional * Remove deprecated device argument * Add comment, add add_prefix_space * More improvements * Remove kwargs --------- Co-authored-by: ArthurZucker Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .circleci/create_circleci_config.py | 2 + README.md | 1 + README_es.md | 1 + README_fr.md | 1 + README_hd.md | 1 + README_ja.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + docs/source/en/_toctree.yml | 2 + docs/source/en/index.md | 1 + docs/source/en/model_doc/udop.md | 102 + src/transformers/__init__.py | 26 + src/transformers/convert_slow_tokenizer.py | 12 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + .../models/auto/image_processing_auto.py | 1 + src/transformers/models/auto/modeling_auto.py | 1 + .../models/auto/tokenization_auto.py | 7 + src/transformers/models/udop/__init__.py | 98 + .../models/udop/configuration_udop.py | 162 ++ .../models/udop/convert_udop_to_hf.py | 213 ++ src/transformers/models/udop/modeling_udop.py | 2030 +++++++++++++++++ .../models/udop/processing_udop.py | 204 ++ .../models/udop/tokenization_udop.py | 1483 ++++++++++++ .../models/udop/tokenization_udop_fast.py | 1012 ++++++++ src/transformers/utils/dummy_pt_objects.py | 31 + .../utils/dummy_sentencepiece_objects.py | 7 + .../utils/dummy_tokenizers_objects.py | 7 + tests/models/udop/__init__.py | 0 tests/models/udop/test_modeling_udop.py | 567 +++++ tests/models/udop/test_processor_udop.py | 508 +++++ tests/models/udop/test_tokenization_udop.py | 1886 +++++++++++++++ utils/check_config_attributes.py | 2 + utils/check_repo.py | 2 + 35 files changed, 8378 insertions(+) create mode 100644 docs/source/en/model_doc/udop.md create mode 100644 src/transformers/models/udop/__init__.py create mode 100644 src/transformers/models/udop/configuration_udop.py create mode 100644 src/transformers/models/udop/convert_udop_to_hf.py create mode 100644 src/transformers/models/udop/modeling_udop.py create mode 100644 src/transformers/models/udop/processing_udop.py create mode 100644 src/transformers/models/udop/tokenization_udop.py create mode 100644 src/transformers/models/udop/tokenization_udop_fast.py create mode 100644 tests/models/udop/__init__.py create mode 100644 tests/models/udop/test_modeling_udop.py create mode 100644 tests/models/udop/test_processor_udop.py create mode 100644 tests/models/udop/test_tokenization_udop.py diff --git a/.circleci/create_circleci_config.py b/.circleci/create_circleci_config.py index 7f271ff0819f78..45a58737a8ddff 100644 --- a/.circleci/create_circleci_config.py +++ b/.circleci/create_circleci_config.py @@ -475,6 +475,7 @@ def job_name(self): "pip install -U --upgrade-strategy eager 'git+https://github.com/facebookresearch/detectron2.git'", "sudo apt install tesseract-ocr", "pip install -U --upgrade-strategy eager pytesseract", + "pip install --upgrade-strategy eager sentencepiece", "pip install -U --upgrade-strategy eager natten==0.15.1+torch210cpu -f https://shi-labs.com/natten/wheels", "pip install -U --upgrade-strategy eager python-Levenshtein", "pip install -U --upgrade-strategy eager opencv-python", @@ -485,6 +486,7 @@ def job_name(self): "tests/models/*layoutlmv*", "tests/models/*nat", "tests/models/deta", + "tests/models/udop", "tests/models/nougat", ], pytest_num_workers=1, diff --git a/README.md b/README.md index 54e228a1150266..30f7cd08a77643 100644 --- a/README.md +++ b/README.md @@ -511,6 +511,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. 1. **[TVLT](https://huggingface.co/docs/transformers/model_doc/tvlt)** (from UNC Chapel Hill) released with the paper [TVLT: Textless Vision-Language Transformer](https://arxiv.org/abs/2209.14156) by Zineng Tang, Jaemin Cho, Yixin Nie, Mohit Bansal. 1. **[TVP](https://huggingface.co/docs/transformers/model_doc/tvp)** (from Intel) released with the paper [Text-Visual Prompting for Efficient 2D Temporal Video Grounding](https://arxiv.org/abs/2303.04995) by Yimeng Zhang, Xin Chen, Jinghan Jia, Sijia Liu, Ke Ding. +1. **[UDOP](https://huggingface.co/docs/transformers/main/model_doc/udop)** (from Microsoft Research) released with the paper [Unifying Vision, Text, and Layout for Universal Document Processing](https://arxiv.org/abs/2212.02623) by Zineng Tang, Ziyi Yang, Guoxin Wang, Yuwei Fang, Yang Liu, Chenguang Zhu, Michael Zeng, Cha Zhang, Mohit Bansal. 1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (from Google Research) released with the paper [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) by Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler 1. **[UMT5](https://huggingface.co/docs/transformers/model_doc/umt5)** (from Google Research) released with the paper [UniMax: Fairer and More Effective Language Sampling for Large-Scale Multilingual Pretraining](https://openreview.net/forum?id=kXwdL1cWOAi) by Hyung Won Chung, Xavier Garcia, Adam Roberts, Yi Tay, Orhan Firat, Sharan Narang, Noah Constant. 1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang. diff --git a/README_es.md b/README_es.md index b3c6845000d2b4..6e808e0e2b1cf1 100644 --- a/README_es.md +++ b/README_es.md @@ -484,6 +484,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. 1. **[TVLT](https://huggingface.co/docs/transformers/model_doc/tvlt)** (from UNC Chapel Hill) released with the paper [TVLT: Textless Vision-Language Transformer](https://arxiv.org/abs/2209.14156) by Zineng Tang, Jaemin Cho, Yixin Nie, Mohit Bansal. 1. **[TVP](https://huggingface.co/docs/transformers/model_doc/tvp)** (from Intel) released with the paper [Text-Visual Prompting for Efficient 2D Temporal Video Grounding](https://arxiv.org/abs/2303.04995) by Yimeng Zhang, Xin Chen, Jinghan Jia, Sijia Liu, Ke Ding. +1. **[UDOP](https://huggingface.co/docs/transformers/main/model_doc/udop)** (from Microsoft Research) released with the paper [Unifying Vision, Text, and Layout for Universal Document Processing](https://arxiv.org/abs/2212.02623) by Zineng Tang, Ziyi Yang, Guoxin Wang, Yuwei Fang, Yang Liu, Chenguang Zhu, Michael Zeng, Cha Zhang, Mohit Bansal. 1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (from Google Research) released with the paper [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) by Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler 1. **[UMT5](https://huggingface.co/docs/transformers/model_doc/umt5)** (from Google Research) released with the paper [UniMax: Fairer and More Effective Language Sampling for Large-Scale Multilingual Pretraining](https://openreview.net/forum?id=kXwdL1cWOAi) by Hyung Won Chung, Xavier Garcia, Adam Roberts, Yi Tay, Orhan Firat, Sharan Narang, Noah Constant. 1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang. diff --git a/README_fr.md b/README_fr.md index 4b87eba5bbe1ba..3bd57830076a5f 100644 --- a/README_fr.md +++ b/README_fr.md @@ -505,6 +505,7 @@ Nombre actuel de points de contrôle : ![](https://img.shields.io/endpoint?url=h 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (de Microsoft), publié dans l'article [TrOCR : Reconnaissance optique de caractères basée sur un transformateur avec des modèles pré-entraînés](https://arxiv.org/abs/2109.10282) par Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. 1. **[TVLT](https://huggingface.co/docs/transformers/model_doc/tvlt)** (de l'UNC Chapel Hill) a été publié dans l'article [TVLT : Transformer Vision-Language sans texte](https://arxiv.org/abs/2209.14156) par Zineng Tang, Jaemin Cho, Yixin Nie, Mohit Bansal. 1. **[TVP](https://huggingface.co/docs/transformers/model_doc/tvp)** (d'Intel) a été publié dans l'article [Text-Visual Prompting for Efficient 2D Temporal Video Grounding](https://arxiv.org/abs/2303.04995) par Yimeng Zhang, Xin Chen, Jinghan Jia, Sijia Liu, Ke Ding. +1. **[UDOP](https://huggingface.co/docs/transformers/main/model_doc/udop)** (de Microsoft Research) publié dans l'article [Unifying Vision, Text, and Layout for Universal Document Processing](https://arxiv.org/abs/2212.02623) parZineng Tang, Ziyi Yang, Guoxin Wang, Yuwei Fang, Yang Liu, Chenguang Zhu, Michael Zeng, Cha Zhang, Mohit Bansal. 1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (de Google Research) a été publié dans l'article [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) par Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler. 1. **[UMT5](https://huggingface.co/docs/transformers/model_doc/umt5)** (de Google Research) a été publié dans l'article [UniMax : Échantillonnage linguistique plus équitable et plus efficace pour l'entraînement préalable multilingue à grande échelle](https://openreview.net/forum?id=kXwdL1cWOAi) par Hyung Won Chung, Xavier Garcia, Adam Roberts, Yi Tay, Orhan Firat, Sharan Narang, Noah Constant. 1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (de Microsoft Research) a été publié dans l'article [UniSpeech : Apprentissage unifié de la représentation de la parole avec des données étiquetées et non étiquetées](https://arxiv.org/abs/2101.07597) par Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang. diff --git a/README_hd.md b/README_hd.md index e68d9d39ba6242..0353eb4d8fbda6 100644 --- a/README_hd.md +++ b/README_hd.md @@ -458,6 +458,7 @@ conda install conda-forge::transformers 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft) released with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. 1. **[TVLT](https://huggingface.co/docs/transformers/model_doc/tvlt)** (from UNC Chapel Hill) released with the paper [TVLT: Textless Vision-Language Transformer](https://arxiv.org/abs/2209.14156) by Zineng Tang, Jaemin Cho, Yixin Nie, Mohit Bansal. 1. **[TVP](https://huggingface.co/docs/transformers/model_doc/tvp)** (from Intel) released with the paper [Text-Visual Prompting for Efficient 2D Temporal Video Grounding](https://arxiv.org/abs/2303.04995) by Yimeng Zhang, Xin Chen, Jinghan Jia, Sijia Liu, Ke Ding. +1. **[UDOP](https://huggingface.co/docs/transformers/main/model_doc/udop)** (Microsoft Research से) Zineng Tang, Ziyi Yang, Guoxin Wang, Yuwei Fang, Yang Liu, Chenguang Zhu, Michael Zeng, Cha Zhang, Mohit Bansal. द्वाराअनुसंधान पत्र [Unifying Vision, Text, and Layout for Universal Document Processing](https://arxiv.org/abs/2212.02623) के साथ जारी किया गया 1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (from Google Research) released with the paper [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) by Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler 1. **[UMT5](https://huggingface.co/docs/transformers/model_doc/umt5)** (Google Research से) Hyung Won Chung, Xavier Garcia, Adam Roberts, Yi Tay, Orhan Firat, Sharan Narang, Noah Constant. द्वाराअनुसंधान पत्र [UniMax: Fairer and More Effective Language Sampling for Large-Scale Multilingual Pretraining](https://openreview.net/forum?id=kXwdL1cWOAi) के साथ जारी किया गया 1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (माइक्रोसॉफ्ट रिसर्च से) साथ में दिया गया पेपर [UniSpeech: यूनिफाइड स्पीच रिप्रेजेंटेशन लर्निंग विद लेबलेड एंड अनलेबल्ड डेटा](https://arxiv.org/abs/2101.07597) चेंगई वांग, यू वू, याओ कियान, केनिची कुमातानी, शुजी लियू, फुरु वेई, माइकल ज़ेंग, ज़ुएदोंग हुआंग द्वारा। diff --git a/README_ja.md b/README_ja.md index d314b07140f504..599865ab5a7d49 100644 --- a/README_ja.md +++ b/README_ja.md @@ -518,6 +518,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (Microsoft から), Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei から公開された研究論文: [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) 1. **[TVLT](https://huggingface.co/docs/transformers/model_doc/tvlt)** (from UNC Chapel Hill から), Zineng Tang, Jaemin Cho, Yixin Nie, Mohit Bansal から公開された研究論文: [TVLT: Textless Vision-Language Transformer](https://arxiv.org/abs/2209.14156) 1. **[TVP](https://huggingface.co/docs/transformers/model_doc/tvp)** (Intel から), Yimeng Zhang, Xin Chen, Jinghan Jia, Sijia Liu, Ke Ding から公開された研究論文: [Text-Visual Prompting for Efficient 2D Temporal Video Grounding](https://arxiv.org/abs/2303.04995) +1. **[UDOP](https://huggingface.co/docs/transformers/main/model_doc/udop)** (Microsoft Research から) Zineng Tang, Ziyi Yang, Guoxin Wang, Yuwei Fang, Yang Liu, Chenguang Zhu, Michael Zeng, Cha Zhang, Mohit Bansal. から公開された研究論文 [Unifying Vision, Text, and Layout for Universal Document Processing](https://arxiv.org/abs/2212.02623) 1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (Google Research から) Yi Tay, Mostafa Dehghani, Vinh Q から公開された研究論文: [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler 1. **[UMT5](https://huggingface.co/docs/transformers/model_doc/umt5)** (Google Research から) Hyung Won Chung, Xavier Garcia, Adam Roberts, Yi Tay, Orhan Firat, Sharan Narang, Noah Constant. から公開された研究論文 [UniMax: Fairer and More Effective Language Sampling for Large-Scale Multilingual Pretraining](https://openreview.net/forum?id=kXwdL1cWOAi) 1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (Microsoft Research から) Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang から公開された研究論文: [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) diff --git a/README_ko.md b/README_ko.md index f8679087ad1787..e48159c7999339 100644 --- a/README_ko.md +++ b/README_ko.md @@ -433,6 +433,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (Microsoft 에서) Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei 의 [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) 논문과 함께 발표했습니다. 1. **[TVLT](https://huggingface.co/docs/transformers/model_doc/tvlt)** (from UNC Chapel Hill 에서) Zineng Tang, Jaemin Cho, Yixin Nie, Mohit Bansal 의 [TVLT: Textless Vision-Language Transformer](https://arxiv.org/abs/2209.14156) 논문과 함께 발표했습니다. 1. **[TVP](https://huggingface.co/docs/transformers/model_doc/tvp)** (Intel 에서) Yimeng Zhang, Xin Chen, Jinghan Jia, Sijia Liu, Ke Ding 의 [Text-Visual Prompting for Efficient 2D Temporal Video Grounding](https://arxiv.org/abs/2303.04995) 논문과 함께 발표했습니다. +1. **[UDOP](https://huggingface.co/docs/transformers/main/model_doc/udop)** (Microsoft Research 에서 제공)은 Zineng Tang, Ziyi Yang, Guoxin Wang, Yuwei Fang, Yang Liu, Chenguang Zhu, Michael Zeng, Cha Zhang, Mohit Bansal.의 [Unifying Vision, Text, and Layout for Universal Document Processing](https://arxiv.org/abs/2212.02623)논문과 함께 발표했습니다. 1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (Google Research 에서) Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzle 의 [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) 논문과 함께 발표했습니다. 1. **[UMT5](https://huggingface.co/docs/transformers/model_doc/umt5)** (Google Research 에서 제공)은 Hyung Won Chung, Xavier Garcia, Adam Roberts, Yi Tay, Orhan Firat, Sharan Narang, Noah Constant.의 [UniMax: Fairer and More Effective Language Sampling for Large-Scale Multilingual Pretraining](https://openreview.net/forum?id=kXwdL1cWOAi)논문과 함께 발표했습니다. 1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (Microsoft Research 에서) Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang 의 [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) 논문과 함께 발표했습니다. diff --git a/README_zh-hans.md b/README_zh-hans.md index 1832870d52ff24..a9e1997da38c83 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -457,6 +457,7 @@ conda install conda-forge::transformers 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (来自 Microsoft) 伴随论文 [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) 由 Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei 发布。 1. **[TVLT](https://huggingface.co/docs/transformers/model_doc/tvlt)** (来自 UNC Chapel Hill) 伴随论文 [TVLT: Textless Vision-Language Transformer](https://arxiv.org/abs/2209.14156) 由 Zineng Tang, Jaemin Cho, Yixin Nie, Mohit Bansal 发布。 1. **[TVP](https://huggingface.co/docs/transformers/model_doc/tvp)** (来自 Intel) 伴随论文 [Text-Visual Prompting for Efficient 2D Temporal Video Grounding](https://arxiv.org/abs/2303.04995) 由 Yimeng Zhang, Xin Chen, Jinghan Jia, Sijia Liu, Ke Ding 发布. +1. **[UDOP](https://huggingface.co/docs/transformers/main/model_doc/udop)** (来自 Microsoft Research) 伴随论文 [Unifying Vision, Text, and Layout for Universal Document Processing](https://arxiv.org/abs/2212.02623) 由 Zineng Tang, Ziyi Yang, Guoxin Wang, Yuwei Fang, Yang Liu, Chenguang Zhu, Michael Zeng, Cha Zhang, Mohit Bansal 发布。 1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (from Google Research) released with the paper [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) by Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler 1. **[UMT5](https://huggingface.co/docs/transformers/model_doc/umt5)** (来自 Google Research) 伴随论文 [UniMax: Fairer and More Effective Language Sampling for Large-Scale Multilingual Pretraining](https://openreview.net/forum?id=kXwdL1cWOAi) 由 Hyung Won Chung, Xavier Garcia, Adam Roberts, Yi Tay, Orhan Firat, Sharan Narang, Noah Constant 发布。 1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (来自 Microsoft Research) 伴随论文 [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) 由 Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 2bf31890f359d7..2c724f309ef304 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -469,6 +469,7 @@ conda install conda-forge::transformers 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft) released with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. 1. **[TVLT](https://huggingface.co/docs/transformers/model_doc/tvlt)** (from UNC Chapel Hill) released with the paper [TVLT: Textless Vision-Language Transformer](https://arxiv.org/abs/2209.14156) by Zineng Tang, Jaemin Cho, Yixin Nie, Mohit Bansal. 1. **[TVP](https://huggingface.co/docs/transformers/model_doc/tvp)** (from Intel) released with the paper [Text-Visual Prompting for Efficient 2D Temporal Video Grounding](https://arxiv.org/abs/2303.04995) by Yimeng Zhang, Xin Chen, Jinghan Jia, Sijia Liu, Ke Ding. +1. **[UDOP](https://huggingface.co/docs/transformers/main/model_doc/udop)** (from Microsoft Research) released with the paper [Unifying Vision, Text, and Layout for Universal Document Processing](https://arxiv.org/abs/2212.02623) by Zineng Tang, Ziyi Yang, Guoxin Wang, Yuwei Fang, Yang Liu, Chenguang Zhu, Michael Zeng, Cha Zhang, Mohit Bansal. 1. **[UL2](https://huggingface.co/docs/transformers/model_doc/ul2)** (from Google Research) released with the paper [Unifying Language Learning Paradigms](https://arxiv.org/abs/2205.05131v1) by Yi Tay, Mostafa Dehghani, Vinh Q. Tran, Xavier Garcia, Dara Bahri, Tal Schuster, Huaixiu Steven Zheng, Neil Houlsby, Donald Metzler 1. **[UMT5](https://huggingface.co/docs/transformers/model_doc/umt5)** (from Google Research) released with the paper [UniMax: Fairer and More Effective Language Sampling for Large-Scale Multilingual Pretraining](https://openreview.net/forum?id=kXwdL1cWOAi) by Hyung Won Chung, Xavier Garcia, Adam Roberts, Yi Tay, Orhan Firat, Sharan Narang, Noah Constant. 1. **[UniSpeech](https://huggingface.co/docs/transformers/model_doc/unispeech)** (from Microsoft Research) released with the paper [UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597) by Chengyi Wang, Yu Wu, Yao Qian, Kenichi Kumatani, Shujie Liu, Furu Wei, Michael Zeng, Xuedong Huang. diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ff6e91dbcf25d6..76d8a2ba7d7d75 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -770,6 +770,8 @@ title: TVLT - local: model_doc/tvp title: TVP + - local: model_doc/udop + title: UDOP - local: model_doc/vilt title: ViLT - local: model_doc/vipllava diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 34995edec39c7d..36216962d2da34 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -279,6 +279,7 @@ Flax), PyTorch, and/or TensorFlow. | [TrOCR](model_doc/trocr) | ✅ | ❌ | ❌ | | [TVLT](model_doc/tvlt) | ✅ | ❌ | ❌ | | [TVP](model_doc/tvp) | ✅ | ❌ | ❌ | +| [UDOP](model_doc/udop) | ✅ | ❌ | ❌ | | [UL2](model_doc/ul2) | ✅ | ✅ | ✅ | | [UMT5](model_doc/umt5) | ✅ | ❌ | ❌ | | [UniSpeech](model_doc/unispeech) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/udop.md b/docs/source/en/model_doc/udop.md new file mode 100644 index 00000000000000..b84ec160f705cc --- /dev/null +++ b/docs/source/en/model_doc/udop.md @@ -0,0 +1,102 @@ + + +# UDOP + +## Overview + +The UDOP model was proposed in [Unifying Vision, Text, and Layout for Universal Document Processing](https://arxiv.org/abs/2212.02623) by Zineng Tang, Ziyi Yang, Guoxin Wang, Yuwei Fang, Yang Liu, Chenguang Zhu, Michael Zeng, Cha Zhang, Mohit Bansal. +UDOP adopts an encoder-decoder Transformer architecture based on [T5](t5) for document AI tasks like document image classification, document parsing and document visual question answering. + +The abstract from the paper is the following: + +We propose Universal Document Processing (UDOP), a foundation Document AI model which unifies text, image, and layout modalities together with varied task formats, including document understanding and generation. UDOP leverages the spatial correlation between textual content and document image to model image, text, and layout modalities with one uniform representation. With a novel Vision-Text-Layout Transformer, UDOP unifies pretraining and multi-domain downstream tasks into a prompt-based sequence generation scheme. UDOP is pretrained on both large-scale unlabeled document corpora using innovative self-supervised objectives and diverse labeled data. UDOP also learns to generate document images from text and layout modalities via masked image reconstruction. To the best of our knowledge, this is the first time in the field of document AI that one model simultaneously achieves high-quality neural document editing and content customization. Our method sets the state-of-the-art on 9 Document AI tasks, e.g., document understanding and QA, across diverse data domains like finance reports, academic papers, and websites. UDOP ranks first on the leaderboard of the Document Understanding Benchmark (DUE).* + + + + UDOP architecture. Taken from the original paper. + +## Usage tips + +- In addition to *input_ids*, [`UdopForConditionalGeneration`] also expects the input `bbox`, which are + the bounding boxes (i.e. 2D-positions) of the input tokens. These can be obtained using an external OCR engine such + as Google's [Tesseract](https://github.com/tesseract-ocr/tesseract) (there's a [Python wrapper](https://pypi.org/project/pytesseract/) available). Each bounding box should be in (x0, y0, x1, y1) format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, y1) represents the + position of the lower right corner. Note that one first needs to normalize the bounding boxes to be on a 0-1000 + scale. To normalize, you can use the following function: + +```python +def normalize_bbox(bbox, width, height): + return [ + int(1000 * (bbox[0] / width)), + int(1000 * (bbox[1] / height)), + int(1000 * (bbox[2] / width)), + int(1000 * (bbox[3] / height)), + ] +``` + +Here, `width` and `height` correspond to the width and height of the original document in which the token +occurs. Those can be obtained using the Python Image Library (PIL) library for example, as follows: + +```python +from PIL import Image + +# Document can be a png, jpg, etc. PDFs must be converted to images. +image = Image.open(name_of_your_document).convert("RGB") + +width, height = image.size +``` + +- At inference time, it's recommended to use the `generate` method to autoregressively generate text given a document image. +- One can use [`UdopProcessor`] to prepare images and text for the model. By default, this class uses the Tesseract engine to extract a list of words +and boxes (coordinates) from a given document. Its functionality is equivalent to that of [`LayoutLMv3Processor`], hence it supports passing either +`apply_ocr=False` in case you prefer to use your own OCR engine or `apply_ocr=True` in case you want the default OCR engine to be used. + +This model was contributed by [nielsr](https://huggingface.co/nielsr). +The original code can be found [here](https://github.com/microsoft/UDOP). + + +## UdopConfig + +[[autodoc]] UdopConfig + +## UdopTokenizer + +[[autodoc]] UdopTokenizer + - build_inputs_with_special_tokens + - get_special_tokens_mask + - create_token_type_ids_from_sequences + - save_vocabulary + +## UdopTokenizerFast + +[[autodoc]] UdopTokenizerFast + +## UdopProcessor + +[[autodoc]] UdopProcessor + - __call__ + +## UdopModel + +[[autodoc]] UdopModel + - forward + +## UdopForConditionalGeneration + +[[autodoc]] UdopForConditionalGeneration + - forward + +## UdopEncoderModel + +[[autodoc]] UdopEncoderModel + - forward \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 027cf495466c50..6cdd561b41e1ba 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -856,6 +856,11 @@ "TvpConfig", "TvpProcessor", ], + "models.udop": [ + "UDOP_PRETRAINED_CONFIG_ARCHIVE_MAP", + "UdopConfig", + "UdopProcessor", + ], "models.umt5": ["UMT5Config"], "models.unispeech": [ "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -1135,6 +1140,7 @@ _import_structure["models.speech_to_text"].append("Speech2TextTokenizer") _import_structure["models.speecht5"].append("SpeechT5Tokenizer") _import_structure["models.t5"].append("T5Tokenizer") + _import_structure["models.udop"].append("UdopTokenizer") _import_structure["models.xglm"].append("XGLMTokenizer") _import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer") _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer") @@ -1214,6 +1220,7 @@ _import_structure["models.splinter"].append("SplinterTokenizerFast") _import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast") _import_structure["models.t5"].append("T5TokenizerFast") + _import_structure["models.udop"].append("UdopTokenizerFast") _import_structure["models.whisper"].append("WhisperTokenizerFast") _import_structure["models.xglm"].append("XGLMTokenizerFast") _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast") @@ -3411,6 +3418,15 @@ "TvpPreTrainedModel", ] ) + _import_structure["models.udop"].extend( + [ + "UDOP_PRETRAINED_MODEL_ARCHIVE_LIST", + "UdopEncoderModel", + "UdopForConditionalGeneration", + "UdopModel", + "UdopPreTrainedModel", + ], + ) _import_structure["models.umt5"].extend( [ "UMT5EncoderModel", @@ -5640,6 +5656,7 @@ TvpConfig, TvpProcessor, ) + from .models.udop import UDOP_PRETRAINED_CONFIG_ARCHIVE_MAP, UdopConfig, UdopProcessor from .models.umt5 import UMT5Config from .models.unispeech import ( UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -5915,6 +5932,7 @@ from .models.speech_to_text import Speech2TextTokenizer from .models.speecht5 import SpeechT5Tokenizer from .models.t5 import T5Tokenizer + from .models.udop import UdopTokenizer from .models.xglm import XGLMTokenizer from .models.xlm_prophetnet import XLMProphetNetTokenizer from .models.xlm_roberta import XLMRobertaTokenizer @@ -5987,6 +6005,7 @@ from .models.splinter import SplinterTokenizerFast from .models.squeezebert import SqueezeBertTokenizerFast from .models.t5 import T5TokenizerFast + from .models.udop import UdopTokenizerFast from .models.whisper import WhisperTokenizerFast from .models.xglm import XGLMTokenizerFast from .models.xlm_roberta import XLMRobertaTokenizerFast @@ -7827,6 +7846,13 @@ TvpModel, TvpPreTrainedModel, ) + from .models.udop import ( + UDOP_PRETRAINED_MODEL_ARCHIVE_LIST, + UdopEncoderModel, + UdopForConditionalGeneration, + UdopModel, + UdopPreTrainedModel, + ) from .models.umt5 import ( UMT5EncoderModel, UMT5ForConditionalGeneration, diff --git a/src/transformers/convert_slow_tokenizer.py b/src/transformers/convert_slow_tokenizer.py index c44592f8a0f9fb..707bfae89db56f 100644 --- a/src/transformers/convert_slow_tokenizer.py +++ b/src/transformers/convert_slow_tokenizer.py @@ -1039,6 +1039,17 @@ def post_processor(self): ) +class UdopConverter(SpmConverter): + def post_processor(self): + return processors.TemplateProcessing( + single=["$A", ""], + pair=["$A", "", "$B", ""], + special_tokens=[ + ("", self.original_tokenizer.convert_tokens_to_ids("")), + ], + ) + + class WhisperConverter(Converter): def converted(self) -> Tokenizer: vocab = self.original_tokenizer.encoder @@ -1471,6 +1482,7 @@ def converted(self) -> Tokenizer: "SeamlessM4TTokenizer": SeamlessM4TConverter, "SqueezeBertTokenizer": BertConverter, "T5Tokenizer": T5Converter, + "UdopTokenizer": UdopConverter, "WhisperTokenizer": WhisperConverter, "XLMRobertaTokenizer": XLMRobertaConverter, "XLNetTokenizer": XLNetConverter, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index ebb3db25fb96be..89ca6ab2b8660c 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -220,6 +220,7 @@ trocr, tvlt, tvp, + udop, umt5, unispeech, unispeech_sat, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 7bc637f3e1060a..87ff925e55eaa1 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -231,6 +231,7 @@ ("trocr", "TrOCRConfig"), ("tvlt", "TvltConfig"), ("tvp", "TvpConfig"), + ("udop", "UdopConfig"), ("umt5", "UMT5Config"), ("unispeech", "UniSpeechConfig"), ("unispeech-sat", "UniSpeechSatConfig"), @@ -454,6 +455,7 @@ ("transfo-xl", "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("tvlt", "TVLT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("tvp", "TVP_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("udop", "UDOP_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("unispeech", "UNISPEECH_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("unispeech-sat", "UNISPEECH_SAT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("univnet", "UNIVNET_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -715,6 +717,7 @@ ("trocr", "TrOCR"), ("tvlt", "TVLT"), ("tvp", "TVP"), + ("udop", "UDOP"), ("ul2", "UL2"), ("umt5", "UMT5"), ("unispeech", "UniSpeech"), diff --git a/src/transformers/models/auto/image_processing_auto.py b/src/transformers/models/auto/image_processing_auto.py index aef894a425bae1..50e9266cdee161 100644 --- a/src/transformers/models/auto/image_processing_auto.py +++ b/src/transformers/models/auto/image_processing_auto.py @@ -108,6 +108,7 @@ ("timesformer", "VideoMAEImageProcessor"), ("tvlt", "TvltImageProcessor"), ("tvp", "TvpImageProcessor"), + ("udop", "LayoutLMv3ImageProcessor"), ("upernet", "SegformerImageProcessor"), ("van", "ConvNextImageProcessor"), ("videomae", "VideoMAEImageProcessor"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 05b519d2bcd16b..0d28d224f19106 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -219,6 +219,7 @@ ("transfo-xl", "TransfoXLModel"), ("tvlt", "TvltModel"), ("tvp", "TvpModel"), + ("udop", "UdopModel"), ("umt5", "UMT5Model"), ("unispeech", "UniSpeechModel"), ("unispeech-sat", "UniSpeechSatModel"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 2c21f1cd529c74..d586068fb9c095 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -418,6 +418,13 @@ ("tapex", ("TapexTokenizer", None)), ("transfo-xl", ("TransfoXLTokenizer", None)), ("tvp", ("BertTokenizer", "BertTokenizerFast" if is_tokenizers_available() else None)), + ( + "udop", + ( + "UdopTokenizer" if is_sentencepiece_available() else None, + "UdopTokenizerFast" if is_tokenizers_available() else None, + ), + ), ( "umt5", ( diff --git a/src/transformers/models/udop/__init__.py b/src/transformers/models/udop/__init__.py new file mode 100644 index 00000000000000..5066fde6af1d15 --- /dev/null +++ b/src/transformers/models/udop/__init__.py @@ -0,0 +1,98 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_sentencepiece_available, + is_tokenizers_available, + is_torch_available, +) + + +_import_structure = { + "configuration_udop": ["UDOP_PRETRAINED_CONFIG_ARCHIVE_MAP", "UdopConfig"], + "processing_udop": ["UdopProcessor"], +} + +try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_udop"] = ["UdopTokenizer"] + +try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_udop_fast"] = ["UdopTokenizerFast"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_udop"] = [ + "UDOP_PRETRAINED_MODEL_ARCHIVE_LIST", + "UdopForConditionalGeneration", + "UdopPreTrainedModel", + "UdopModel", + "UdopEncoderModel", + ] + +if TYPE_CHECKING: + from .configuration_udop import UDOP_PRETRAINED_CONFIG_ARCHIVE_MAP, UdopConfig + from .processing_udop import UdopProcessor + + try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_udop import UdopTokenizer + + try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_udop_fast import UdopTokenizerFast + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_udop import ( + UDOP_PRETRAINED_MODEL_ARCHIVE_LIST, + UdopEncoderModel, + UdopForConditionalGeneration, + UdopModel, + UdopPreTrainedModel, + ) + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/udop/configuration_udop.py b/src/transformers/models/udop/configuration_udop.py new file mode 100644 index 00000000000000..8647a7bae29acf --- /dev/null +++ b/src/transformers/models/udop/configuration_udop.py @@ -0,0 +1,162 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" UDOP model configuration""" + + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +UDOP_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/udop-large": "https://huggingface.co/microsoft/udop-large/resolve/main/config.json", +} + + +class UdopConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`UdopForConditionalGeneration`]. It is used to + instantiate a UDOP model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the UDOP + [microsoft/udop-large](https://huggingface.co/microsoft/udop-large) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Arguments: + vocab_size (`int`, *optional*, defaults to 33201): + Vocabulary size of the UDOP model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`UdopForConditionalGeneration`]. + d_model (`int`, *optional*, defaults to 1024): + Size of the encoder layers and the pooler layer. + d_kv (`int`, *optional*, defaults to 64): + Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will + be defined as `num_heads * d_kv`. + d_ff (`int`, *optional*, defaults to 4096): + Size of the intermediate feed forward layer in each `UdopBlock`. + num_layers (`int`, *optional*, defaults to 24): + Number of hidden layers in the Transformer encoder and decoder. + num_decoder_layers (`int`, *optional*): + Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set. + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder and decoder. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + The number of buckets to use for each attention layer. + relative_attention_max_distance (`int`, *optional*, defaults to 128): + The maximum distance of the longer sequences for the bucket separation. + relative_bias_args (`List[dict]`, *optional*, defaults to `[{'type': '1d'}, {'type': 'horizontal'}, {'type': 'vertical'}]`): + A list of dictionaries containing the arguments for the relative bias layers. + dropout_rate (`float`, *optional*, defaults to 0.1): + The ratio for all dropout layers. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + initializer_factor (`float`, *optional*, defaults to 1.0): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (`string`, *optional*, defaults to `"relu"`): + Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. Udopv1.1 uses the + `"gated-gelu"` feed forward projection. Original Udop uses `"relu"`. + is_encoder_decoder (`bool`, *optional*, defaults to `True`): + Whether the model should behave as an encoder/decoder or not. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 1): + The id of the end-of-sequence token in the vocabulary. + max_2d_position_embeddings (`int`, *optional*, defaults to 1024): + The maximum absolute position embeddings for relative position encoding. + image_size (`int`, *optional*, defaults to 224): + The size of the input images. + patch_size (`int`, *optional*, defaults to 16): + The patch size used by the vision encoder. + num_channels (`int`, *optional*, defaults to 3): + The number of channels in the input images. + """ + + model_type = "udop" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=33201, + d_model=1024, + d_kv=64, + d_ff=4096, + num_layers=24, + num_decoder_layers=None, + num_heads=16, + relative_attention_num_buckets=32, + relative_attention_max_distance=128, + relative_bias_args=[{"type": "1d"}, {"type": "horizontal"}, {"type": "vertical"}], + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="relu", + is_encoder_decoder=True, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + max_2d_position_embeddings=1024, + image_size=224, + patch_size=16, + num_channels=3, + **kwargs, + ): + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.relative_attention_max_distance = relative_attention_max_distance + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + self.use_cache = use_cache + + # UDOP attributes + self.max_2d_position_embeddings = max_2d_position_embeddings + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + if not isinstance(relative_bias_args, list): + raise ValueError("`relative_bias_args` should be a list of dictionaries.") + self.relative_bias_args = relative_bias_args + + act_info = self.feed_forward_proj.split("-") + self.dense_act_fn = act_info[-1] + self.is_gated_act = act_info[0] == "gated" + + if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2: + raise ValueError( + f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer." + "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. " + "'gated-gelu' or 'relu'" + ) + + super().__init__( + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) diff --git a/src/transformers/models/udop/convert_udop_to_hf.py b/src/transformers/models/udop/convert_udop_to_hf.py new file mode 100644 index 00000000000000..f9cf07f1286bf1 --- /dev/null +++ b/src/transformers/models/udop/convert_udop_to_hf.py @@ -0,0 +1,213 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert UDOP checkpoints from the original repository. URL: https://github.com/microsoft/i-Code/tree/main/i-Code-Doc""" + + +import argparse + +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from torchvision import transforms as T + +from transformers import ( + LayoutLMv3ImageProcessor, + UdopConfig, + UdopForConditionalGeneration, + UdopProcessor, + UdopTokenizer, +) +from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def original_transform(image, image_size=224): + transform = T.Compose( + [ + T.Resize([image_size, image_size]), + T.ToTensor(), + T.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + ] + ) + + image = transform(image) + return image + + +def get_image(): + filepath = hf_hub_download( + repo_id="hf-internal-testing/fixtures_docvqa", filename="document_2.png", repo_type="dataset" + ) + image = Image.open(filepath).convert("RGB") + + return image + + +def prepare_dummy_inputs(tokenizer, image_processor): + prompt = "Question answering. What is the name of the company?" + prompt = "Question answering. In which year is the report made?" + prompt_ids = tokenizer.encode(prompt, add_special_tokens=False) + + image = get_image() + # words, boxes = apply_tesseract(image, lang=None) + # fmt: off + words = ['7', 'ITC', 'Limited', 'REPORT', 'AND', 'ACCOUNTS', '2013', 'ITC’s', 'Brands:', 'An', 'Asset', 'for', 'the', 'Nation', 'The', 'consumer', 'needs', 'and', 'aspirations', 'they', 'fulfil,', 'the', 'benefit', 'they', 'generate', 'for', 'millions', 'across', 'ITC’s', 'value', 'chains,', 'the', 'future-ready', 'capabilities', 'that', 'support', 'them,', 'and', 'the', 'value', 'that', 'they', 'create', 'for', 'the', 'country,', 'have', 'made', 'ITC’s', 'brands', 'national', 'assets,', 'adding', 'to', 'India’s', 'competitiveness.', 'It', 'is', 'ITC’s', 'aspiration', 'to', 'be', 'the', 'No', '1', 'FMCG', 'player', 'in', 'the', 'country,', 'driven', 'by', 'its', 'new', 'FMCG', 'businesses.', 'A', 'recent', 'Nielsen', 'report', 'has', 'highlighted', 'that', "ITC's", 'new', 'FMCG', 'businesses', 'are', 'the', 'fastest', 'growing', 'among', 'the', 'top', 'consumer', 'goods', 'companies', 'operating', 'in', 'India.', 'ITC', 'takes', 'justifiable', 'pride', 'that,', 'along', 'with', 'generating', 'economic', 'value,', 'these', 'celebrated', 'Indian', 'brands', 'also', 'drive', 'the', 'creation', 'of', 'larger', 'societal', 'capital', 'through', 'the', 'virtuous', 'cycle', 'of', 'sustainable', 'and', 'inclusive', 'growth.', 'DI', 'WILLS', '*', ';', 'LOVE', 'DELIGHTFULLY', 'SOFT', 'SKIN?', 'aia', 'Ans', 'Source:', 'https://www.industrydocuments.ucsf.edu/docs/snbx0223'] + boxes = [[0, 45, 67, 80], [72, 56, 109, 67], [116, 56, 189, 67], [198, 59, 253, 66], [257, 59, 285, 66], [289, 59, 365, 66], [372, 59, 407, 66], [74, 136, 161, 158], [175, 137, 306, 158], [318, 137, 363, 158], [374, 137, 472, 158], [483, 136, 529, 158], [540, 137, 593, 158], [608, 137, 717, 158], [73, 194, 100, 203], [106, 196, 177, 203], [183, 194, 227, 203], [233, 194, 259, 203], [265, 194, 344, 205], [74, 211, 104, 222], [109, 210, 141, 221], [147, 211, 169, 220], [175, 210, 223, 220], [229, 211, 259, 222], [265, 211, 329, 222], [334, 210, 352, 220], [74, 227, 127, 236], [133, 229, 180, 236], [187, 227, 221, 236], [226, 227, 264, 236], [270, 227, 320, 237], [327, 227, 349, 236], [74, 243, 161, 254], [166, 243, 249, 254], [254, 243, 281, 252], [286, 244, 342, 254], [74, 260, 112, 270], [119, 260, 145, 269], [151, 260, 174, 269], [179, 260, 217, 269], [222, 260, 249, 269], [254, 260, 285, 271], [290, 260, 335, 269], [340, 259, 359, 269], [74, 276, 95, 284], [101, 276, 156, 287], [164, 276, 198, 284], [203, 276, 244, 284], [251, 275, 285, 284], [291, 276, 340, 284], [74, 292, 129, 301], [135, 292, 185, 302], [192, 292, 242, 303], [248, 292, 261, 301], [267, 292, 312, 301], [74, 308, 195, 319], [75, 335, 82, 344], [88, 335, 98, 344], [105, 335, 138, 344], [144, 335, 214, 346], [220, 336, 233, 344], [239, 335, 256, 344], [262, 335, 283, 344], [290, 335, 309, 344], [316, 335, 320, 344], [74, 351, 119, 360], [126, 352, 170, 362], [176, 352, 186, 360], [192, 352, 214, 360], [220, 352, 276, 362], [282, 352, 326, 360], [333, 352, 349, 362], [74, 368, 89, 377], [95, 370, 124, 377], [129, 367, 175, 377], [181, 368, 266, 377], [272, 368, 283, 376], [289, 368, 333, 377], [74, 384, 126, 393], [134, 385, 175, 395], [181, 384, 206, 393], [212, 384, 292, 395], [298, 384, 325, 393], [330, 384, 366, 393], [74, 403, 103, 409], [109, 400, 154, 409], [161, 401, 241, 409], [247, 403, 269, 409], [275, 401, 296, 409], [302, 400, 349, 409], [74, 417, 131, 428], [137, 419, 186, 428], [192, 417, 214, 426], [219, 417, 242, 428], [248, 419, 319, 426], [74, 433, 119, 444], [125, 433, 204, 444], [210, 433, 278, 444], [285, 433, 295, 441], [302, 433, 340, 442], [75, 449, 98, 458], [104, 449, 142, 458], [146, 449, 215, 460], [221, 449, 258, 460], [263, 449, 293, 459], [300, 449, 339, 460], [74, 466, 101, 474], [108, 466, 185, 476], [191, 466, 261, 474], [267, 466, 309, 476], [315, 466, 354, 474], [74, 482, 151, 491], [158, 482, 201, 491], [208, 482, 258, 491], [263, 482, 292, 491], [298, 482, 333, 491], [338, 482, 360, 491], [74, 498, 131, 507], [137, 498, 150, 507], [156, 498, 197, 509], [202, 498, 257, 507], [263, 498, 310, 509], [74, 515, 128, 525], [134, 515, 156, 523], [161, 515, 218, 523], [223, 515, 261, 525], [267, 514, 280, 523], [74, 531, 156, 540], [162, 531, 188, 540], [195, 531, 257, 540], [263, 531, 315, 542], [871, 199, 878, 202], [883, 199, 908, 202], [894, 251, 904, 257], [841, 268, 841, 270], [784, 373, 811, 378], [816, 373, 896, 378], [784, 381, 811, 387], [815, 381, 847, 387], [645, 908, 670, 915], [692, 908, 712, 915], [220, 984, 285, 993], [293, 983, 779, 996]] + # fmt: on + text_list = [] + bbox_list = [] + for text, box in zip(words, boxes): + if text == "": + continue + sub_tokens = tokenizer.tokenize(text) + for sub_token in sub_tokens: + text_list.append(sub_token) + bbox_list.append(box) + + input_ids = tokenizer.convert_tokens_to_ids(text_list) + + input_ids = prompt_ids + input_ids + bbox = [[0, 0, 0, 0]] * len(prompt_ids) + bbox_list + + pixel_values = image_processor(image, return_tensors="pt").pixel_values + original_pixel_values = original_transform(image, image_size=image_processor.size["height"]).unsqueeze(0) + # verify pixel values + assert torch.allclose(original_pixel_values, pixel_values) + print("Pixel values are ok!") + + return torch.tensor(input_ids).unsqueeze(0), torch.tensor(bbox).unsqueeze(0).float(), pixel_values + + +def convert_udop_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False): + # model_name to checkpoint_path + name_to_checkpoint_path = { + "udop-large": "/Users/nielsrogge/Documents/UDOP/udop-unimodel-large-224/pytorch_model.bin", + "udop-large-512": "/Users/nielsrogge/Documents/UDOP/udop-unimodel-large-512/pytorch_model.bin", + "udop-large-512-300k": "/Users/nielsrogge/Documents/UDOP/udop-unimodel-large-512-300k-steps/pytorch_model.bin", + } + + # load original state dict + checkpoint_path = name_to_checkpoint_path[model_name] + state_dict = torch.load(checkpoint_path, map_location="cpu") + + print("Checkpoint path:", checkpoint_path) + + # create HF model + image_size = 512 if "512" in model_name else 224 + config = UdopConfig(decoder_start_token_id=0, image_size=image_size) + model = UdopForConditionalGeneration(config) + model.eval() + + # rename keys + state_dict = {k.replace("cell2dembedding", "cell_2d_embedding"): v for k, v in state_dict.items()} + + # load weights + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + print("Missing keys:", missing_keys) + print("Unexpected keys:", unexpected_keys) + assert missing_keys == ["encoder.embed_patches.proj.weight", "encoder.embed_patches.proj.bias"] + assert unexpected_keys == ["pos_embed"] + + # prepare dummy inputs + tokenizer = UdopTokenizer.from_pretrained("t5-base", legacy=True) + size = {"height": image_size, "width": image_size} + image_processor = LayoutLMv3ImageProcessor( + image_mean=IMAGENET_DEFAULT_MEAN, image_std=IMAGENET_DEFAULT_STD, size=size + ) + processor = UdopProcessor(image_processor=image_processor, tokenizer=tokenizer) + input_ids, bbox, image = prepare_dummy_inputs(tokenizer, image_processor) + prompt = "Question answering. In which year is the report made?" + encoding = processor(images=get_image(), text=prompt, return_tensors="pt") + + input_ids = encoding.input_ids + try: + EXPECTED_INPUT_IDS = torch.tensor([[11860, 18243, 5, 86, 84, 215, 19, 8, 934, 263, 58, 1, 489, 27, 3838, 7363, 4083, 14536, 3430, 5686, 5911, 17161, 134, 2038, 27, 3838, 22, 7, 4688, 7, 10, 389, 18202, 21, 8, 11046, 37, 3733, 523, 11, 38, 2388, 1628, 3, 13133, 23334, 6, 8, 1656, 79, 3806, 21, 4040, 640, 27, 3838, 22, 7, 701, 16534, 6, 8, 3, 76, 2693, 18, 23015, 5644, 24, 380, 3, 6015, 6, 11, 8, 701, 24, 79, 482, 21, 3, 88, 684, 6, 43, 263, 27, 3838, 22, 7, 3635, 1157, 4089, 6, 2651, 12, 1547, 22, 7, 3265, 655, 5, 19, 27, 3838, 22, 7, 38, 2388, 257, 12, 36, 8, 465, 209, 13409, 12150, 1959, 16, 8, 684, 6, 6737, 57, 165, 126, 13409, 12150, 1623, 5, 71, 1100, 30298, 934, 65, 12566, 24, 27, 3838, 31, 7, 126, 13409, 12150, 1623, 33, 8, 10391, 1710, 859, 8, 420, 3733, 4968, 688, 2699, 16, 1547, 5, 27, 3838, 1217, 131, 99, 23, 179, 6064, 24, 6, 590, 28, 3, 11600, 1456, 701, 6, 175, 9443, 2557, 3635, 92, 1262, 8, 3409, 13, 2186, 3, 27908, 1784, 190, 8, 3, 5771, 17, 13281, 4005, 13, 5086, 11, 13066, 1170, 5, 10826, 16309, 134, 3, 2, 276, 26, 3, 55, 391, 13570, 5, 10315, 309, 3577, 19114, 371, 4254, 5121, 5055, 6245, 3, 10047, 3162, 58, 3, 9, 61, 1713, 2703, 476, 667, 25158, 301, 6058, 6038, 476, 3765, 9149, 10, 4893, 1303, 1986, 5, 13580, 7, 8224, 28244, 7, 5, 76, 75, 7, 89, 5, 15, 1259, 87, 7171, 7, 87, 7, 29, 115, 226, 4305, 2773, 1]]) # fmt: skip + torch.testing.assert_close(EXPECTED_INPUT_IDS, input_ids) + bbox = encoding.bbox.float() + pixel_values = encoding.pixel_values + except Exception: + print("Input_ids don't match, preparing dummy inputs") + input_ids, bbox, pixel_values = prepare_dummy_inputs(tokenizer, image_processor) + + # Verify single forward pass + print("Testing single forward pass..") + with torch.no_grad(): + decoder_input_ids = torch.tensor([[101]]) + outputs = model(input_ids=input_ids, bbox=bbox, pixel_values=pixel_values, decoder_input_ids=decoder_input_ids) + print("Shape of logits:", outputs.logits.shape) + print("First values of logits:", outputs.logits[0, :3, :3]) + + # tensor([[-18.5262, 1.5087, -15.7051]]) on linux + # tensor([[-19.4976, 0.8515, -17.1873]]) on mac + try: + assert torch.allclose(outputs.logits[0, :3, :3], torch.tensor([[-18.5262, 1.5087, -15.7051]]), atol=1e-4) + print("Looks ok!") + except Exception: + print("logits don't match let's try to generate") + + # Verify autoregressive decoding + print("Testing generation...") + model_kwargs = {"bbox": bbox, "pixel_values": pixel_values} + outputs = model.generate(input_ids=input_ids, **model_kwargs, max_new_tokens=20) + + print("Generated:", tokenizer.batch_decode(outputs, skip_special_tokens=True)) + + # autoregressive decoding with original input data + print("Testing generation with original inputs...") + filepath = hf_hub_download(repo_id="nielsr/test-image", filename="input_ids_udop.pt", repo_type="dataset") + input_ids = torch.load(filepath) + filepath = hf_hub_download(repo_id="nielsr/test-image", filename="bbox_udop.pt", repo_type="dataset") + bbox = torch.load(filepath) + pixel_values_filename = "pixel_values_udop_512.pt" if "512" in model_name else "pixel_values_udop_224.pt" + filepath = hf_hub_download(repo_id="nielsr/test-image", filename=pixel_values_filename, repo_type="dataset") + pixel_values = torch.load(filepath) + + print("Decoded input ids:", tokenizer.decode(input_ids[0], skip_special_tokens=True)) + print("Bbox shape:", bbox.shape) + + model_kwargs = {"bbox": bbox, "pixel_values": pixel_values} + outputs = model.generate(input_ids=input_ids, **model_kwargs, max_new_tokens=20) + generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] + print("Generated:", generated_text) + + if pytorch_dump_folder_path is not None: + model.save_pretrained(pytorch_dump_folder_path) + tokenizer.save_pretrained(pytorch_dump_folder_path) + + if push_to_hub: + model.push_to_hub(f"microsoft/{model_name}") + processor.push_to_hub(f"microsoft/{model_name}") + # BIG note here: to save the fast tokenizer files in the repo on the hub, you need to do the following: + # see https://discuss.huggingface.co/t/convert-slow-xlmrobertatokenizer-to-fast-one/20876 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--model_name", + default="udop-large", + type=str, + choices=["udop-large", "udop-large-512", "udop-large-512-300k"], + help=("Name of the UDOP model you'd like to convert."), + ) + parser.add_argument( + "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + ) + parser.add_argument( + "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub." + ) + + args = parser.parse_args() + convert_udop_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub) diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py new file mode 100644 index 00000000000000..62192eea7f5a5e --- /dev/null +++ b/src/transformers/models/udop/modeling_udop.py @@ -0,0 +1,2030 @@ +# coding=utf-8 +# Copyright 2024 Microsoft Research and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch UDOP model.""" + +import collections +import logging +import math +import random +from abc import ABC, abstractmethod +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import torch +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss + +from transformers import UdopConfig +from transformers.modeling_outputs import ( + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) + +from ...activations import ACT2FN +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + + +logger = logging.getLogger(__name__) + +UDOP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/udop-large", + # See all UDOP models at https://huggingface.co/models?filter=udop +] + + +_CONFIG_FOR_DOC = "UdopConfig" + + +UDOP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Args: + config ([`UdopConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +UDOP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. UDOP is a model with relative position embeddings so + you should be able to pad the inputs on both the right and the left. Indices can be obtained using + [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for detail. + [What are input IDs?](../glossary#input-ids) + + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + + bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*): + Bounding boxes of each input sequence tokens. Selected in the range `[0, + config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1) + format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, + y1) represents the position of the lower right corner. + + Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS] + token. See `pixel_values` for `patch_sequence_length`. + + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Batch of document images. Each image is divided into patches of shape `(num_channels, config.patch_size, + config.patch_size)` and the total number of patches (=`patch_sequence_length`) equals to `((height / + config.patch_size) * (width / config.patch_size))`. + + visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*): + Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model. + + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using + [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + [What are decoder input IDs?](../glossary#decoder-input-ids) T5 uses the `pad_token_id` as the starting + token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last + `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare + `decoder_input_ids` for pretraining take a look at [T5 Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. If + `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value of + `inputs_embeds`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +UDOP_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*): + Bounding boxes of each input sequence tokens. Selected in the range `[0, + config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1) + format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1, + y1) represents the position of the lower right corner. + + Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS] + token. See `pixel_values` for `patch_sequence_length`. + + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Batch of document images. Each image is divided into patches of shape `(num_channels, config.patch_size, + config.patch_size)` and the total number of patches (=`patch_sequence_length`) equals to `((height / + config.patch_size) * (width / config.patch_size))`. + + visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*): + Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model. + + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@dataclass +class BaseModelOutputWithAttentionMask(ModelOutput): + """ + Class for the model's outputs that may also contain a past key/values (to speed up sequential decoding). Includes + an additional attention mask. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only + the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or + when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the + self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) + that can be used (see `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or + when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of + the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when + `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in + the self-attention heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and + `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax, + used to compute the weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor = None + attention_mask: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +def get_visual_bbox(image_size=224, patch_size=16): + image_feature_pool_shape = [image_size // patch_size, image_size // patch_size] + visual_bbox_x = torch.arange(0, 1.0 * (image_feature_pool_shape[1] + 1), 1.0) + visual_bbox_x /= image_feature_pool_shape[1] + + visual_bbox_y = torch.arange(0, 1.0 * (image_feature_pool_shape[0] + 1), 1.0) + visual_bbox_y /= image_feature_pool_shape[0] + + visual_bbox_input = torch.stack( + [ + visual_bbox_x[:-1].repeat(image_feature_pool_shape[0], 1), + visual_bbox_y[:-1].repeat(image_feature_pool_shape[1], 1).transpose(0, 1), + visual_bbox_x[1:].repeat(image_feature_pool_shape[0], 1), + visual_bbox_y[1:].repeat(image_feature_pool_shape[1], 1).transpose(0, 1), + ], + dim=-1, + ) + + visual_bbox_input = visual_bbox_input.view(-1, 4) + + return visual_bbox_input + + +def pad_sequence(seq, target_len, pad_value=0): + if isinstance(seq, torch.Tensor): + n = seq.shape[0] + else: + n = len(seq) + seq = torch.tensor(seq) + m = target_len - n + if m > 0: + ret = torch.stack([pad_value] * m).to(seq) + seq = torch.cat([seq, ret], dim=0) + return seq[:target_len] + + +def combine_image_text_embeddings( + image_embeddings, + inputs_embeds, + bbox, + visual_bbox, + attention_mask=None, + num_patches=14, + max_len=0, + image_size=224, + patch_size=16, +): + """ + Combine the image and text embeddings for the input to the encoder/decoder of UDOP. + + First, the image embeddings are created by checking for each visual patch if it is inside the bounding box of a + token. If it is, the visual patch is combined with the token embedding. Then, the visual bounding boxes are combined + with the text bounding boxes. Finally, the visual bounding boxes are combined with the text attention mask. + """ + + sequence_length = num_patches + ocr_points_x = torch.clip( + torch.floor((bbox[:, :, 0] + bbox[:, :, 2]) / 2.0 * sequence_length).long(), 0, sequence_length - 1 + ) + ocr_points_y = ( + torch.clip(torch.floor((bbox[:, :, 1] + bbox[:, :, 3]) / 2.0 * sequence_length).long(), 0, sequence_length - 1) + * sequence_length + ) + ocr_points = ocr_points_x + ocr_points_y + # make sure bounding boxes are of type float to calculate means + bbox = bbox.to(torch.float64) + target_seg = (bbox.mean(-1) == 0.0) | (bbox.mean(-1) == 1.0) + repeated_vision_embeds = torch.gather( + image_embeddings, 1, ocr_points.unsqueeze(-1).repeat(1, 1, image_embeddings.size(-1)) + ) + repeated_vision_embeds[target_seg] = 0.0 + inputs_embeds += repeated_vision_embeds + + patch_inds = torch.full_like(image_embeddings[:, :, 0], True).bool() + ind = torch.cat( + [ + torch.arange(len(ocr_points))[:, None].repeat(1, ocr_points.size(-1))[:, :, None].to(ocr_points), + ocr_points[:, :, None], + ], + dim=-1, + ) + ind = ind.flatten(0, 1) + rows, cols = zip(*ind) + patch_inds[rows, cols] = False + + input_vision_patches = [image_embeddings[i][patch_inds[i]] for i in range(len(patch_inds))] + + if visual_bbox is None: + visual_bbox = get_visual_bbox(image_size=image_size, patch_size=patch_size) + visual_bbox = visual_bbox.unsqueeze(0).repeat(image_embeddings.size(0), 1, 1) + visual_bbox = visual_bbox.to(image_embeddings.device) + + visual_bbox = [visual_bbox[i][patch_inds[i]] for i in range(len(patch_inds))] + if attention_mask is not None: + visual_attention_mask = [torch.tensor([1] * len(item)).to(attention_mask) for item in visual_bbox] + + if max_len == 0: + max_len = image_embeddings.size(1) + else: + max_len = max_len - inputs_embeds.size(1) + inputs_vision_patches = torch.stack( + [pad_sequence(item, max_len, torch.zeros_like(image_embeddings[0, 0])) for item in input_vision_patches] + ) + visual_bbox = torch.stack([pad_sequence(item, max_len, torch.zeros_like(bbox[0, 0])) for item in visual_bbox]) + if attention_mask is not None: + visual_attention_mask = torch.stack( + [pad_sequence(item, max_len, torch.zeros_like(attention_mask[0, 0])) for item in visual_attention_mask] + ) + + inputs_embeds = torch.cat([inputs_embeds, inputs_vision_patches], 1) + bbox = torch.cat([bbox, visual_bbox], 1) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask, visual_attention_mask], 1) + return inputs_embeds, bbox, attention_mask + + +class UdopPatchEmbeddings(nn.Module): + """2D Image to Patch Embeddings""" + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.proj = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, pixel_values): + batch_size, num_channels, height, width = pixel_values.shape + if height != self.image_size[0] or width != self.image_size[1]: + raise ValueError( + f"Input image size ({height}*{width}) doesn't match model" + f" ({self.image_size[0]}*{self.image_size[1]})." + ) + embeddings = self.proj(pixel_values) + embeddings = embeddings.flatten(2).transpose(1, 2) + return embeddings + + +class UdopPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. Based on `T5PreTrainedModel`. + """ + + config_class = UdopConfig + base_model_prefix = "transformer" + supports_gradient_checkpointing = True + _no_split_modules = ["UdopBlock"] + _keep_in_fp32_modules = ["wo"] + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, UdopLayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=factor) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Conv2d): + # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid + # `trunc_normal_cpu` not implemented in `half` issues + module.weight.data = nn.init.trunc_normal_(module.weight.data.to(torch.float32), mean=0.0, std=factor).to( + module.weight.dtype + ) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, RelativePositionBiasBase): + factor = self.config.initializer_factor + d_model = self.config.d_model + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + elif isinstance(module, UdopModel): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, UdopForConditionalGeneration): + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, UdopDenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, UdopDenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, UdopAttention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) + + # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetPreTrainedModel._shift_right with ProphetNet->Udop + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In Udop it is usually set to the" + " pad_token_id. See Udop docs for more information" + ) + + # shift inputs to the right + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values" + + return shifted_input_ids + + +# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Udop +class UdopLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the Udop style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + # Udop uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->Udop +class UdopDenseActDense(nn.Module): + def __init__(self, config: UdopConfig): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Udop +class UdopDenseGatedActDense(nn.Module): + def __init__(self, config: UdopConfig): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + + # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32. + # See https://github.com/huggingface/transformers/issues/20287 + # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None`` + if ( + isinstance(self.wo.weight, torch.Tensor) + and hidden_states.dtype != self.wo.weight.dtype + and self.wo.weight.dtype != torch.int8 + ): + hidden_states = hidden_states.to(self.wo.weight.dtype) + + hidden_states = self.wo(hidden_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->Udop +class UdopLayerFF(nn.Module): + def __init__(self, config: UdopConfig): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = UdopDenseGatedActDense(config) + else: + self.DenseReluDense = UdopDenseActDense(config) + + self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Udop +class UdopAttention(nn.Module): + def __init__(self, config: UdopConfig, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + if len(past_key_value) != 2: + raise ValueError( + f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + ) + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Udop +class UdopLayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = UdopAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Udop +class UdopLayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False) + self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop +class UdopBlock(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(UdopLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(UdopLayerCrossAttention(config)) + + self.layer.append(UdopLayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + if past_key_value is not None: + if not self.is_decoder: + logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class UdopCellEmbeddings(nn.Module): + def __init__(self, max_2d_position_embeddings=501, hidden_size=1024): + super(UdopCellEmbeddings, self).__init__() + self.max_2d_position_embeddings = max_2d_position_embeddings + + self.x_position_embeddings = nn.Embedding(max_2d_position_embeddings, hidden_size) + self.y_position_embeddings = nn.Embedding(max_2d_position_embeddings, hidden_size) + + def forward(self, bbox): + bbox = torch.clip(bbox, 0.0, 1.0) + bbox = (bbox * (self.max_2d_position_embeddings - 1)).long() + left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) + upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) + right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) + lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) + + embeddings = ( + left_position_embeddings + + upper_position_embeddings + + right_position_embeddings + + lower_position_embeddings + ) + + return embeddings + + +# get function for bucket computation +# protected member access seems to be lesser evil than copy paste whole function +get_relative_position_bucket = UdopAttention._relative_position_bucket +AUGMENTATION_RANGE = (0.80, 1.25) + + +class RelativePositionBiasBase(nn.Module, ABC): + """ + Base class of relative biases. + + Args: + num_heads (`int`): + Number of attention heads in the model, it will create embeddings of size `num_heads`, which will be added to the scores of each token pair. + relative_attention_num_buckets (`int`, *optional*, defaults to 32): + Pair token metric (distance in the sequence, distance in pixels etc.) will be bucketed, parameter is defining number of such + buckets. + bidirectional (`bool`, *optional*, defaults to `True`): + Whether the distance should be bidirectional for a pair of tokens. If `False`, then distance(tok1, tok2) == distance(tok2, tok1). + scaling_factor (`int`, *optional*, defaults to 1): + Defining factor which will be used to scale relative distance. + max_distance (`int`, *optional*, defaults to 128): + All distances above this value will end up in the one/same bucket. + augmentation (`bool`, *optional*, defaults to `False`): + Whether to multiply relative distances by a random scalar. + expand (`bool`, *optional*, defaults to `False`): + Whether to expand an existing pretrained model with subsequent additions of prefix_bucket. + """ + + def __init__( + self, + num_heads=None, + relative_attention_num_buckets=32, + bidirectional=True, + scaling_factor=1, + max_distance=128, + level="tokens", + augmentation=False, + prefix_bucket=False, + expand=False, + ): + super(RelativePositionBiasBase, self).__init__() + self.prefix_bucket = prefix_bucket + self.augmentation = augmentation + self.level = level + self.max_distance = max_distance + self.scaling_factor = scaling_factor + self.bidirectional = bidirectional + self.num_heads = num_heads + self.expand = expand + self.relative_attention_num_buckets = relative_attention_num_buckets + extra_head = 2 if prefix_bucket and not self.expand else 0 + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets + extra_head, self.num_heads) + + @abstractmethod + def prepare_input( + self, + attention_mask: Optional[Tensor] = None, + bbox: Optional[Dict[str, Any]] = None, + ) -> Tensor: + pass + + def get_bucket(self, attention_mask: Optional[Tensor] = None, bbox: Optional[Dict[str, Any]] = None) -> Tensor: + relative_position = self.prepare_input(attention_mask, bbox) + rp_bucket: Tensor = get_relative_position_bucket( + relative_position, + bidirectional=self.bidirectional, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.max_distance, + ) + return rp_bucket + + def get_relative_position(self, positions): + context_position = positions[:, :, None] + memory_position = positions[:, None, :] + relative_position = memory_position - context_position + if self.augmentation and self.training: + relative_position *= random.uniform(*AUGMENTATION_RANGE) + relative_position *= self.scaling_factor + + return relative_position.to(torch.long) + + def forward(self, attention_mask: Optional[Tensor] = None, bbox: Optional[Dict[str, Any]] = None) -> Tensor: + # re-using pretrained model with subsequent addition of prefix_bucket + if self.expand and self.prefix_bucket: + new_bias = nn.Embedding(self.relative_attention_num_buckets + 2, self.num_heads) + new_bias.weight.data[: self.relative_attention_num_buckets] = self.relative_attention_bias.weight.data + new_bias.weight.data[self.relative_attention_num_buckets :] = 0.1 + self.relative_attention_bias = new_bias + self.expand = False + + rp_bucket = self.get_bucket(attention_mask, bbox) + + if self.prefix_bucket: + if rp_bucket.size(0) == 1 and attention_mask.size(0) > 1: + rp_bucket = rp_bucket.repeat(attention_mask.size(0), 1, 1) + # based on assumption that prefix bboxes are negative + is_prefix = bbox[:, :, 1] < 0 + num_prefix = is_prefix.sum(-1) + for idx, num_prefix_row in enumerate(num_prefix.cpu().numpy()): + rp_bucket[idx, :num_prefix_row, num_prefix_row:] = self.relative_attention_num_buckets + rp_bucket[idx, num_prefix_row:, :num_prefix_row] = self.relative_attention_num_buckets + 1 + + values: Tensor = self.relative_attention_bias(rp_bucket) + if values.dim() != 4: + raise ValueError("Wrong dimension of values tensor") + values = values.permute([0, 3, 1, 2]) + + return values + + +class RelativePositionBias1D(RelativePositionBiasBase): + def __init__(self, scaling_factor=1, max_distance=128, **kwargs): + """ + Reimplementation of T5 relative position bias. Distance between given tokens is their distance in the sequence. + Parameters are the same as in base class + """ + super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs) + + def prepare_input(self, attention_mask: Optional[Tensor] = None, bbox: Optional[Dict[str, Any]] = None) -> Tensor: + if self.scaling_factor != 1: + raise ValueError("No need to scale 1d features") + relative_position = self.get_relative_position( + torch.arange(attention_mask.size(1), dtype=torch.long, device=attention_mask.device)[None, :] + ) + + return relative_position + + +class RelativePositionBiasHorizontal(RelativePositionBiasBase): + def __init__(self, scaling_factor=100, max_distance=100, **kwargs): + """ + Represents in the bucket embeddings horizontal distance between two tokens. Parameters are the same as in base + class + """ + super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs) + + def prepare_input(self, attention_mask: Optional[Tensor] = None, bbox: Optional[Dict[str, Any]] = None) -> Tensor: + if not self.scaling_factor > 1.0: + raise ValueError("Need to scale the values of bboxes, as there are in small (0,1) range") + if bbox is None: + raise ValueError("Bbox is required for horizontal relative position bias") + # get x positions of left point of bbox + horizontal_position: Tensor = bbox[:, :, [0, 2]].mean(dim=-1) + + return self.get_relative_position(horizontal_position) + + +class RelativePositionBiasVertical(RelativePositionBiasBase): + def __init__(self, scaling_factor=100, max_distance=100, **kwargs): + """ + Represents in the bucket embeddings vertical distance between two tokens. Parameters are the same as in base + class + """ + super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs) + + def prepare_input(self, attention_mask: Optional[Tensor] = None, bbox: Optional[Dict[str, Any]] = None) -> Tensor: + if not self.scaling_factor > 1.0: + raise ValueError("Need to scale the values of bboxes, as there are in small (0,1) range") + if bbox is None: + raise ValueError("Bbox is required for vertical relative position bias") + # get y positions of middle of bbox + vertical_position: Tensor = bbox[:, :, [1, 3]].mean(dim=-1) + + return self.get_relative_position(vertical_position) + + +class RelativePositionBiasAggregated(nn.Module): + def __init__(self, modules: Sequence[RelativePositionBiasBase]): + """ + Class which sums up various computed biases. + + Args: + modules (Sequence[RelativePositionBiasBase]): + List of relative bias modules. + """ + super().__init__() + self.biases = nn.ModuleList(modules) + + def forward( + self, attention_mask: Optional[Tensor] = None, bbox: Optional[Dict[str, Any]] = None + ) -> Union[float, Tensor]: + output = 0.0 + for bias in self.biases: # type: ignore + output = bias(attention_mask, bbox) + output + + return output + + +BIAS_CLASSES = { + "1d": RelativePositionBias1D, + "horizontal": RelativePositionBiasHorizontal, + "vertical": RelativePositionBiasVertical, +} + + +def create_relative_bias(config: UdopConfig) -> Sequence[RelativePositionBiasBase]: + """ + Creates empty list or one/multiple relative biases. + + :param config: Model's configuration :return: Sequence with created bias modules. + """ + bias_list = [] + if hasattr(config, "relative_bias_args"): + for bias_kwargs_org in config.relative_bias_args: + bias_kwargs = deepcopy(bias_kwargs_org) + bias_type = bias_kwargs.pop("type") + model_num_heads = config.num_heads if hasattr(config, "num_heads") else config.num_attention_heads + if "num_heads" in bias_kwargs: + if bias_kwargs["num_heads"] != model_num_heads: + raise ValueError("Number of heads must match num of heads in the model") + else: + bias_kwargs["num_heads"] = model_num_heads + bias_list.append(BIAS_CLASSES[bias_type](**bias_kwargs)) # type: ignore + + return bias_list + + +class UdopStack(UdopPreTrainedModel): + """ + This class is based on `T5Stack`, but modified to take into account the image modality as well as 2D position + embeddings. + """ + + def __init__(self, config, embed_tokens=None, embed_patches=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.embed_patches = embed_patches + self.is_decoder = config.is_decoder + self._max_length = config.max_length + self.num_layers = config.num_layers + + self.block = nn.ModuleList( + [UdopBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(self.num_layers)] + ) + self.final_layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) + + self.dropout = nn.Dropout(config.dropout_rate) + + if not self.is_decoder: + self.cell_2d_embedding = UdopCellEmbeddings(config.max_2d_position_embeddings, config.hidden_size) + + # get weights from encoder position bias + self.relative_bias = self._get_relative_bias(config) + + # tie weights of original position bias of encoder + for bias in self.relative_bias.biases: + if isinstance(bias, RelativePositionBias1D): + self._tie_or_clone_weights( + bias.relative_attention_bias, self.block[0].layer[0].SelfAttention.relative_attention_bias + ) + + @staticmethod + def _get_relative_bias(config: UdopConfig) -> RelativePositionBiasAggregated: + relative_bias_list = create_relative_bias(config) + return RelativePositionBiasAggregated(relative_bias_list) + + def get_input_embeddings(self): + return self.embed_tokens + + def get_output_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + bbox=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + pixel_values=None, + visual_bbox=None, + image_embeddings=None, + position_bias=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # input embeddings processing + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None and torch.numel(input_ids) > 0: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is None and input_ids is not None and torch.numel(input_ids) == 0: + input_ids = torch.full((4, 1024), self.config.pad_token_id, device=input_ids.device, dtype=input_ids.dtype) + attention_mask = torch.zeros((4, 1024), device=input_ids.device, dtype=input_ids.dtype) + bbox = torch.zeros((4, 1024, 4), device=input_ids.device, dtype=input_ids.dtype) + input_shape = input_ids.size() + position_bias = torch.zeros_like(self.get_extended_attention_mask(attention_mask, input_shape)) + # encoder_attention_mask = attention_mask + logger.warning("Empty batch") + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to intialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + if pixel_values is not None: + image_embeddings = self.embed_patches(pixel_values) + + if image_embeddings is not None: + # combine visual and OCR text embeddings + num_patches = self.config.image_size // self.config.patch_size + inputs_embeds, bbox, attention_mask = combine_image_text_embeddings( + image_embeddings, + inputs_embeds, + bbox, + visual_bbox, + attention_mask, + num_patches, + 0, + self.config.image_size, + self.config.patch_size, + ) + input_shape = inputs_embeds.size()[:-1] + + if not self.is_decoder and bbox is not None: + inputs_embeds += self.cell_2d_embedding(bbox) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self) + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + if self.is_decoder and encoder_attention_mask is not None: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + + if self.is_decoder: # modified lines + position_bias = None + else: + position_bias = self.relative_bias(attention_mask=attention_mask, bbox=bbox) + position_bias = position_bias + extended_attention_mask + encoder_decoder_position_bias = None + + hidden_states = inputs_embeds + + hidden_states = self.dropout(hidden_states) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=head_mask[i], + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) + if use_cache is False: # MP fixes + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention weights), + # (self-attention position bias), (cross-attention weights), (cross-attention position bias) + + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + attention_mask, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return BaseModelOutputWithAttentionMask( + last_hidden_state=hidden_states, + attention_mask=attention_mask, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + "The bare UDOP encoder-decoder Transformer outputting raw hidden-states without any specific head on top.", + UDOP_START_DOCSTRING, +) +class UdopModel(UdopPreTrainedModel): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "encoder.embed_patches.proj.weight", + "encoder.embed_patches.proj.bias", + "encoder.relative_bias.biases.0.relative_attention_bias.weight", + "decoder.relative_bias.biases.0.relative_attention_bias.weight", + ] + + def __init__(self, config): + super(UdopModel, self).__init__(config) + + # text and image embeddings + self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.patch_embed = UdopPatchEmbeddings(config) + + encoder_config = deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed) + + decoder_config = deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = UdopStack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(UDOP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Tensor = None, + attention_mask: Tensor = None, + bbox: Dict[str, Any] = None, + pixel_values: Optional[Tensor] = None, + visual_bbox: Dict[str, Any] = None, + decoder_input_ids: Optional[Tensor] = None, + decoder_attention_mask: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + encoder_outputs: Optional[Tensor] = None, + past_key_values: Optional[Tensor] = None, + head_mask: Optional[Tensor] = None, + decoder_inputs_embeds: Optional[Tensor] = None, + decoder_head_mask: Optional[Tensor] = None, + cross_attn_head_mask: Optional[Tensor] = None, + use_cache=True, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Tuple[Tensor, ...]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoProcessor, AutoModel + >>> from datasets import load_dataset + >>> import torch + + >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False) + >>> model = AutoModel.from_pretrained("microsoft/udop-large") + + >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") + >>> example = dataset[0] + >>> image = example["image"] + >>> words = example["tokens"] + >>> boxes = example["bboxes"] + >>> inputs = processor(image, words, boxes=boxes, return_tensors="pt") + + >>> decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]) + + >>> # forward pass + >>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + >>> list(last_hidden_states.shape) + [1, 1, 1024] + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + bbox=bbox, + pixel_values=pixel_values, + visual_bbox=visual_bbox, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + encoder_attention_mask = encoder_outputs.attention_mask if return_dict else encoder_outputs[1] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + # we filter out the attention mask + decoder_outputs = tuple(value for idx, value in enumerate(decoder_outputs) if idx != 1) + encoder_outputs = tuple(value for idx, value in enumerate(encoder_outputs) if idx != 1) + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The UDOP encoder-decoder Transformer with a language modeling head on top, enabling to generate text given document + images and an optional prompt. + + This class is based on [`T5ForConditionalGeneration`], extended to deal with images and layout (2D) data.""", + UDOP_START_DOCSTRING, +) +class UdopForConditionalGeneration(UdopPreTrainedModel): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + "encoder.embed_patches.proj.weight", + "encoder.embed_patches.proj.bias", + "encoder.relative_bias.biases.0.relative_attention_bias.weight", + "decoder.relative_bias.biases.0.relative_attention_bias.weight", + "lm_head.weight", + ] + + def __init__(self, config): + super(UdopForConditionalGeneration, self).__init__(config) + + # text and image embeddings + self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.patch_embed = UdopPatchEmbeddings(config) + + encoder_config = deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed) + + decoder_config = deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = UdopStack(decoder_config, self.shared) + + # The weights of the language modeling head are shared with those of the encoder and decoder + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(UDOP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Tensor = None, + attention_mask: Tensor = None, + bbox: Dict[str, Any] = None, + pixel_values: Optional[Tensor] = None, + visual_bbox: Dict[str, Any] = None, + decoder_input_ids: Optional[Tensor] = None, + decoder_attention_mask: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + encoder_outputs: Optional[Tensor] = None, + past_key_values: Optional[Tensor] = None, + head_mask: Optional[Tensor] = None, + decoder_inputs_embeds: Optional[Tensor] = None, + decoder_head_mask: Optional[Tensor] = None, + cross_attn_head_mask: Optional[Tensor] = None, + use_cache=True, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[Tensor] = None, + ) -> Tuple[Tensor, ...]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size - + 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size]`. + + Returns: + + Examples: + + ```python + >>> from transformers import AutoProcessor, UdopForConditionalGeneration + >>> from datasets import load_dataset + + >>> # load model and processor + >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False) + >>> model = UdopForConditionalGeneration.from_pretrained("microsoft/udop-large") + + >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") + >>> example = dataset[0] + >>> image = example["image"] + >>> words = example["tokens"] + >>> boxes = example["bboxes"] + >>> question = "Question answering. What is the date on the form?" + >>> encoding = processor(image, question, words, boxes=boxes, return_tensors="pt") + + >>> # autoregressive generation + >>> predicted_ids = model.generate(**encoding) + >>> print(processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]) + 9/30/92 + ```""" + + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if decoder_input_ids is None and labels is not None: + decoder_input_ids = self._shift_right(labels) + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + bbox=bbox, + visual_bbox=visual_bbox, + pixel_values=pixel_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + encoder_attention_mask = encoder_outputs.attention_mask if return_dict else encoder_outputs[1] + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.config.d_model**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[2:] + (encoder_outputs[0],) + encoder_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + "bbox": kwargs.get("bbox", None), + "pixel_values": kwargs.get("pixel_values", None), + "visual_bbox": kwargs.get("visual_bbox", None), + } + + # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + if reordered_layer_past_states[0].shape != layer_past_states[0].shape: + raise ValueError( + f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched" + ) + if len(reordered_layer_past_states) != len(layer_past_states): + raise ValueError( + f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched" + ) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare UDOP Model transformer outputting encoder's raw hidden-states without any specific head on top.", + UDOP_START_DOCSTRING, +) +class UdopEncoderModel(UdopPreTrainedModel): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "encoder.embed_patches.proj.weight", + "encoder.embed_patches.proj.bias", + "encoder.relative_bias.biases.0.relative_attention_bias.weight", + ] + + def __init__(self, config: UdopConfig): + super().__init__(config) + + # text and image embeddings + self.shared = nn.Embedding(config.vocab_size, config.d_model) + self.patch_embed = UdopPatchEmbeddings(config) + + encoder_config = deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(UDOP_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithAttentionMask, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Tensor = None, + bbox: Dict[str, Any] = None, + attention_mask: Tensor = None, + pixel_values: Optional[Tensor] = None, + visual_bbox: Dict[str, Any] = None, + head_mask: Optional[Tensor] = None, + inputs_embeds: Optional[Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithAttentionMask]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import AutoProcessor, UdopEncoderModel + >>> from huggingface_hub import hf_hub_download + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False) + >>> model = UdopEncoderModel.from_pretrained("microsoft/udop-large") + + >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train") + >>> example = dataset[0] + >>> image = example["image"] + >>> words = example["tokens"] + >>> boxes = example["bboxes"] + >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt") + + >>> outputs = model(**encoding) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_outputs = self.encoder( + input_ids=input_ids, + bbox=bbox, + visual_bbox=visual_bbox, + pixel_values=pixel_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/src/transformers/models/udop/processing_udop.py b/src/transformers/models/udop/processing_udop.py new file mode 100644 index 00000000000000..2902541d6f5b46 --- /dev/null +++ b/src/transformers/models/udop/processing_udop.py @@ -0,0 +1,204 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for UDOP. +""" + +from typing import List, Optional, Union + +from ...image_utils import ImageInput +from ...processing_utils import ProcessorMixin +from ...tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from ...utils import TensorType + + +class UdopProcessor(ProcessorMixin): + r""" + Constructs a UDOP processor which combines a LayoutLMv3 image processor and a UDOP tokenizer into a single processor. + + [`UdopProcessor`] offers all the functionalities you need to prepare data for the model. + + It first uses [`LayoutLMv3ImageProcessor`] to resize, rescale and normalize document images, and optionally applies OCR + to get words and normalized bounding boxes. These are then provided to [`UdopTokenizer`] or [`UdopTokenizerFast`], + which turns the words and bounding boxes into token-level `input_ids`, `attention_mask`, `token_type_ids`, `bbox`. + Optionally, one can provide integer `word_labels`, which are turned into token-level `labels` for token + classification tasks (such as FUNSD, CORD). + + Additionally, it also supports passing `text_target` and `text_pair_target` to the tokenizer, which can be used to + prepare labels for language modeling tasks. + + Args: + image_processor (`LayoutLMv3ImageProcessor`): + An instance of [`LayoutLMv3ImageProcessor`]. The image processor is a required input. + tokenizer (`UdopTokenizer` or `UdopTokenizerFast`): + An instance of [`UdopTokenizer`] or [`UdopTokenizerFast`]. The tokenizer is a required input. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "LayoutLMv3ImageProcessor" + tokenizer_class = ("UdopTokenizer", "UdopTokenizerFast") + + def __init__(self, image_processor, tokenizer): + super().__init__(image_processor, tokenizer) + + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, + boxes: Union[List[List[int]], List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = False, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchEncoding: + """ + This method first forwards the `images` argument to [`~UdopImageProcessor.__call__`]. In case + [`UdopImageProcessor`] was initialized with `apply_ocr` set to `True`, it passes the obtained words and + bounding boxes along with the additional arguments to [`~UdopTokenizer.__call__`] and returns the output, + together with the prepared `pixel_values`. In case [`UdopImageProcessor`] was initialized with `apply_ocr` set + to `False`, it passes the words (`text`/``text_pair`) and `boxes` specified by the user along with the + additional arguments to [`~UdopTokenizer.__call__`] and returns the output, together with the prepared + `pixel_values`. + + Alternatively, one can pass `text_target` and `text_pair_target` to prepare the targets of UDOP. + + Please refer to the docstring of the above two methods for more information. + """ + # verify input + if self.image_processor.apply_ocr and (boxes is not None): + raise ValueError( + "You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True." + ) + + if self.image_processor.apply_ocr and (word_labels is not None): + raise ValueError( + "You cannot provide word labels if you initialized the image processor with apply_ocr set to True." + ) + + if return_overflowing_tokens is True and return_offsets_mapping is False: + raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.") + + if text_target is not None: + # use the processor to prepare the targets of UDOP + return self.tokenizer( + text_target=text_target, + text_pair_target=text_pair_target, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + return_tensors=return_tensors, + ) + + else: + # use the processor to prepare the inputs of UDOP + # first, apply the image processor + features = self.image_processor(images=images, return_tensors=return_tensors) + + # second, apply the tokenizer + if text is not None and self.image_processor.apply_ocr and text_pair is None: + if isinstance(text, str): + text = [text] # add batch dimension (as the image processor always adds a batch dimension) + text_pair = features["words"] + + encoded_inputs = self.tokenizer( + text=text if text is not None else features["words"], + text_pair=text_pair if text_pair is not None else None, + boxes=boxes if boxes is not None else features["boxes"], + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + return_tensors=return_tensors, + ) + + # add pixel values + pixel_values = features.pop("pixel_values") + if return_overflowing_tokens is True: + pixel_values = self.get_overflowing_images(pixel_values, encoded_inputs["overflow_to_sample_mapping"]) + encoded_inputs["pixel_values"] = pixel_values + + return encoded_inputs + + # Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.get_overflowing_images + def get_overflowing_images(self, images, overflow_to_sample_mapping): + # in case there's an overflow, ensure each `input_ids` sample is mapped to its corresponding image + images_with_overflow = [] + for sample_idx in overflow_to_sample_mapping: + images_with_overflow.append(images[sample_idx]) + + if len(images_with_overflow) != len(overflow_to_sample_mapping): + raise ValueError( + "Expected length of images to be the same as the length of `overflow_to_sample_mapping`, but got" + f" {len(images_with_overflow)} and {len(overflow_to_sample_mapping)}" + ) + + return images_with_overflow + + # Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.batch_decode + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.decode + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer + to the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.layoutlmv3.processing_layoutlmv3.LayoutLMv3Processor.model_input_names + def model_input_names(self): + return ["input_ids", "bbox", "attention_mask", "pixel_values"] diff --git a/src/transformers/models/udop/tokenization_udop.py b/src/transformers/models/udop/tokenization_udop.py new file mode 100644 index 00000000000000..10e92db48cebba --- /dev/null +++ b/src/transformers/models/udop/tokenization_udop.py @@ -0,0 +1,1483 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +""" Tokenization classes for UDOP model.""" + + +import os +import re +import warnings +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import sentencepiece as spm + +from ...tokenization_utils import PreTrainedTokenizer +from ...tokenization_utils_base import ( + AddedToken, + BatchEncoding, + EncodedInput, + PreTokenizedInput, + TextInput, + TextInputPair, + TruncationStrategy, +) +from ...utils import PaddingStrategy, TensorType, add_end_docstrings, logging + + +logger = logging.get_logger(__name__) + + +SPIECE_UNDERLINE = "▁" + + +UDOP_ENCODE_KWARGS_DOCSTRING = r""" + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to encode the sequences with the special tokens relative to their model. + padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + return_token_type_ids (`bool`, *optional*): + Whether to return token type IDs. If left to the default, will return the token type IDs according to + the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are token type IDs?](../glossary#token-type-ids) + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_overflowing_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch + of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead + of returning overflowing tokens. + return_special_tokens_mask (`bool`, *optional*, defaults to `False`): + Whether or not to return special tokens mask information. + return_offsets_mapping (`bool`, *optional*, defaults to `False`): + Whether or not to return `(char_start, char_end)` for each token. + + This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using + Python's tokenizer, this method will raise `NotImplementedError`. + return_length (`bool`, *optional*, defaults to `False`): + Whether or not to return the lengths of the encoded inputs. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + **kwargs: passed to the `self.tokenize()` method + + Return: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + + [What are input IDs?](../glossary#input-ids) + + - **bbox** -- List of bounding boxes to be fed to a model. + + - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or + if *"token_type_ids"* is in `self.model_input_names`). + + [What are token type IDs?](../glossary#token-type-ids) + + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`). + + [What are attention masks?](../glossary#attention-mask) + + - **labels** -- List of labels to be fed to a model. (when `word_labels` is specified). + - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying + regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`). + - **length** -- The length of the inputs (when `return_length=True`). +""" + +VOCAB_FILES_NAMES = {"vocab_file": "spiece.model", "tokenizer_file": "tokenizer.json"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "microsoft/udop-large": "https://huggingface.co/microsoft/udop-large/resolve/main/spiece.model", + }, + "tokenizer_file": { + "microsoft/udop-large": "https://huggingface.co/microsoft/udop-large/resolve/main/tokenizer.json", + }, +} + + +# TODO(PVP) - this should be removed in Transformers v5 +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "microsoft/udop-large": 512, +} + + +class UdopTokenizer(PreTrainedTokenizer): + """ + Adapted from [`LayoutXLMTokenizer`] and [`T5Tokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + sep_token_box (`List[int]`, *optional*, defaults to `[1000, 1000, 1000, 1000]`): + The bounding box to use for the special [SEP] token. + pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`): + The bounding box to use for the special [PAD] token. + pad_token_label (`int`, *optional*, defaults to -100): + The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's + CrossEntropyLoss. + only_label_first_subword (`bool`, *optional*, defaults to `True`): + Whether or not to only label the first subword, in case word labels are provided. + additional_special_tokens (`List[str]`, *optional*, defaults to `["NOTUSED", "NOTUSED"]`): + Additional special tokens used by the tokenizer. + + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + legacy (`bool`, *optional*, defaults to `True`): + Whether or not the `legacy` behaviour of the tokenizer should be used. Legacy is before the merge of #24622 + which includes fixes to properly handle tokens that appear after special tokens. A simple example: + - `legacy=True`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True) + >>> tokenizer.encode("Hello .") + [8774, 32099, 3, 5, 1] + ``` + - `legacy=False`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False) + >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here + [8774, 32099, 5, 1] + ``` + Checkout the pull request and the issue [here](https://github.com/huggingface/transformers/pull/24565) for + more details. + add_prefix_space (`bool`, *optional*, defaults to `True`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. + + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + eos_token="", + unk_token="", + sep_token="", + pad_token="", + sep_token_box=[1000, 1000, 1000, 1000], + pad_token_box=[0, 0, 0, 0], + pad_token_label=-100, + only_label_first_subword=True, + additional_special_tokens=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + legacy=True, + add_prefix_space=True, + **kwargs, + ) -> None: + eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token + sep_token = AddedToken(sep_token, special=True) if isinstance(sep_token, str) else sep_token + pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token + + self.legacy = legacy + self.add_prefix_space = add_prefix_space + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.vocab_file = vocab_file + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + + # additional properties + self.sep_token_box = sep_token_box + self.pad_token_box = pad_token_box + self.pad_token_label = pad_token_label + self.only_label_first_subword = only_label_first_subword + + super().__init__( + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + sep_token_box=sep_token_box, + pad_token_box=pad_token_box, + pad_token_label=pad_token_label, + only_label_first_subword=only_label_first_subword, + additional_special_tokens=additional_special_tokens, + sp_model_kwargs=self.sp_model_kwargs, + legacy=legacy, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def vocab_size(self): + return len(self.sp_model) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + # normal case: some special tokens + if token_ids_1 is None: + return ([0] * len(token_ids_0)) + [1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_sentinel_tokens + def get_sentinel_tokens(self): + return list( + set(filter(lambda x: bool(re.search(r"", x)) is not None, self.additional_special_tokens)) + ) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_sentinel_token_ids + def get_sentinel_token_ids(self): + return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present + def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]: + """Do not add eos again if user already added it.""" + if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id: + warnings.warn( + f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated" + " eos tokens being added." + ) + return token_ids + else: + return token_ids + [self.eos_token_id] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A sequence has the following format: + + - single sequence: `X ` + - pair of sequences: `A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + token_ids_0 = self._add_eos_if_not_present(token_ids_0) + if token_ids_1 is None: + return token_ids_0 + else: + token_ids_1 = self._add_eos_if_not_present(token_ids_1) + return token_ids_0 + token_ids_1 + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__ + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + tokens = self.sp_model.encode(text, out_type=str) + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.sp_model.IdToPiece(index) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.convert_tokens_to_string + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string.strip() + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + @add_end_docstrings(UDOP_ENCODE_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, + boxes: Union[List[List[int]], List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + **kwargs, + ) -> BatchEncoding: + if text is None and text_target is None: + raise ValueError("You need to specify either `text` or `text_target`.") + if text is not None: + # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the + # input mode in this case. + if not self._in_target_context_manager: + self._switch_to_input_mode() + encodings = self.call_boxes(text=text, text_pair=text_pair, boxes=boxes, word_labels=word_labels, **kwargs) + if text_target is not None: + self._switch_to_target_mode() + target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **kwargs) + # Leave back tokenizer in input mode + self._switch_to_input_mode() + + if text_target is None: + return encodings + elif text is None: + return target_encodings + else: + encodings["labels"] = target_encodings["input_ids"] + return encodings + + def call_boxes( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, + boxes: Union[List[List[int]], List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences with word-level normalized bounding boxes and optional labels. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings + (words of a single example or questions of a batch of examples) or a list of list of strings (batch of + words). + text_pair (`List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence should be a list of strings + (pretokenized string). + boxes (`List[List[int]]`, `List[List[List[int]]]`): + Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale. + word_labels (`List[int]`, `List[List[int]]`, *optional*): + Word-level integer labels (for token classification tasks such as FUNSD, CORD). + """ + + # Input type checking for clearer error + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + if text_pair is not None: + # in case text + text_pair are provided, text = questions, text_pair = words + if not _is_valid_text_input(text): + raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ") + if not isinstance(text_pair, (list, tuple)): + raise ValueError( + "words must of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + else: + # in case only text is provided => must be words + if not isinstance(text, (list, tuple)): + raise ValueError( + "Words must of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + if text_pair is not None: + is_batched = isinstance(text, (list, tuple)) + else: + is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) + + words = text if text_pair is None else text_pair + if boxes is None: + raise ValueError("You must provide corresponding bounding boxes") + if is_batched: + if len(words) != len(boxes): + raise ValueError("You must provide words and boxes for an equal amount of examples") + for words_example, boxes_example in zip(words, boxes): + if len(words_example) != len(boxes_example): + raise ValueError("You must provide as many words as there are bounding boxes") + else: + if len(words) != len(boxes): + raise ValueError("You must provide as many words as there are bounding boxes") + + if is_batched: + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f" {len(text_pair)}." + ) + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + is_pair = bool(text_pair is not None) + return self.batch_encode_plus_boxes( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus_boxes( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def batch_encode_plus_boxes( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + ], + is_pair: bool = None, + boxes: Optional[List[List[List[int]]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a list of sequences or a list of pairs of sequences. + + Args: + batch_text_or_text_pairs (`List[str]`, `List[Tuple[str, str]]`, `List[List[str]]`, `List[Tuple[List[str], List[str]]]`, and for not-fast tokenizers, also `List[List[int]]`, `List[Tuple[List[int], List[int]]]`): + Batch of sequences or pair of sequences to be encoded. This can be a list of + string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see + details in `encode_plus`). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus_boxes( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def encode_boxes( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> List[int]: + """ + Args: + Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. Same as doing + `self.convert_tokens_to_ids(self.tokenize(text))`. + text (`str`, `List[str]` or `List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + text_pair (`str`, `List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + """ + encoded_inputs = self.encode_plus_boxes( + text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + def encode_plus_boxes( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. + + + + This method is deprecated, `__call__` should be used instead. + + + + Args: + text (`str`, `List[str]` or `List[int]` (the latter only for not-fast tokenizers)): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + text_pair (`str`, `List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus_boxes( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _batch_encode_plus_boxes( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + ], + is_pair: bool = None, + boxes: Optional[List[List[List[int]]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast." + ) + + batch_outputs = self._batch_prepare_for_model_boxes( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=return_tensors, + verbose=verbose, + ) + + return BatchEncoding(batch_outputs) + + @add_end_docstrings(UDOP_ENCODE_KWARGS_DOCSTRING) + def _batch_prepare_for_model_boxes( + self, + batch_text_or_text_pairs, + is_pair: bool = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_length: bool = False, + verbose: bool = True, + ) -> BatchEncoding: + """ + Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It + adds special tokens, truncates sequences if overflowing while taking into account the special tokens and + manages a moving window (with user defined stride) for overflowing tokens + + Args: + batch_ids_pairs: list of tokenized input ids or input ids pairs + """ + + batch_outputs = {} + for idx, example in enumerate(zip(batch_text_or_text_pairs, boxes)): + batch_text_or_text_pair, boxes_example = example + outputs = self.prepare_for_model_boxes( + batch_text_or_text_pair[0] if is_pair else batch_text_or_text_pair, + batch_text_or_text_pair[1] if is_pair else None, + boxes_example, + word_labels=word_labels[idx] if word_labels is not None else None, + add_special_tokens=add_special_tokens, + padding=PaddingStrategy.DO_NOT_PAD.value, # we pad in batch afterward + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=None, # we pad in batch afterward + return_attention_mask=False, # we pad in batch afterward + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + return_tensors=None, # We convert the whole batch to tensors at the end + prepend_batch_axis=False, + verbose=verbose, + ) + + for key, value in outputs.items(): + if key not in batch_outputs: + batch_outputs[key] = [] + batch_outputs[key].append(value) + + batch_outputs = self.pad( + batch_outputs, + padding=padding_strategy.value, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + batch_outputs = BatchEncoding(batch_outputs, tensor_type=return_tensors) + + return batch_outputs + + def _encode_plus_boxes( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if return_offsets_mapping: + raise NotImplementedError( + "return_offset_mapping is not available when using Python tokenizers. " + "To use this feature, change your tokenizer to one deriving from " + "transformers.PreTrainedTokenizerFast. " + "More information on available tokenizers at " + "https://github.com/huggingface/transformers/pull/2674" + ) + + return self.prepare_for_model_boxes( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding_strategy.value, + truncation=truncation_strategy.value, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + prepend_batch_axis=True, + return_attention_mask=return_attention_mask, + return_token_type_ids=return_token_type_ids, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_length=return_length, + verbose=verbose, + ) + + @add_end_docstrings(UDOP_ENCODE_KWARGS_DOCSTRING) + def prepare_for_model_boxes( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + prepend_batch_axis: bool = False, + **kwargs, + ) -> BatchEncoding: + """ + Prepares a sequence or a pair of sequences so that it can be used by the model. It adds special tokens, + truncates sequences if overflowing while taking into account the special tokens and manages a moving window + (with user defined stride) for overflowing tokens. + + Word-level `boxes` are turned into token-level `bbox`. If provided, word-level `word_labels` are turned into + token-level `labels`. The word label is used for the first token of the word, while remaining tokens are + labeled with -100, such that they will be ignored by the loss function. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The first sequence to be encoded. This can be a string, a list of strings or a list of list of strings. + text_pair (`List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a list of strings (words of a single example) or a + list of list of strings (words of a batch of examples). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + tokens = [] + pair_tokens = [] + token_boxes = [] + pair_token_boxes = [] + labels = [] + + if text_pair is None: + if word_labels is None: + # CASE 1: document image classification (training + inference) + CASE 2: token classification (inference) + for word, box in zip(text, boxes): + if len(word) < 1: # skip empty words + continue + word_tokens = self.tokenize(word) + tokens.extend(word_tokens) + token_boxes.extend([box] * len(word_tokens)) + else: + # CASE 2: token classification (training) + for word, box, label in zip(text, boxes, word_labels): + if len(word) < 1: # skip empty words + continue + word_tokens = self.tokenize(word) + tokens.extend(word_tokens) + token_boxes.extend([box] * len(word_tokens)) + if self.only_label_first_subword: + # Use the real label id for the first token of the word, and padding ids for the remaining tokens + labels.extend([label] + [self.pad_token_label] * (len(word_tokens) - 1)) + else: + labels.extend([label] * len(word_tokens)) + else: + # CASE 3: document visual question answering (inference) + # text = question + # text_pair = words + tokens = self.tokenize(text) + token_boxes = [self.pad_token_box for _ in range(len(tokens))] + + for word, box in zip(text_pair, boxes): + if len(word) < 1: # skip empty words + continue + word_tokens = self.tokenize(word) + pair_tokens.extend(word_tokens) + pair_token_boxes.extend([box] * len(word_tokens)) + + # Create ids + pair_ids + ids = self.convert_tokens_to_ids(tokens) + pair_ids = self.convert_tokens_to_ids(pair_tokens) if pair_tokens else None + + # Compute the total size of the returned encodings + pair = bool(pair_ids is not None) + len_ids = len(ids) + len_pair_ids = len(pair_ids) if pair else 0 + total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0) + + # Truncation: Handle max sequence length + overflowing_tokens = [] + overflowing_token_boxes = [] + overflowing_labels = [] + if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length: + ( + ids, + token_boxes, + pair_ids, + pair_token_boxes, + labels, + overflowing_tokens, + overflowing_token_boxes, + overflowing_labels, + ) = self.truncate_sequences( + ids, + token_boxes, + pair_ids=pair_ids, + pair_token_boxes=pair_token_boxes, + labels=labels, + num_tokens_to_remove=total_len - max_length, + truncation_strategy=truncation_strategy, + stride=stride, + ) + + if return_token_type_ids and not add_special_tokens: + raise ValueError( + "Asking to return token_type_ids while setting add_special_tokens to False " + "results in an undefined behavior. Please set add_special_tokens to True or " + "set return_token_type_ids to None." + ) + + # Load from model defaults + if return_token_type_ids is None: + return_token_type_ids = "token_type_ids" in self.model_input_names + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + encoded_inputs = {} + + if return_overflowing_tokens: + encoded_inputs["overflowing_tokens"] = overflowing_tokens + encoded_inputs["overflowing_token_boxes"] = overflowing_token_boxes + encoded_inputs["overflowing_labels"] = overflowing_labels + encoded_inputs["num_truncated_tokens"] = total_len - max_length + + # Add special tokens + if add_special_tokens: + sequence = self.build_inputs_with_special_tokens(ids, pair_ids) + token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids) + token_boxes = token_boxes + [self.sep_token_box] + if pair_token_boxes: + pair_token_boxes = pair_token_boxes + [self.sep_token_box] + if labels: + labels = labels + [self.pad_token_label] + else: + sequence = ids + pair_ids if pair else ids + token_type_ids = [0] * len(ids) + ([0] * len(pair_ids) if pair else []) + + # Build output dictionary + encoded_inputs["input_ids"] = sequence + encoded_inputs["bbox"] = token_boxes + pair_token_boxes + if return_token_type_ids: + encoded_inputs["token_type_ids"] = token_type_ids + if return_special_tokens_mask: + if add_special_tokens: + encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids) + else: + encoded_inputs["special_tokens_mask"] = [0] * len(sequence) + + if labels: + encoded_inputs["labels"] = labels + + # Check lengths + self._eventual_warn_about_too_long_sequence(encoded_inputs["input_ids"], max_length, verbose) + + # Padding + if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask: + encoded_inputs = self.pad( + encoded_inputs, + max_length=max_length, + padding=padding_strategy.value, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask, + ) + + if return_length: + encoded_inputs["length"] = len(encoded_inputs["input_ids"]) + + batch_outputs = BatchEncoding( + encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis + ) + + return batch_outputs + + # Copied from transformers.models.layoutxlm.tokenization_layoutxlm.LayoutXLMTokenizer.truncate_sequences + def truncate_sequences( + self, + ids: List[int], + token_boxes: List[List[int]], + pair_ids: Optional[List[int]] = None, + pair_token_boxes: Optional[List[List[int]]] = None, + labels: Optional[List[int]] = None, + num_tokens_to_remove: int = 0, + truncation_strategy: Union[str, TruncationStrategy] = "longest_first", + stride: int = 0, + ) -> Tuple[List[int], List[int], List[int]]: + """ + Truncates a sequence pair in-place following the strategy. + + Args: + ids (`List[int]`): + Tokenized input ids of the first sequence. Can be obtained from a string by chaining the `tokenize` and + `convert_tokens_to_ids` methods. + token_boxes (`List[List[int]]`): + Bounding boxes of the first sequence. + pair_ids (`List[int]`, *optional*): + Tokenized input ids of the second sequence. Can be obtained from a string by chaining the `tokenize` + and `convert_tokens_to_ids` methods. + pair_token_boxes (`List[List[int]]`, *optional*): + Bounding boxes of the second sequence. + labels (`List[int]`, *optional*): + Labels of the first sequence (for token classification tasks). + num_tokens_to_remove (`int`, *optional*, defaults to 0): + Number of tokens to remove using the truncation strategy. + truncation_strategy (`str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + The strategy to follow for truncation. Can be: + + - `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will truncate + token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a + batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths greater + than the model maximum admissible input size). + stride (`int`, *optional*, defaults to 0): + If set to a positive number, the overflowing tokens returned will contain some tokens from the main + sequence returned. The value of this argument defines the number of additional tokens. + + Returns: + `Tuple[List[int], List[int], List[int]]`: The truncated `ids`, the truncated `pair_ids` and the list of + overflowing tokens. + """ + if num_tokens_to_remove <= 0: + return ids, token_boxes, pair_ids, pair_token_boxes, labels, [], [], [] + + if not isinstance(truncation_strategy, TruncationStrategy): + truncation_strategy = TruncationStrategy(truncation_strategy) + + overflowing_tokens = [] + overflowing_token_boxes = [] + overflowing_labels = [] + if truncation_strategy == TruncationStrategy.LONGEST_FIRST: + for _ in range(num_tokens_to_remove): + if pair_ids is None or len(ids) > len(pair_ids): + if not overflowing_tokens: + window_len = min(len(ids), stride + 1) + else: + window_len = 1 + overflowing_tokens.extend(ids[-window_len:]) + overflowing_token_boxes.extend(token_boxes[-window_len:]) + overflowing_labels.extend(labels[-window_len:]) + ids = ids[:-1] + token_boxes = token_boxes[:-1] + labels = labels[:-1] + else: + if not overflowing_tokens: + window_len = min(len(pair_ids), stride + 1) + else: + window_len = 1 + overflowing_tokens.extend(pair_ids[-window_len:]) + overflowing_token_boxes.extend(pair_token_boxes[-window_len:]) + pair_ids = pair_ids[:-1] + pair_token_boxes = pair_token_boxes[:-1] + elif truncation_strategy == TruncationStrategy.ONLY_FIRST: + if len(ids) > num_tokens_to_remove: + window_len = min(len(ids), stride + num_tokens_to_remove) + overflowing_tokens = ids[-window_len:] + overflowing_token_boxes = token_boxes[-window_len:] + overflowing_labels = labels[-window_len:] + ids = ids[:-num_tokens_to_remove] + token_boxes = token_boxes[:-num_tokens_to_remove] + labels = labels[:-num_tokens_to_remove] + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the first sequence has a length {len(ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + "for instance 'longest_first' or 'only_second'." + ) + elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None: + if len(pair_ids) > num_tokens_to_remove: + window_len = min(len(pair_ids), stride + num_tokens_to_remove) + overflowing_tokens = pair_ids[-window_len:] + overflowing_token_boxes = pair_token_boxes[-window_len:] + pair_ids = pair_ids[:-num_tokens_to_remove] + pair_token_boxes = pair_token_boxes[:-num_tokens_to_remove] + else: + logger.error( + f"We need to remove {num_tokens_to_remove} to truncate the input " + f"but the second sequence has a length {len(pair_ids)}. " + f"Please select another truncation strategy than {truncation_strategy}, " + "for instance 'longest_first' or 'only_first'." + ) + + return ( + ids, + token_boxes, + pair_ids, + pair_token_boxes, + labels, + overflowing_tokens, + overflowing_token_boxes, + overflowing_labels, + ) + + # Copied from transformers.models.layoutxlm.tokenization_layoutxlm.LayoutXLMTokenizer._pad + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + if self.padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference + if "labels" in encoded_inputs: + encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif self.padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"] + if "labels" in encoded_inputs: + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + + return encoded_inputs diff --git a/src/transformers/models/udop/tokenization_udop_fast.py b/src/transformers/models/udop/tokenization_udop_fast.py new file mode 100644 index 00000000000000..ee0697595508a7 --- /dev/null +++ b/src/transformers/models/udop/tokenization_udop_fast.py @@ -0,0 +1,1012 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +""" Tokenization classes for UDOP model.""" + + +import os +from shutil import copyfile +from typing import Dict, List, Optional, Tuple, Union + +from ...tokenization_utils_base import ( + BatchEncoding, + EncodedInput, + PreTokenizedInput, + TextInput, + TextInputPair, + TruncationStrategy, +) +from ...tokenization_utils_fast import PreTrainedTokenizerFast +from ...utils import PaddingStrategy, TensorType, add_end_docstrings, is_sentencepiece_available, logging +from ..udop.tokenization_udop import ( + PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES, + PRETRAINED_VOCAB_FILES_MAP, + VOCAB_FILES_NAMES, +) + + +if is_sentencepiece_available(): + from .tokenization_udop import UdopTokenizer +else: + UdopTokenizer = None + + +logger = logging.get_logger(__name__) + +UDOP_ENCODE_KWARGS_DOCSTRING = r""" + add_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to encode the sequences with the special tokens relative to their model. + padding (`bool`, `str` or [`~file_utils.PaddingStrategy`], *optional*, defaults to `False`): + Activates and controls padding. Accepts the following values: + + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + truncation (`bool`, `str` or [`~tokenization_utils_base.TruncationStrategy`], *optional*, defaults to `False`): + Activates and controls truncation. Accepts the following values: + + - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will + truncate token by token, removing a token from the longest sequence in the pair if a pair of + sequences (or a batch of pairs) is provided. + - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the + maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths + greater than the model maximum admissible input size). + max_length (`int`, *optional*): + Controls the maximum length to use by one of the truncation/padding parameters. + + If left unset or set to `None`, this will use the predefined model maximum length if a maximum length + is required by one of the truncation/padding parameters. If the model has no specific maximum input + length (like XLNet) truncation/padding to a maximum length will be deactivated. + stride (`int`, *optional*, defaults to 0): + If set to a number along with `max_length`, the overflowing tokens returned when + `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence + returned to provide some overlap between truncated and overflowing sequences. The value of this + argument defines the number of overlapping tokens. + pad_to_multiple_of (`int`, *optional*): + If set will pad the sequence to a multiple of the provided value. This is especially useful to enable + the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta). + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of list of python integers. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return Numpy `np.ndarray` objects. + return_token_type_ids (`bool`, *optional*): + Whether to return token type IDs. If left to the default, will return the token type IDs according to + the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are token type IDs?](../glossary#token-type-ids) + return_attention_mask (`bool`, *optional*): + Whether to return the attention mask. If left to the default, will return the attention mask according + to the specific tokenizer's default, defined by the `return_outputs` attribute. + + [What are attention masks?](../glossary#attention-mask) + return_overflowing_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to return overflowing token sequences. If a pair of sequences of input ids (or a batch + of pairs) is provided with `truncation_strategy = longest_first` or `True`, an error is raised instead + of returning overflowing tokens. + return_special_tokens_mask (`bool`, *optional*, defaults to `False`): + Whether or not to return special tokens mask information. + return_offsets_mapping (`bool`, *optional*, defaults to `False`): + Whether or not to return `(char_start, char_end)` for each token. + + This is only available on fast tokenizers inheriting from [`PreTrainedTokenizerFast`], if using + Python's tokenizer, this method will raise `NotImplementedError`. + return_length (`bool`, *optional*, defaults to `False`): + Whether or not to return the lengths of the encoded inputs. + verbose (`bool`, *optional*, defaults to `True`): + Whether or not to print more information and warnings. + **kwargs: passed to the `self.tokenize()` method + + Return: + [`BatchEncoding`]: A [`BatchEncoding`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + + [What are input IDs?](../glossary#input-ids) + + - **bbox** -- List of bounding boxes to be fed to a model. + + - **token_type_ids** -- List of token type ids to be fed to a model (when `return_token_type_ids=True` or + if *"token_type_ids"* is in `self.model_input_names`). + + [What are token type IDs?](../glossary#token-type-ids) + + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names`). + + [What are attention masks?](../glossary#attention-mask) + + - **labels** -- List of labels to be fed to a model. (when `word_labels` is specified). + - **overflowing_tokens** -- List of overflowing tokens sequences (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **num_truncated_tokens** -- Number of tokens truncated (when a `max_length` is specified and + `return_overflowing_tokens=True`). + - **special_tokens_mask** -- List of 0s and 1s, with 1 specifying added special tokens and 0 specifying + regular sequence tokens (when `add_special_tokens=True` and `return_special_tokens_mask=True`). + - **length** -- The length of the inputs (when `return_length=True`). +""" + + +class UdopTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a "fast" UDOP tokenizer (backed by HuggingFace's *tokenizers* library). Adapted from + [`LayoutXLMTokenizer`] and [`T5Tokenizer`]. Based on + [BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models). + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`, *optional*): + Path to the vocabulary file. + + tokenizer_file (`str`, *optional*): + Path to the tokenizer file. + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + sep_token_box (`List[int]`, *optional*, defaults to `[1000, 1000, 1000, 1000]`): + The bounding box to use for the special [SEP] token. + pad_token_box (`List[int]`, *optional*, defaults to `[0, 0, 0, 0]`): + The bounding box to use for the special [PAD] token. + pad_token_label (`int`, *optional*, defaults to -100): + The label to use for padding tokens. Defaults to -100, which is the `ignore_index` of PyTorch's + CrossEntropyLoss. + only_label_first_subword (`bool`, *optional*, defaults to `True`): + Whether or not to only label the first subword, in case word labels are provided. + additional_special_tokens (`List[str]`, *optional*, defaults to `["NOTUSED", "NOTUSED"]`): + Additional special tokens used by the tokenizer. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + slow_tokenizer_class = UdopTokenizer + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + eos_token="", + sep_token="", + unk_token="", + pad_token="", + sep_token_box=[1000, 1000, 1000, 1000], + pad_token_box=[0, 0, 0, 0], + pad_token_label=-100, + only_label_first_subword=True, + additional_special_tokens=None, + **kwargs, + ): + super().__init__( + vocab_file, + tokenizer_file=tokenizer_file, + eos_token=eos_token, + sep_token=sep_token, + unk_token=unk_token, + pad_token=pad_token, + sep_token_box=sep_token_box, + pad_token_box=pad_token_box, + pad_token_label=pad_token_label, + only_label_first_subword=only_label_first_subword, + additional_special_tokens=additional_special_tokens, + **kwargs, + ) + + self.vocab_file = vocab_file + + # additional properties + self.sep_token_box = sep_token_box + self.pad_token_box = pad_token_box + self.pad_token_label = pad_token_label + self.only_label_first_subword = only_label_first_subword + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + @add_end_docstrings(UDOP_ENCODE_KWARGS_DOCSTRING) + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, + boxes: Union[List[List[int]], List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + text_pair_target: Optional[ + Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] + ] = None, + **kwargs, + ) -> BatchEncoding: + if text is None and text_target is None: + raise ValueError("You need to specify either `text` or `text_target`.") + if text is not None: + # The context manager will send the inputs as normal texts and not text_target, but we shouldn't change the + # input mode in this case. + if not self._in_target_context_manager: + self._switch_to_input_mode() + encodings = self.call_boxes(text=text, text_pair=text_pair, boxes=boxes, word_labels=word_labels, **kwargs) + if text_target is not None: + self._switch_to_target_mode() + target_encodings = self._call_one(text=text_target, text_pair=text_pair_target, **kwargs) + # Leave back tokenizer in input mode + self._switch_to_input_mode() + + if text_target is None: + return encodings + elif text is None: + return target_encodings + else: + encodings["labels"] = target_encodings["input_ids"] + return encodings + + @add_end_docstrings(UDOP_ENCODE_KWARGS_DOCSTRING) + def call_boxes( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + text_pair: Optional[Union[PreTokenizedInput, List[PreTokenizedInput]]] = None, + boxes: Union[List[List[int]], List[List[List[int]]]] = None, + word_labels: Optional[Union[List[int], List[List[int]]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of + sequences with word-level normalized bounding boxes and optional labels. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string, a list of strings + (words of a single example or questions of a batch of examples) or a list of list of strings (batch of + words). + text_pair (`List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence should be a list of strings + (pretokenized string). + boxes (`List[List[int]]`, `List[List[List[int]]]`): + Word-level bounding boxes. Each bounding box should be normalized to be on a 0-1000 scale. + word_labels (`List[int]`, `List[List[int]]`, *optional*): + Word-level integer labels (for token classification tasks such as FUNSD, CORD). + """ + + # Input type checking for clearer error + def _is_valid_text_input(t): + if isinstance(t, str): + # Strings are fine + return True + elif isinstance(t, (list, tuple)): + # List are fine as long as they are... + if len(t) == 0: + # ... empty + return True + elif isinstance(t[0], str): + # ... list of strings + return True + elif isinstance(t[0], (list, tuple)): + # ... list with an empty list or with a list of strings + return len(t[0]) == 0 or isinstance(t[0][0], str) + else: + return False + else: + return False + + if text_pair is not None: + # in case text + text_pair are provided, text = questions, text_pair = words + if not _is_valid_text_input(text): + raise ValueError("text input must of type `str` (single example) or `List[str]` (batch of examples). ") + if not isinstance(text_pair, (list, tuple)): + raise ValueError( + "words must of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + else: + # in case only text is provided => must be words + if not isinstance(text, (list, tuple)): + raise ValueError( + "Words must of type `List[str]` (single pretokenized example), " + "or `List[List[str]]` (batch of pretokenized examples)." + ) + + if text_pair is not None: + is_batched = isinstance(text, (list, tuple)) + else: + is_batched = isinstance(text, (list, tuple)) and text and isinstance(text[0], (list, tuple)) + + words = text if text_pair is None else text_pair + if boxes is None: + raise ValueError("You must provide corresponding bounding boxes") + if is_batched: + if len(words) != len(boxes): + raise ValueError("You must provide words and boxes for an equal amount of examples") + for words_example, boxes_example in zip(words, boxes): + if len(words_example) != len(boxes_example): + raise ValueError("You must provide as many words as there are bounding boxes") + else: + if len(words) != len(boxes): + raise ValueError("You must provide as many words as there are bounding boxes") + + if is_batched: + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`:" + f" {len(text_pair)}." + ) + batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text + is_pair = bool(text_pair is not None) + return self.batch_encode_plus_boxes( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + else: + return self.encode_plus_boxes( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Copied from transformers.models.layoutxlm.tokenization_layoutxlm_fast.LayoutXLMTokenizerFast.tokenize + def tokenize(self, text: str, pair: Optional[str] = None, add_special_tokens: bool = False, **kwargs) -> List[str]: + batched_input = [(text, pair)] if pair else [text] + encodings = self._tokenizer.encode_batch( + batched_input, add_special_tokens=add_special_tokens, is_pretokenized=False, **kwargs + ) + + return encodings[0].tokens + + def batch_encode_plus_boxes( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + ], + is_pair: bool = None, + boxes: Optional[List[List[List[int]]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a list of sequences or a list of pairs of sequences. + + + + This method is deprecated, `__call__` should be used instead. + + + + Args: + batch_text_or_text_pairs (`List[str]`, `List[Tuple[str, str]]`, `List[List[str]]`, `List[Tuple[List[str], List[str]]]`, and for not-fast tokenizers, also `List[List[int]]`, `List[Tuple[List[int], List[int]]]`): + Batch of sequences or pair of sequences to be encoded. This can be a list of + string/string-sequences/int-sequences or a list of pair of string/string-sequences/int-sequence (see + details in `encode_plus`). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._batch_encode_plus_boxes( + batch_text_or_text_pairs=batch_text_or_text_pairs, + is_pair=is_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + def _batch_encode_plus_boxes( + self, + batch_text_or_text_pairs: Union[ + List[TextInput], + List[TextInputPair], + List[PreTokenizedInput], + ], + is_pair: bool = None, + boxes: Optional[List[List[List[int]]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[str] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + if not isinstance(batch_text_or_text_pairs, list): + raise TypeError(f"batch_text_or_text_pairs has to be a list (got {type(batch_text_or_text_pairs)})") + + # Set the truncation and padding strategy and restore the initial configuration + self.set_truncation_and_padding( + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + ) + + if is_pair: + batch_text_or_text_pairs = [(text.split(), text_pair) for text, text_pair in batch_text_or_text_pairs] + + encodings = self._tokenizer.encode_batch( + batch_text_or_text_pairs, + add_special_tokens=add_special_tokens, + is_pretokenized=True, # we set this to True as LayoutLMv2 always expects pretokenized inputs + ) + + # Convert encoding to dict + # `Tokens` has type: Tuple[ + # List[Dict[str, List[List[int]]]] or List[Dict[str, 2D-Tensor]], + # List[EncodingFast] + # ] + # with nested dimensions corresponding to batch, overflows, sequence length + tokens_and_encodings = [ + self._convert_encoding( + encoding=encoding, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=True + if word_labels is not None + else return_offsets_mapping, # we use offsets to create the labels + return_length=return_length, + verbose=verbose, + ) + for encoding in encodings + ] + + # Convert the output to have dict[list] from list[dict] and remove the additional overflows dimension + # From (variable) shape (batch, overflows, sequence length) to ~ (batch * overflows, sequence length) + # (we say ~ because the number of overflow varies with the example in the batch) + # + # To match each overflowing sample with the original sample in the batch + # we add an overflow_to_sample_mapping array (see below) + sanitized_tokens = {} + for key in tokens_and_encodings[0][0].keys(): + stack = [e for item, _ in tokens_and_encodings for e in item[key]] + sanitized_tokens[key] = stack + sanitized_encodings = [e for _, item in tokens_and_encodings for e in item] + + # If returning overflowing tokens, we need to return a mapping + # from the batch idx to the original sample + if return_overflowing_tokens: + overflow_to_sample_mapping = [] + for i, (toks, _) in enumerate(tokens_and_encodings): + overflow_to_sample_mapping += [i] * len(toks["input_ids"]) + sanitized_tokens["overflow_to_sample_mapping"] = overflow_to_sample_mapping + + for input_ids in sanitized_tokens["input_ids"]: + self._eventual_warn_about_too_long_sequence(input_ids, max_length, verbose) + + # create the token boxes + token_boxes = [] + for batch_index in range(len(sanitized_tokens["input_ids"])): + if return_overflowing_tokens: + original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index] + else: + original_index = batch_index + token_boxes_example = [] + for id, sequence_id, word_id in zip( + sanitized_tokens["input_ids"][batch_index], + sanitized_encodings[batch_index].sequence_ids, + sanitized_encodings[batch_index].word_ids, + ): + if word_id is not None: + if is_pair and sequence_id == 0: + token_boxes_example.append(self.pad_token_box) + else: + token_boxes_example.append(boxes[original_index][word_id]) + else: + if id == self.sep_token_id: + token_boxes_example.append(self.sep_token_box) + elif id == self.pad_token_id: + token_boxes_example.append(self.pad_token_box) + else: + raise ValueError("Id not recognized") + token_boxes.append(token_boxes_example) + + sanitized_tokens["bbox"] = token_boxes + + # optionally, create the labels + if word_labels is not None: + labels = [] + for batch_index in range(len(sanitized_tokens["input_ids"])): + if return_overflowing_tokens: + original_index = sanitized_tokens["overflow_to_sample_mapping"][batch_index] + else: + original_index = batch_index + labels_example = [] + previous_token_empty = False + for id, offset, word_id in zip( + sanitized_tokens["input_ids"][batch_index], + sanitized_tokens["offset_mapping"][batch_index], + sanitized_encodings[batch_index].word_ids, + ): + if word_id is not None: + if self.only_label_first_subword: + if offset[0] == 0 and not previous_token_empty: + # Use the real label id for the first token of the word, and padding ids for the remaining tokens + labels_example.append(word_labels[original_index][word_id]) + else: + labels_example.append(self.pad_token_label) + else: + labels_example.append(word_labels[original_index][word_id]) + if self.decode(id) == "": + previous_token_empty = True + else: + previous_token_empty = False + else: + labels_example.append(self.pad_token_label) + labels.append(labels_example) + + sanitized_tokens["labels"] = labels + # finally, remove offsets if the user didn't want them + if not return_offsets_mapping: + del sanitized_tokens["offset_mapping"] + + return BatchEncoding(sanitized_tokens, sanitized_encodings, tensor_type=return_tensors) + + def _encode_plus_boxes( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[int]] = None, + add_special_tokens: bool = True, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, + max_length: Optional[int] = None, + stride: int = 0, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[bool] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + # make it a batched input + # 2 options: + # 1) only text, in case text must be a list of str + # 2) text + text_pair, in which case text = str and text_pair a list of str + batched_input = [(text, text_pair)] if text_pair else [text] + batched_boxes = [boxes] + batched_word_labels = [word_labels] if word_labels is not None else None + batched_output = self._batch_encode_plus_boxes( + batched_input, + is_pair=bool(text_pair is not None), + boxes=batched_boxes, + word_labels=batched_word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Return tensor is None, then we can remove the leading batch axis + # Overflowing tokens are returned as a batch of output so we keep them in this case + if return_tensors is None and not return_overflowing_tokens: + batched_output = BatchEncoding( + { + key: value[0] if len(value) > 0 and isinstance(value[0], list) else value + for key, value in batched_output.items() + }, + batched_output.encodings, + ) + + self._eventual_warn_about_too_long_sequence(batched_output["input_ids"], max_length, verbose) + + return batched_output + + def encode_boxes( + self, + text: Union[TextInput, PreTokenizedInput, EncodedInput], + text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs, + ) -> List[int]: + """ + Args: + Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary. Same as doing + `self.convert_tokens_to_ids(self.tokenize(text))`. + text (`str`, `List[str]` or `List[int]`): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + text_pair (`str`, `List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + """ + encoded_inputs = self.encode_plus_boxes( + text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding=padding, + truncation=truncation, + max_length=max_length, + stride=stride, + return_tensors=return_tensors, + **kwargs, + ) + + return encoded_inputs["input_ids"] + + def encode_plus_boxes( + self, + text: Union[TextInput, PreTokenizedInput], + text_pair: Optional[PreTokenizedInput] = None, + boxes: Optional[List[List[int]]] = None, + word_labels: Optional[List[List[int]]] = None, + add_special_tokens: bool = True, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Union[bool, str, TruncationStrategy] = None, + max_length: Optional[int] = None, + stride: int = 0, + is_split_into_words: bool = False, + pad_to_multiple_of: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + return_token_type_ids: Optional[bool] = None, + return_attention_mask: Optional[bool] = None, + return_overflowing_tokens: bool = False, + return_special_tokens_mask: bool = False, + return_offsets_mapping: bool = False, + return_length: bool = False, + verbose: bool = True, + **kwargs, + ) -> BatchEncoding: + """ + Tokenize and prepare for the model a sequence or a pair of sequences. + + + + This method is deprecated, `__call__` should be used instead. + + + + Args: + text (`str`, `List[str]` or `List[int]` (the latter only for not-fast tokenizers)): + The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the + `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + text_pair (`str`, `List[str]` or `List[int]`, *optional*): + Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using + the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids` + method). + """ + + # Backward compatibility for 'truncation_strategy', 'pad_to_max_length' + padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies( + padding=padding, + truncation=truncation, + max_length=max_length, + pad_to_multiple_of=pad_to_multiple_of, + verbose=verbose, + **kwargs, + ) + + return self._encode_plus_boxes( + text=text, + text_pair=text_pair, + boxes=boxes, + word_labels=word_labels, + add_special_tokens=add_special_tokens, + padding_strategy=padding_strategy, + truncation_strategy=truncation_strategy, + max_length=max_length, + stride=stride, + is_split_into_words=is_split_into_words, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors=return_tensors, + return_token_type_ids=return_token_type_ids, + return_attention_mask=return_attention_mask, + return_overflowing_tokens=return_overflowing_tokens, + return_special_tokens_mask=return_special_tokens_mask, + return_offsets_mapping=return_offsets_mapping, + return_length=return_length, + verbose=verbose, + **kwargs, + ) + + # Copied from transformers.models.layoutxlm.tokenization_layoutxlm_fast.LayoutXLMTokenizerFast._pad + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + if return_attention_mask is None: + return_attention_mask = "attention_mask" in self.model_input_names + + required_input = encoded_inputs[self.model_input_names[0]] + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if return_attention_mask and "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * len(required_input) + + if needs_to_be_padded: + difference = max_length - len(required_input) + if self.padding_side == "right": + if return_attention_mask: + encoded_inputs["attention_mask"] = encoded_inputs["attention_mask"] + [0] * difference + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = ( + encoded_inputs["token_type_ids"] + [self.pad_token_type_id] * difference + ) + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = encoded_inputs["bbox"] + [self.pad_token_box] * difference + if "labels" in encoded_inputs: + encoded_inputs["labels"] = encoded_inputs["labels"] + [self.pad_token_label] * difference + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"] + [1] * difference + encoded_inputs[self.model_input_names[0]] = required_input + [self.pad_token_id] * difference + elif self.padding_side == "left": + if return_attention_mask: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "bbox" in encoded_inputs: + encoded_inputs["bbox"] = [self.pad_token_box] * difference + encoded_inputs["bbox"] + if "labels" in encoded_inputs: + encoded_inputs["labels"] = [self.pad_token_label] * difference + encoded_inputs["labels"] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + else: + raise ValueError("Invalid padding strategy:" + str(self.padding_side)) + + return encoded_inputs + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An XLM-RoBERTa sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return token_ids_0 + [self.sep_token_id] + sep = [self.sep_token_id] + return token_ids_0 + sep + token_ids_1 + sep + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does + not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + sep) * [0] + return len(token_ids_0 + sep + token_ids_1 + sep) * [0] + + # Copied from transformers.models.layoutxlm.tokenization_layoutxlm_fast.LayoutXLMTokenizerFast.save_vocabulary + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory.") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5c635cf7af2c1c..8f7deb28327abc 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8341,6 +8341,37 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +UDOP_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class UdopEncoderModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class UdopForConditionalGeneration(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class UdopModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class UdopPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class UMT5EncoderModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_sentencepiece_objects.py b/src/transformers/utils/dummy_sentencepiece_objects.py index 5103626b263d35..33ee907a741f18 100644 --- a/src/transformers/utils/dummy_sentencepiece_objects.py +++ b/src/transformers/utils/dummy_sentencepiece_objects.py @@ -219,6 +219,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["sentencepiece"]) +class UdopTokenizer(metaclass=DummyObject): + _backends = ["sentencepiece"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["sentencepiece"]) + + class XGLMTokenizer(metaclass=DummyObject): _backends = ["sentencepiece"] diff --git a/src/transformers/utils/dummy_tokenizers_objects.py b/src/transformers/utils/dummy_tokenizers_objects.py index 5d792a0bbacde6..42b4397622f31d 100644 --- a/src/transformers/utils/dummy_tokenizers_objects.py +++ b/src/transformers/utils/dummy_tokenizers_objects.py @@ -408,6 +408,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["tokenizers"]) +class UdopTokenizerFast(metaclass=DummyObject): + _backends = ["tokenizers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tokenizers"]) + + class WhisperTokenizerFast(metaclass=DummyObject): _backends = ["tokenizers"] diff --git a/tests/models/udop/__init__.py b/tests/models/udop/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/udop/test_modeling_udop.py b/tests/models/udop/test_modeling_udop.py new file mode 100644 index 00000000000000..3947da62cc6fe6 --- /dev/null +++ b/tests/models/udop/test_modeling_udop.py @@ -0,0 +1,567 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import unittest + +from huggingface_hub import hf_hub_download + +from transformers import UdopConfig, is_torch_available, is_vision_available +from transformers.testing_utils import ( + require_sentencepiece, + require_tokenizers, + require_torch, + require_vision, + slow, + torch_device, +) +from transformers.utils import cached_property + +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, ids_tensor +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import UdopEncoderModel, UdopForConditionalGeneration, UdopModel, UdopProcessor + from transformers.models.udop.modeling_udop import UDOP_PRETRAINED_MODEL_ARCHIVE_LIST + + +if is_vision_available(): + from PIL import Image + + +class UdopModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + encoder_seq_length=7, + decoder_seq_length=9, + # For common tests + is_training=True, + use_attention_mask=True, + use_labels=True, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=32, + dropout_rate=0.1, + initializer_factor=0.002, + eos_token_id=1, + pad_token_id=0, + scope=None, + decoder_layers=None, + range_bbox=1000, + decoder_start_token_id=0, + ): + self.parent = parent + self.batch_size = batch_size + self.encoder_seq_length = encoder_seq_length + self.decoder_seq_length = decoder_seq_length + # For common tests + self.seq_length = self.decoder_seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.scope = None + self.decoder_layers = decoder_layers + self.range_bbox = range_bbox + self.decoder_start_token_id = decoder_start_token_id + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.encoder_seq_length], self.vocab_size) + bbox = ids_tensor([self.batch_size, self.encoder_seq_length, 4], self.range_bbox).float() + # Ensure that bbox is legal + for i in range(bbox.shape[0]): + for j in range(bbox.shape[1]): + if bbox[i, j, 3] < bbox[i, j, 1]: + t = bbox[i, j, 3] + bbox[i, j, 3] = bbox[i, j, 1] + bbox[i, j, 1] = t + if bbox[i, j, 2] < bbox[i, j, 0]: + t = bbox[i, j, 2] + bbox[i, j, 2] = bbox[i, j, 0] + bbox[i, j, 0] = t + decoder_input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + attention_mask = None + decoder_attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.encoder_seq_length], vocab_size=2) + decoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2) + + lm_labels = None + if self.use_labels: + lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size) + + config = self.get_config() + + return ( + config, + input_ids, + bbox, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) + + def get_config(self): + return UdopConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + ) + + def create_and_check_model( + self, + config, + input_ids, + bbox, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = UdopModel(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids=input_ids, + bbox=bbox, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + ) + result = model(input_ids=input_ids, bbox=bbox, decoder_input_ids=decoder_input_ids) + decoder_output = result.last_hidden_state + decoder_past = result.past_key_values + encoder_output = result.encoder_last_hidden_state + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size)) + self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size)) + # There should be `num_layers` key value embeddings stored in decoder_past + self.parent.assertEqual(len(decoder_past), config.num_layers) + # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple + self.parent.assertEqual(len(decoder_past[0]), 4) + + def create_and_check_with_lm_head( + self, + config, + input_ids, + bbox, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = UdopForConditionalGeneration(config=config).to(torch_device).eval() + outputs = model( + input_ids=input_ids, + bbox=bbox, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + labels=lm_labels, + ) + self.parent.assertEqual(len(outputs), 4) + self.parent.assertEqual(outputs["logits"].size(), (self.batch_size, self.decoder_seq_length, self.vocab_size)) + self.parent.assertEqual(outputs["loss"].size(), ()) + + def create_and_check_generate_with_past_key_values( + self, + config, + input_ids, + bbox, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ): + model = UdopForConditionalGeneration(config=config).to(torch_device).eval() + torch.manual_seed(0) + output_without_past_cache = model.generate( + input_ids[:1], bbox=bbox[:1, :, :], num_beams=2, max_length=5, do_sample=True, use_cache=False + ) + torch.manual_seed(0) + output_with_past_cache = model.generate( + input_ids[:1], bbox=bbox[:1, :, :], num_beams=2, max_length=5, do_sample=True + ) + self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + bbox, + decoder_input_ids, + attention_mask, + decoder_attention_mask, + lm_labels, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "bbox": bbox, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "use_cache": False, + } + return config, inputs_dict + + +@require_torch +class UdopModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + UdopModel, + UdopForConditionalGeneration, + ) + if is_torch_available() + else () + ) + all_generative_model_classes = (UdopForConditionalGeneration,) if is_torch_available() else () + pipeline_model_mapping = {"feature-extraction": UdopModel} if is_torch_available() else {} + fx_compatible = False + test_pruning = False + test_torchscript = False + test_head_masking = False + test_resize_embeddings = True + test_model_parallel = False + is_encoder_decoder = True + # The small UDOP model needs higher percentages for CPU/MP tests + model_split_percents = [0.8, 0.9] + + def setUp(self): + self.model_tester = UdopModelTester(self) + self.config_tester = ConfigTester(self, config_class=UdopConfig, d_model=37) + + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = copy.deepcopy(inputs_dict) + if model_class.__name__ == "UdopForConditionalGeneration": + if return_labels: + inputs_dict["labels"] = torch.zeros( + (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device + ) + + return inputs_dict + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_with_lm_head(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_with_lm_head(*config_and_inputs) + + def test_generate_with_past_key_values(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_generate_with_past_key_values(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + @unittest.skip("Gradient checkpointing is not supported by this model") + def test_training_gradient_checkpointing(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant(self): + pass + + @unittest.skip( + reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124" + ) + def test_training_gradient_checkpointing_use_reentrant_false(self): + pass + + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.forward) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = sorted([*signature.parameters.keys()]) + + expected_arg_names = [ + "attention_mask", + "bbox", + "cross_attn_head_mask", + "decoder_attention_mask", + "decoder_head_mask", + "decoder_input_ids", + "decoder_inputs_embeds", + "encoder_outputs", + "head_mask", + "input_ids", + "inputs_embeds", + ] + if model_class in self.all_generative_model_classes: + expected_arg_names.append( + "labels", + ) + expected_arg_names = sorted(expected_arg_names) + self.assertListEqual(sorted(arg_names[: len(expected_arg_names)]), expected_arg_names) + + @unittest.skip( + "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" + ) + def test_save_load_low_cpu_mem_usage(self): + pass + + @slow + def test_model_from_pretrained(self): + for model_name in UDOP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = UdopForConditionalGeneration.from_pretrained(model_name) + self.assertIsNotNone(model) + + +class UdopEncoderOnlyModelTester: + def __init__( + self, + parent, + vocab_size=99, + batch_size=13, + seq_length=7, + # For common tests + is_training=False, + use_attention_mask=True, + hidden_size=32, + num_hidden_layers=5, + decoder_layers=2, + num_attention_heads=4, + d_ff=37, + relative_attention_num_buckets=32, + dropout_rate=0.1, + initializer_factor=0.002, + eos_token_id=1, + pad_token_id=0, + scope=None, + range_bbox=1000, + ): + self.parent = parent + self.batch_size = batch_size + # For common tests + self.seq_length = seq_length + self.is_training = is_training + self.use_attention_mask = use_attention_mask + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.decoder_layers = decoder_layers + self.num_attention_heads = num_attention_heads + self.d_ff = d_ff + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.initializer_factor = initializer_factor + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.scope = None + self.range_bbox = range_bbox + + def get_config(self): + return UdopConfig( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + is_encoder_decoder=False, + ) + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + bbox = ids_tensor([self.batch_size, self.seq_length, 4], self.range_bbox).float() + # Ensure that bbox is legal + for i in range(bbox.shape[0]): + for j in range(bbox.shape[1]): + if bbox[i, j, 3] < bbox[i, j, 1]: + t = bbox[i, j, 3] + bbox[i, j, 3] = bbox[i, j, 1] + bbox[i, j, 1] = t + if bbox[i, j, 2] < bbox[i, j, 0]: + t = bbox[i, j, 2] + bbox[i, j, 2] = bbox[i, j, 0] + bbox[i, j, 0] = t + + attention_mask = None + if self.use_attention_mask: + attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + config = self.get_config() + + return ( + config, + input_ids, + bbox, + attention_mask, + ) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + bbox, + attention_mask, + ) = config_and_inputs + + inputs_dict = { + "input_ids": input_ids, + "bbox": bbox, + "attention_mask": attention_mask, + } + return config, inputs_dict + + def create_and_check_model( + self, + config, + input_ids, + bbox, + attention_mask, + ): + model = UdopEncoderModel(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids=input_ids, + bbox=bbox, + attention_mask=attention_mask, + ) + encoder_output = result.last_hidden_state + + self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.seq_length, self.hidden_size)) + + def create_and_check_model_fp16_forward( + self, + config, + input_ids, + attention_mask, + ): + model = UdopEncoderModel(config=config).to(torch_device).half().eval() + output = model(input_ids, attention_mask=attention_mask)["last_hidden_state"] + self.parent.assertFalse(torch.isnan(output).any().item()) + + +class UdopEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase): + all_model_classes = (UdopEncoderModel,) if is_torch_available() else () + test_pruning = False + test_torchscript = False + test_head_masking = False + test_resize_embeddings = False + test_model_parallel = True + all_parallelizable_model_classes = (UdopEncoderModel,) if is_torch_available() else () + + def setUp(self): + self.model_tester = UdopEncoderOnlyModelTester(self) + self.config_tester = ConfigTester(self, config_class=UdopConfig, d_model=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") + def test_model_fp16_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + + @unittest.skip( + "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" + ) + def test_save_load_low_cpu_mem_usage(self): + pass + + +@require_torch +@require_sentencepiece +@require_tokenizers +@require_vision +@slow +class UdopModelIntegrationTests(unittest.TestCase): + @cached_property + def image(self): + filepath = hf_hub_download( + repo_id="hf-internal-testing/fixtures_docvqa", filename="document_2.png", repo_type="dataset" + ) + image = Image.open(filepath).convert("RGB") + + return image + + @cached_property + def processor(self): + return UdopProcessor.from_pretrained("microsoft/udop-large") + + @cached_property + def model(self): + return UdopForConditionalGeneration.from_pretrained("microsoft/udop-large").to(torch_device) + + def test_conditional_generation(self): + processor = self.processor + model = self.model + + prompt = "Question answering. In which year is the report made?" + encoding = processor(images=self.image, text=prompt, return_tensors="pt") + + predicted_ids = model.generate(**encoding) + + predicted_text = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] + self.assertEquals(predicted_text, "2013") diff --git a/tests/models/udop/test_processor_udop.py b/tests/models/udop/test_processor_udop.py new file mode 100644 index 00000000000000..05855991b185ea --- /dev/null +++ b/tests/models/udop/test_processor_udop.py @@ -0,0 +1,508 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import tempfile +import unittest +from typing import List + +import numpy as np + +from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast +from transformers.models.udop import UdopTokenizer, UdopTokenizerFast +from transformers.testing_utils import ( + require_pytesseract, + require_sentencepiece, + require_tokenizers, + require_torch, + slow, +) +from transformers.utils import FEATURE_EXTRACTOR_NAME, cached_property, is_pytesseract_available, is_torch_available + + +if is_torch_available(): + import torch + + +if is_pytesseract_available(): + from PIL import Image + + from transformers import LayoutLMv3ImageProcessor, UdopProcessor + + +@require_pytesseract +@require_sentencepiece +@require_tokenizers +class UdopProcessorTest(unittest.TestCase): + tokenizer_class = UdopTokenizer + rust_tokenizer_class = UdopTokenizerFast + maxDiff = None + + def setUp(self): + image_processor_map = { + "do_resize": True, + "size": 224, + "apply_ocr": True, + } + + self.tmpdirname = tempfile.mkdtemp() + self.feature_extraction_file = os.path.join(self.tmpdirname, FEATURE_EXTRACTOR_NAME) + with open(self.feature_extraction_file, "w", encoding="utf-8") as fp: + fp.write(json.dumps(image_processor_map) + "\n") + + self.tokenizer_pretrained_name = "microsoft/udop-large" + + def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer: + return self.tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs) + + def get_rust_tokenizer(self, **kwargs) -> PreTrainedTokenizerFast: + return self.rust_tokenizer_class.from_pretrained(self.tokenizer_pretrained_name, **kwargs) + + def get_tokenizers(self, **kwargs) -> List[PreTrainedTokenizerBase]: + return [self.get_tokenizer(**kwargs), self.get_rust_tokenizer(**kwargs)] + + def get_image_processor(self, **kwargs): + return LayoutLMv3ImageProcessor.from_pretrained(self.tmpdirname, **kwargs) + + def tearDown(self): + shutil.rmtree(self.tmpdirname) + + def prepare_image_inputs(self): + """This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True, + or a list of PyTorch tensors if one specifies torchify=True. + """ + + image_inputs = [np.random.randint(255, size=(3, 30, 400), dtype=np.uint8)] + + image_inputs = [Image.fromarray(np.moveaxis(x, 0, -1)) for x in image_inputs] + + return image_inputs + + def test_save_load_pretrained_default(self): + image_processor = self.get_image_processor() + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + processor = UdopProcessor(image_processor=image_processor, tokenizer=tokenizer) + + processor.save_pretrained(self.tmpdirname) + processor = UdopProcessor.from_pretrained(self.tmpdirname) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer.get_vocab()) + self.assertIsInstance(processor.tokenizer, (UdopTokenizer, UdopTokenizerFast)) + + self.assertEqual(processor.image_processor.to_json_string(), image_processor.to_json_string()) + self.assertIsInstance(processor.image_processor, LayoutLMv3ImageProcessor) + + def test_save_load_pretrained_additional_features(self): + processor = UdopProcessor(image_processor=self.get_image_processor(), tokenizer=self.get_tokenizer()) + processor.save_pretrained(self.tmpdirname) + + # slow tokenizer + tokenizer_add_kwargs = self.get_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + image_processor_add_kwargs = self.get_image_processor(do_resize=False, size=30) + + processor = UdopProcessor.from_pretrained( + self.tmpdirname, + use_fast=False, + bos_token="(BOS)", + eos_token="(EOS)", + do_resize=False, + size=30, + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, UdopTokenizer) + + self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.image_processor, LayoutLMv3ImageProcessor) + + # fast tokenizer + tokenizer_add_kwargs = self.get_rust_tokenizer(bos_token="(BOS)", eos_token="(EOS)") + image_processor_add_kwargs = self.get_image_processor(do_resize=False, size=30) + + processor = UdopProcessor.from_pretrained( + self.tmpdirname, use_xlm=True, bos_token="(BOS)", eos_token="(EOS)", do_resize=False, size=30 + ) + + self.assertEqual(processor.tokenizer.get_vocab(), tokenizer_add_kwargs.get_vocab()) + self.assertIsInstance(processor.tokenizer, UdopTokenizerFast) + + self.assertEqual(processor.image_processor.to_json_string(), image_processor_add_kwargs.to_json_string()) + self.assertIsInstance(processor.image_processor, LayoutLMv3ImageProcessor) + + def test_model_input_names(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = UdopProcessor(tokenizer=tokenizer, image_processor=image_processor) + + input_str = "lower newer" + image_input = self.prepare_image_inputs() + + inputs = processor(text=input_str, images=image_input) + + self.assertListEqual(list(inputs.keys()), processor.model_input_names) + + def test_text_target(self): + image_processor = self.get_image_processor() + tokenizer = self.get_tokenizer() + + processor = UdopProcessor(tokenizer=tokenizer, image_processor=image_processor) + + text = "hello world" + expected_decoding = "hello world" + + encoding_processor = processor(text_target=text) + encoding_tokenizer = tokenizer(text_target=text) + + self.assertListEqual(encoding_processor["input_ids"], [21820, 296, 1]) + self.assertListEqual(encoding_processor["attention_mask"], [1, 1, 1]) + self.assertDictEqual(dict(encoding_processor), dict(encoding_tokenizer)) + self.assertEqual(tokenizer.decode(encoding_processor["input_ids"]), expected_decoding) + + @slow + def test_overflowing_tokens(self): + # In the case of overflowing tokens, test that we still have 1-to-1 mapping between the images and input_ids (sequences that are too long are broken down into multiple sequences). + + from datasets import load_dataset + + # set up + datasets = load_dataset("nielsr/funsd") + processor = UdopProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False) + + def preprocess_data(examples): + images = [Image.open(path).convert("RGB") for path in examples["image_path"]] + words = examples["words"] + boxes = examples["bboxes"] + word_labels = examples["ner_tags"] + encoded_inputs = processor( + images, + words, + boxes=boxes, + word_labels=word_labels, + max_length=512, + padding="max_length", + truncation=True, + return_overflowing_tokens=True, + stride=50, + return_offsets_mapping=True, + return_tensors="pt", + ) + return encoded_inputs + + train_data = preprocess_data(datasets["train"]) + + self.assertEqual(len(train_data["pixel_values"]), len(train_data["input_ids"])) + + +# different use cases tests +@require_sentencepiece +@require_torch +@require_pytesseract +class UdopProcessorIntegrationTests(unittest.TestCase): + @cached_property + def get_images(self): + # we verify our implementation on 2 document images from the DocVQA dataset + from datasets import load_dataset + + ds = load_dataset("hf-internal-testing/fixtures_docvqa", split="test") + + image_1 = Image.open(ds[0]["file"]).convert("RGB") + image_2 = Image.open(ds[1]["file"]).convert("RGB") + + return image_1, image_2 + + @cached_property + def get_tokenizers(self): + slow_tokenizer = UdopTokenizer.from_pretrained("microsoft/udop-large") + fast_tokenizer = UdopTokenizerFast.from_pretrained("microsoft/udop-large") + return [slow_tokenizer, fast_tokenizer] + + @slow + def test_processor_case_1(self): + # case 1: document image classification (training, inference) + token classification (inference), apply_ocr = True + + image_processor = LayoutLMv3ImageProcessor() + tokenizers = self.get_tokenizers + images = self.get_images + + for tokenizer in tokenizers: + processor = UdopProcessor(image_processor=image_processor, tokenizer=tokenizer) + + # not batched + input_image_processor = image_processor(images[0], return_tensors="pt") + input_processor = processor(images[0], return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(input_processor.keys()) + self.assertListEqual(actual_keys, expected_keys) + + # verify pixel_values + self.assertTrue( + torch.allclose(input_image_processor["pixel_values"], input_processor["pixel_values"], atol=1e-2) + ) + + # verify input_ids + # this was obtained with Tesseract 4.1.1 + # fmt: off + expected_decoding = "11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President “Introductory Remarks” Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231 + # fmt: on + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # batched + input_image_processor = image_processor(images, return_tensors="pt") + input_processor = processor(images, padding=True, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(input_processor.keys()) + self.assertListEqual(actual_keys, expected_keys) + + # verify pixel_values + self.assertTrue( + torch.allclose(input_image_processor["pixel_values"], input_processor["pixel_values"], atol=1e-2) + ) + + # verify input_ids + # this was obtained with Tesseract 4.1.1 + # fmt: off + expected_decoding = "7 ITC Limited REPORT AND ACCOUNTS 2013 ITC’s Brands: An Asset for the Nation The consumer needs and aspirations they fulfil, the benefit they generate for millions across ITC’s value chains, the future-ready capabilities that support them, and the value that they create for the country, have made ITC’s brands national assets, adding to India’s competitiveness. It is ITC’s aspiration to be the No 1 FMCG player in the country, driven by its new FMCG businesses. A recent Nielsen report has highlighted that ITC's new FMCG businesses are the fastest growing among the top consumer goods companies operating in India. ITC takes justifiable pride that, along with generating economic value, these celebrated Indian brands also drive the creation of larger societal capital through the virtuous cycle of sustainable and inclusive growth. DI WILLS * ; LOVE DELIGHTFULLY SOFT SKIN? aia Ans Source: https://www.industrydocuments.ucsf.edu/docs/snbx0223" # noqa: E231 + # fmt: on + decoding = processor.decode(input_processor.input_ids[1].tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + @slow + def test_processor_case_2(self): + # case 2: document image classification (training, inference) + token classification (inference), apply_ocr=False + + image_processor = LayoutLMv3ImageProcessor(apply_ocr=False) + tokenizers = self.get_tokenizers + images = self.get_images + + for tokenizer in tokenizers: + processor = UdopProcessor(image_processor=image_processor, tokenizer=tokenizer) + + # not batched + words = ["hello", "world"] + boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] + input_processor = processor(images[0], words, boxes=boxes, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = list(input_processor.keys()) + for key in expected_keys: + self.assertIn(key, actual_keys) + + # verify input_ids + expected_decoding = "hello world" + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # batched + words = [["hello", "world"], ["my", "name", "is", "niels"]] + boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]] + input_processor = processor(images, words, boxes=boxes, padding=True, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(input_processor.keys()) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + expected_decoding = "hello world" + decoding = processor.decode(input_processor.input_ids[0].tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # verify bbox + expected_bbox = [ + [3, 2, 5, 1], + [6, 7, 4, 2], + [3, 9, 2, 4], + [1, 1, 2, 3], + [1, 1, 2, 3], + [1, 1, 2, 3], + [1000, 1000, 1000, 1000], + ] + self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox) + + @slow + def test_processor_case_3(self): + # case 3: token classification (training), apply_ocr=False + + image_processor = LayoutLMv3ImageProcessor(apply_ocr=False) + tokenizers = self.get_tokenizers + images = self.get_images + + for tokenizer in tokenizers: + processor = UdopProcessor(image_processor=image_processor, tokenizer=tokenizer) + + # not batched + words = ["weirdly", "world"] + boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] + word_labels = [1, 2] + input_processor = processor(images[0], words, boxes=boxes, word_labels=word_labels, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"] + actual_keys = sorted(input_processor.keys()) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + expected_decoding = "weirdly world" + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # verify labels + expected_labels = [1, -100, 2, -100] + self.assertListEqual(input_processor.labels.squeeze().tolist(), expected_labels) + + # batched + words = [["hello", "world"], ["my", "name", "is", "niels"]] + boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]] + word_labels = [[1, 2], [6, 3, 10, 2]] + input_processor = processor( + images, words, boxes=boxes, word_labels=word_labels, padding=True, return_tensors="pt" + ) + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"] + actual_keys = sorted(input_processor.keys()) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + expected_decoding = "my name is niels" + decoding = processor.decode(input_processor.input_ids[1].tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # verify bbox + expected_bbox = [ + [3, 2, 5, 1], + [6, 7, 4, 2], + [3, 9, 2, 4], + [1, 1, 2, 3], + [1, 1, 2, 3], + [1, 1, 2, 3], + [1000, 1000, 1000, 1000], + ] + self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox) + + # verify labels + expected_labels = [6, 3, 10, 2, -100, -100, -100] + self.assertListEqual(input_processor.labels[1].tolist(), expected_labels) + + @slow + def test_processor_case_4(self): + # case 4: visual question answering (inference), apply_ocr=True + + image_processor = LayoutLMv3ImageProcessor() + tokenizers = self.get_tokenizers + images = self.get_images + + for tokenizer in tokenizers: + processor = UdopProcessor(image_processor=image_processor, tokenizer=tokenizer) + + # not batched + question = "What's his name?" + input_processor = processor(images[0], question, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(input_processor.keys()) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + # this was obtained with Tesseract 4.1.1 + # fmt: off + expected_decoding = "What's his name? 11:14 to 11:39 a.m 11:39 to 11:44 a.m. 11:44 a.m. to 12:25 p.m. 12:25 to 12:58 p.m. 12:58 to 4:00 p.m. 2:00 to 5:00 p.m. Coffee Break Coffee will be served for men and women in the lobby adjacent to exhibit area. Please move into exhibit area. (Exhibits Open) TRRF GENERAL SESSION (PART |) Presiding: Lee A. Waller TRRF Vice President “Introductory Remarks” Lee A. Waller, TRRF Vice Presi- dent Individual Interviews with TRRF Public Board Members and Sci- entific Advisory Council Mem- bers Conducted by TRRF Treasurer Philip G. Kuehn to get answers which the public refrigerated warehousing industry is looking for. Plus questions from the floor. Dr. Emil M. Mrak, University of Cal- ifornia, Chairman, TRRF Board; Sam R. Cecil, University of Georgia College of Agriculture; Dr. Stanley Charm, Tufts University School of Medicine; Dr. Robert H. Cotton, ITT Continental Baking Company; Dr. Owen Fennema, University of Wis- consin; Dr. Robert E. Hardenburg, USDA. Questions and Answers Exhibits Open Capt. Jack Stoney Room TRRF Scientific Advisory Council Meeting Ballroom Foyer" # noqa: E231 + # fmt: on + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # batched + questions = ["How old is he?", "what's the time"] + input_processor = processor( + images, questions, padding="max_length", max_length=20, truncation=True, return_tensors="pt" + ) + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(input_processor.keys()) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + # this was obtained with Tesseract 4.1.1 + expected_decoding = "what's the time 7 ITC Limited REPORT AND ACCOUNTS 2013 I" + decoding = processor.decode(input_processor.input_ids[1].tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # verify bbox + # fmt: off + expected_bbox = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1000, 1000, 1000, 1000], [0, 45, 67, 80], [72, 56, 109, 67], [72, 56, 109, 67], [116, 56, 189, 67], [198, 59, 253, 66], [198, 59, 253, 66], [257, 59, 285, 66], [289, 59, 365, 66], [289, 59, 365, 66], [289, 59, 365, 66], [289, 59, 365, 66], [372, 59, 407, 66], [74, 136, 161, 158], [1000, 1000, 1000, 1000]] # noqa: E231 + # fmt: on + self.assertListEqual(input_processor.bbox[1].tolist(), expected_bbox) + + @slow + def test_processor_case_5(self): + # case 5: visual question answering (inference), apply_ocr=False + + image_processor = LayoutLMv3ImageProcessor(apply_ocr=False) + tokenizers = self.get_tokenizers + images = self.get_images + + for tokenizer in tokenizers: + processor = UdopProcessor(image_processor=image_processor, tokenizer=tokenizer) + + # not batched + question = "What's his name?" + words = ["hello", "world"] + boxes = [[1, 2, 3, 4], [5, 6, 7, 8]] + input_processor = processor(images[0], question, words, boxes, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(input_processor.keys()) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + expected_decoding = "What's his name? hello world" + decoding = processor.decode(input_processor.input_ids.squeeze().tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # batched + questions = ["How old is he?", "what's the time"] + words = [["hello", "world"], ["my", "name", "is", "niels"]] + boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[3, 2, 5, 1], [6, 7, 4, 2], [3, 9, 2, 4], [1, 1, 2, 3]]] + input_processor = processor(images, questions, words, boxes, padding=True, return_tensors="pt") + + # verify keys + expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] + actual_keys = sorted(input_processor.keys()) + self.assertListEqual(actual_keys, expected_keys) + + # verify input_ids + expected_decoding = "How old is he? hello world" + decoding = processor.decode(input_processor.input_ids[0].tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + expected_decoding = "what's the time my name is niels" + decoding = processor.decode(input_processor.input_ids[1].tolist()) + self.assertSequenceEqual(decoding, expected_decoding) + + # verify bbox + expected_bbox = [[3, 9, 2, 4], [1, 1, 2, 3], [1, 1, 2, 3], [1, 1, 2, 3], [1000, 1000, 1000, 1000]] + self.assertListEqual(input_processor.bbox[1].tolist()[-5:], expected_bbox) diff --git a/tests/models/udop/test_tokenization_udop.py b/tests/models/udop/test_tokenization_udop.py new file mode 100644 index 00000000000000..e9d41c5b77a872 --- /dev/null +++ b/tests/models/udop/test_tokenization_udop.py @@ -0,0 +1,1886 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import shutil +import tempfile +import unittest +from typing import List + +from transformers import ( + AddedToken, + SpecialTokensMixin, + UdopTokenizerFast, + is_tf_available, + is_torch_available, + logging, +) +from transformers.models.udop.tokenization_udop import UdopTokenizer +from transformers.testing_utils import ( + get_tests_dir, + is_pt_tf_cross_test, + require_pandas, + require_sentencepiece, + require_tokenizers, + require_torch, + slow, +) + +from ...test_tokenization_common import ( + SMALL_TRAINING_CORPUS, + TokenizerTesterMixin, + filter_non_english, + merge_model_tokenizer_mappings, +) + + +logger = logging.get_logger(__name__) +SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model") + + +@require_sentencepiece +@require_tokenizers +@require_pandas +class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + tokenizer_class = UdopTokenizer + rust_tokenizer_class = UdopTokenizerFast + test_rust_tokenizer = True + from_pretrained_filter = filter_non_english + test_seq2seq = False + test_sentencepiece = True + + def get_words_and_boxes(self): + words = ["a", "weirdly", "test", "hello"] + boxes = [[423, 237, 440, 251], [427, 272, 441, 287], [419, 115, 437, 129], [961, 885, 992, 912]] + + return words, boxes + + def get_words_and_boxes_batch(self): + words = [["a", "weirdly", "test"], ["hello", "my", "name", "is", "bob"]] + boxes = [ + [[423, 237, 440, 251], [427, 272, 441, 287], [419, 115, 437, 129]], + [[961, 885, 992, 912], [256, 38, 330, 58], [256, 38, 330, 58], [336, 42, 353, 57], [34, 42, 66, 69]], + ] + + return words, boxes + + def get_question_words_and_boxes(self): + question = "what's his name?" + words = ["a", "weirdly", "test"] + boxes = [[423, 237, 440, 251], [427, 272, 441, 287], [419, 115, 437, 129]] + + return question, words, boxes + + def get_question_words_and_boxes_batch(self): + questions = ["what's his name?", "how is he called?"] + words = [["a", "weirdly", "test"], ["what", "a", "laif", "gastn"]] + boxes = [ + [[423, 237, 440, 251], [427, 272, 441, 287], [419, 115, 437, 129]], + [[256, 38, 330, 58], [256, 38, 330, 58], [336, 42, 353, 57], [34, 42, 66, 69]], + ] + + return questions, words, boxes + + def setUp(self): + super().setUp() + + # We have a SentencePiece fixture for testing + tokenizer = UdopTokenizer(SAMPLE_VOCAB, keep_accents=True) + tokenizer.save_pretrained(self.tmpdirname) + + def get_input_output_texts(self, tokenizer): + input_text = "UNwant\u00E9d,running" + output_text = "unwanted, running" + return input_text, output_text + + # override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of + # this tokenizer + def test_save_sentencepiece_tokenizer(self) -> None: + if not self.test_sentencepiece or not self.test_slow_tokenizer: + return + # We want to verify that we will be able to save the tokenizer even if the original files that were used to + # build the tokenizer have been deleted in the meantime. + words, boxes = self.get_words_and_boxes() + + tokenizer_slow_1 = self.get_tokenizer() + encoding_tokenizer_slow_1 = tokenizer_slow_1( + words, + boxes=boxes, + ) + + tmpdirname_1 = tempfile.mkdtemp() + tmpdirname_2 = tempfile.mkdtemp() + + tokenizer_slow_1.save_pretrained(tmpdirname_1) + tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1) + encoding_tokenizer_slow_2 = tokenizer_slow_2( + words, + boxes=boxes, + ) + + shutil.rmtree(tmpdirname_1) + tokenizer_slow_2.save_pretrained(tmpdirname_2) + + tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2) + encoding_tokenizer_slow_3 = tokenizer_slow_3( + words, + boxes=boxes, + ) + shutil.rmtree(tmpdirname_2) + + self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2) + self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3) + + @slow + def test_sequence_builders(self): + tokenizer = self.tokenizer_class.from_pretrained("microsoft/udop-large") + + question, words, boxes = self.get_question_words_and_boxes() + + text = tokenizer.encode_boxes( + question.split(), + boxes=[tokenizer.pad_token_box for _ in range(len(question.split()))], + add_special_tokens=False, + ) + text_2 = tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=False) + + encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2) + + assert encoded_pair == text + [1] + text_2 + [1] + + def test_add_special_tokens(self): + tokenizers: List[UdopTokenizer] = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + special_token = "[SPECIAL_TOKEN]" + special_token_box = [1000, 1000, 1000, 1000] + + tokenizer.add_special_tokens({"cls_token": special_token}) + encoded_special_token = tokenizer.encode_boxes( + [special_token], boxes=[special_token_box], add_special_tokens=False + ) + self.assertEqual(len(encoded_special_token), 1) + + decoded = tokenizer.decode(encoded_special_token, skip_special_tokens=True) + self.assertTrue(special_token not in decoded) + + def test_add_tokens_tokenizer(self): + tokenizers: List[UdopTokenizer] = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + vocab_size = tokenizer.vocab_size + all_size = len(tokenizer) + + self.assertNotEqual(vocab_size, 0) + + # We usually have added tokens from the start in tests because our vocab fixtures are + # smaller than the original vocabs - let's not assert this + # self.assertEqual(vocab_size, all_size) + + new_toks = ["aaaaa", "bbbbbb", "cccccccccdddddddd"] + added_toks = tokenizer.add_tokens(new_toks) + vocab_size_2 = tokenizer.vocab_size + all_size_2 = len(tokenizer) + + self.assertNotEqual(vocab_size_2, 0) + self.assertEqual(vocab_size, vocab_size_2) + self.assertEqual(added_toks, len(new_toks)) + self.assertEqual(all_size_2, all_size + len(new_toks)) + + words = "aaaaa bbbbbb low cccccccccdddddddd l".split() + boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))] + + tokens = tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=False) + + self.assertGreaterEqual(len(tokens), 4) + self.assertGreater(tokens[0], tokenizer.vocab_size - 1) + self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) + + new_toks_2 = {"eos_token": ">>>>|||<||<<|<<", "pad_token": "<<<<<|||>|>>>>|>"} + added_toks_2 = tokenizer.add_special_tokens(new_toks_2) + vocab_size_3 = tokenizer.vocab_size + all_size_3 = len(tokenizer) + + self.assertNotEqual(vocab_size_3, 0) + self.assertEqual(vocab_size, vocab_size_3) + self.assertEqual(added_toks_2, len(new_toks_2)) + self.assertEqual(all_size_3, all_size_2 + len(new_toks_2)) + + words = ">>>>|||<||<<|<< aaaaabbbbbb low cccccccccdddddddd <<<<<|||>|>>>>|> l".split() + boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))] + + tokens = tokenizer.encode_boxes( + words, + boxes=boxes, + add_special_tokens=False, + ) + + self.assertGreaterEqual(len(tokens), 6) + self.assertGreater(tokens[0], tokenizer.vocab_size - 1) + self.assertGreater(tokens[0], tokens[1]) + self.assertGreater(tokens[-2], tokenizer.vocab_size - 1) + self.assertGreater(tokens[-2], tokens[-3]) + self.assertEqual(tokens[0], tokenizer.eos_token_id) + self.assertEqual(tokens[-2], tokenizer.pad_token_id) + + @require_tokenizers + def test_encode_decode_with_spaces(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + + new_toks = [AddedToken("[ABC]", normalized=False), AddedToken("[DEF]", normalized=False)] + tokenizer.add_tokens(new_toks) + input = "[ABC][DEF][ABC][DEF]" + if self.space_between_special_tokens: + output = "[ABC] [DEF] [ABC] [DEF]" + else: + output = input + encoded = tokenizer.encode_boxes(input.split(), boxes=boxes, add_special_tokens=False) + decoded = tokenizer.decode(encoded, spaces_between_special_tokens=self.space_between_special_tokens) + self.assertIn(decoded, [output, output.lower()]) + + def test_encode_plus_with_padding(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, words) + + padding_size = 10 + padding_idx = tokenizer.pad_token_id + + encoded_sequence = tokenizer.encode_plus_boxes(words, boxes=boxes, return_special_tokens_mask=True) + input_ids = encoded_sequence["input_ids"] + special_tokens_mask = encoded_sequence["special_tokens_mask"] + sequence_length = len(input_ids) + + # Test 'longest' and 'no_padding' don't do anything + tokenizer.padding_side = "right" + + not_padded_sequence = tokenizer.encode_plus_boxes( + words, + boxes=boxes, + padding=False, + return_special_tokens_mask=True, + ) + not_padded_input_ids = not_padded_sequence["input_ids"] + + not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"] + not_padded_sequence_length = len(not_padded_input_ids) + + self.assertTrue(sequence_length == not_padded_sequence_length) + self.assertTrue(input_ids == not_padded_input_ids) + self.assertTrue(special_tokens_mask == not_padded_special_tokens_mask) + + not_padded_sequence = tokenizer.encode_plus_boxes( + words, + boxes=boxes, + padding=False, + return_special_tokens_mask=True, + ) + not_padded_input_ids = not_padded_sequence["input_ids"] + + not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"] + not_padded_sequence_length = len(not_padded_input_ids) + + self.assertTrue(sequence_length == not_padded_sequence_length) + self.assertTrue(input_ids == not_padded_input_ids) + self.assertTrue(special_tokens_mask == not_padded_special_tokens_mask) + + # Test right padding + tokenizer.padding_side = "right" + + right_padded_sequence = tokenizer.encode_plus_boxes( + words, + boxes=boxes, + max_length=sequence_length + padding_size, + padding="max_length", + return_special_tokens_mask=True, + ) + right_padded_input_ids = right_padded_sequence["input_ids"] + + right_padded_special_tokens_mask = right_padded_sequence["special_tokens_mask"] + right_padded_sequence_length = len(right_padded_input_ids) + + self.assertTrue(sequence_length + padding_size == right_padded_sequence_length) + self.assertTrue(input_ids + [padding_idx] * padding_size == right_padded_input_ids) + self.assertTrue(special_tokens_mask + [1] * padding_size == right_padded_special_tokens_mask) + + # Test left padding + tokenizer.padding_side = "left" + left_padded_sequence = tokenizer.encode_plus_boxes( + words, + boxes=boxes, + max_length=sequence_length + padding_size, + padding="max_length", + return_special_tokens_mask=True, + ) + left_padded_input_ids = left_padded_sequence["input_ids"] + left_padded_special_tokens_mask = left_padded_sequence["special_tokens_mask"] + left_padded_sequence_length = len(left_padded_input_ids) + + self.assertTrue(sequence_length + padding_size == left_padded_sequence_length) + self.assertTrue([padding_idx] * padding_size + input_ids == left_padded_input_ids) + self.assertTrue([1] * padding_size + special_tokens_mask == left_padded_special_tokens_mask) + + if "token_type_ids" in tokenizer.model_input_names: + token_type_ids = encoded_sequence["token_type_ids"] + left_padded_token_type_ids = left_padded_sequence["token_type_ids"] + right_padded_token_type_ids = right_padded_sequence["token_type_ids"] + + assert token_type_ids + [0] * padding_size == right_padded_token_type_ids + assert [0] * padding_size + token_type_ids == left_padded_token_type_ids + + if "attention_mask" in tokenizer.model_input_names: + attention_mask = encoded_sequence["attention_mask"] + right_padded_attention_mask = right_padded_sequence["attention_mask"] + left_padded_attention_mask = left_padded_sequence["attention_mask"] + + self.assertTrue(attention_mask + [0] * padding_size == right_padded_attention_mask) + self.assertTrue([0] * padding_size + attention_mask == left_padded_attention_mask) + + def test_internal_consistency(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + + tokens = [] + for word in words: + tokens.extend(tokenizer.tokenize(word)) + ids = tokenizer.convert_tokens_to_ids(tokens) + ids_2 = tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=False) + self.assertListEqual(ids, ids_2) + + tokens_2 = tokenizer.convert_ids_to_tokens(ids) + self.assertNotEqual(len(tokens_2), 0) + text_2 = tokenizer.decode(ids) + self.assertIsInstance(text_2, str) + + output_text = "a weirdly test hello" + self.assertEqual(text_2, output_text) + + def test_mask_output(self): + tokenizers = self.get_tokenizers(fast=False, do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + + if ( + tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer" + and "token_type_ids" in tokenizer.model_input_names + ): + information = tokenizer.encode_plus_boxes(words, boxes=boxes, add_special_tokens=True) + sequences, mask = information["input_ids"], information["token_type_ids"] + self.assertEqual(len(sequences), len(mask)) + + def test_number_of_added_tokens(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + # test 1: single sequence + words, boxes = self.get_words_and_boxes() + + sequences = tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=False) + attached_sequences = tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=True) + + # Method is implemented (e.g. not GPT-2) + if len(attached_sequences) != 2: + self.assertEqual( + tokenizer.num_special_tokens_to_add(pair=False), len(attached_sequences) - len(sequences) + ) + + # test 2: two sequences + question, words, boxes = self.get_question_words_and_boxes() + + sequences = tokenizer.encode_boxes(question, words, boxes=boxes, add_special_tokens=False) + attached_sequences = tokenizer.encode_boxes(question, words, boxes=boxes, add_special_tokens=True) + + # Method is implemented (e.g. not GPT-2) + if len(attached_sequences) != 2: + self.assertEqual( + tokenizer.num_special_tokens_to_add(pair=True), len(attached_sequences) - len(sequences) + ) + + def test_padding_to_max_length(self): + """We keep this test for backward compatibility but it should be removed when `pad_to_max_length` will be deprecated""" + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + padding_size = 10 + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, words) + + padding_idx = tokenizer.pad_token_id + + # Check that it correctly pads when a maximum length is specified along with the padding flag set to True + tokenizer.padding_side = "right" + encoded_sequence = tokenizer.encode_boxes(words, boxes=boxes) + sequence_length = len(encoded_sequence) + # FIXME: the next line should be padding(max_length) to avoid warning + padded_sequence = tokenizer.encode_boxes( + words, boxes=boxes, max_length=sequence_length + padding_size, pad_to_max_length=True + ) + padded_sequence_length = len(padded_sequence) + assert sequence_length + padding_size == padded_sequence_length + assert encoded_sequence + [padding_idx] * padding_size == padded_sequence + + # Check that nothing is done when a maximum length is not specified + encoded_sequence = tokenizer.encode_boxes(words, boxes=boxes) + sequence_length = len(encoded_sequence) + + tokenizer.padding_side = "right" + padded_sequence_right = tokenizer.encode_boxes(words, boxes=boxes, pad_to_max_length=True) + padded_sequence_right_length = len(padded_sequence_right) + assert sequence_length == padded_sequence_right_length + assert encoded_sequence == padded_sequence_right + + def test_padding(self, max_length=50): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id) + pad_token_id = tokenizer_p.pad_token_id + + # Encode - Simple input + words, boxes = self.get_words_and_boxes() + input_r = tokenizer_r.encode_boxes(words, boxes=boxes, max_length=max_length, pad_to_max_length=True) + input_p = tokenizer_p.encode_boxes(words, boxes=boxes, max_length=max_length, pad_to_max_length=True) + self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id) + input_r = tokenizer_r.encode_boxes(words, boxes=boxes, max_length=max_length, padding="max_length") + input_p = tokenizer_p.encode_boxes(words, boxes=boxes, max_length=max_length, padding="max_length") + self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id) + + input_r = tokenizer_r.encode_boxes(words, boxes=boxes, padding="longest") + input_p = tokenizer_p.encode_boxes(words, boxes=boxes, padding=True) + self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id) + + # Encode - Pair input + question, words, boxes = self.get_question_words_and_boxes() + input_r = tokenizer_r.encode_boxes( + question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True + ) + input_p = tokenizer_p.encode_boxes( + question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True + ) + self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id) + input_r = tokenizer_r.encode_boxes( + question, words, boxes=boxes, max_length=max_length, padding="max_length" + ) + input_p = tokenizer_p.encode_boxes( + question, words, boxes=boxes, max_length=max_length, padding="max_length" + ) + self.assert_padded_input_match(input_r, input_p, max_length, pad_token_id) + input_r = tokenizer_r.encode_boxes(question, words, boxes=boxes, padding=True) + input_p = tokenizer_p.encode_boxes(question, words, boxes=boxes, padding="longest") + self.assert_padded_input_match(input_r, input_p, len(input_r), pad_token_id) + + # Encode_plus - Simple input + words, boxes = self.get_words_and_boxes() + input_r = tokenizer_r.encode_plus_boxes( + words, boxes=boxes, max_length=max_length, pad_to_max_length=True + ) + input_p = tokenizer_p.encode_plus_boxes( + words, boxes=boxes, max_length=max_length, pad_to_max_length=True + ) + self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id) + self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) + input_r = tokenizer_r.encode_plus_boxes( + words, boxes=boxes, max_length=max_length, padding="max_length" + ) + input_p = tokenizer_p.encode_plus_boxes( + words, boxes=boxes, max_length=max_length, padding="max_length" + ) + self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id) + self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) + + input_r = tokenizer_r.encode_plus_boxes(words, boxes=boxes, padding="longest") + input_p = tokenizer_p.encode_plus_boxes(words, boxes=boxes, padding=True) + self.assert_padded_input_match( + input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id + ) + + self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) + + # Encode_plus - Pair input + question, words, boxes = self.get_question_words_and_boxes() + input_r = tokenizer_r.encode_plus_boxes( + question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True + ) + input_p = tokenizer_p.encode_plus_boxes( + question, words, boxes=boxes, max_length=max_length, pad_to_max_length=True + ) + self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id) + self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) + input_r = tokenizer_r.encode_plus_boxes( + question, words, boxes=boxes, max_length=max_length, padding="max_length" + ) + input_p = tokenizer_p.encode_plus_boxes( + question, words, boxes=boxes, max_length=max_length, padding="max_length" + ) + self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id) + self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) + input_r = tokenizer_r.encode_plus_boxes(question, words, boxes=boxes, padding="longest") + input_p = tokenizer_p.encode_plus_boxes(question, words, boxes=boxes, padding=True) + self.assert_padded_input_match( + input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id + ) + self.assertSequenceEqual(input_r["attention_mask"], input_p["attention_mask"]) + + # Batch_encode_plus - Simple input + words, boxes = self.get_words_and_boxes_batch() + + input_r = tokenizer_r.batch_encode_plus_boxes( + words, + boxes=boxes, + max_length=max_length, + pad_to_max_length=True, + ) + input_p = tokenizer_p.batch_encode_plus_boxes( + words, + boxes=boxes, + max_length=max_length, + pad_to_max_length=True, + ) + self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id) + + input_r = tokenizer_r.batch_encode_plus_boxes( + words, + boxes=boxes, + max_length=max_length, + padding="max_length", + ) + input_p = tokenizer_p.batch_encode_plus_boxes( + words, + boxes=boxes, + max_length=max_length, + padding="max_length", + ) + self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id) + + input_r = tokenizer_r.batch_encode_plus_boxes( + words, + boxes=boxes, + max_length=max_length, + padding="longest", + ) + input_p = tokenizer_p.batch_encode_plus_boxes( + words, + boxes=boxes, + max_length=max_length, + padding=True, + ) + self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id) + + input_r = tokenizer_r.batch_encode_plus_boxes(words, boxes=boxes, padding="longest") + input_p = tokenizer_p.batch_encode_plus_boxes(words, boxes=boxes, padding=True) + self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id) + + # Batch_encode_plus - Pair input + questions, words, boxes = self.get_question_words_and_boxes_batch() + + input_r = tokenizer_r.batch_encode_plus_boxes( + list(zip(questions, words)), + is_pair=True, + boxes=boxes, + max_length=max_length, + truncation=True, + padding="max_length", + ) + input_p = tokenizer_p.batch_encode_plus_boxes( + list(zip(questions, words)), + is_pair=True, + boxes=boxes, + max_length=max_length, + truncation=True, + padding="max_length", + ) + self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id) + + input_r = tokenizer_r.batch_encode_plus_boxes( + list(zip(questions, words)), + is_pair=True, + boxes=boxes, + padding=True, + ) + input_p = tokenizer_p.batch_encode_plus_boxes( + list(zip(questions, words)), + is_pair=True, + boxes=boxes, + padding="longest", + ) + self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id) + + # Using pad on single examples after tokenization + words, boxes = self.get_words_and_boxes() + input_r = tokenizer_r.encode_plus_boxes(words, boxes=boxes) + input_r = tokenizer_r.pad(input_r) + + input_p = tokenizer_r.encode_plus_boxes(words, boxes=boxes) + input_p = tokenizer_r.pad(input_p) + + self.assert_padded_input_match( + input_r["input_ids"], input_p["input_ids"], len(input_r["input_ids"]), pad_token_id + ) + + # Using pad on single examples after tokenization + input_r = tokenizer_r.encode_plus_boxes(words, boxes=boxes) + input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length") + + input_p = tokenizer_r.encode_plus_boxes(words, boxes=boxes) + input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length") + + self.assert_padded_input_match(input_r["input_ids"], input_p["input_ids"], max_length, pad_token_id) + + # Using pad after tokenization + words, boxes = self.get_words_and_boxes_batch() + input_r = tokenizer_r.batch_encode_plus_boxes( + words, + boxes=boxes, + ) + input_r = tokenizer_r.pad(input_r) + + input_p = tokenizer_r.batch_encode_plus_boxes( + words, + boxes=boxes, + ) + input_p = tokenizer_r.pad(input_p) + + self.assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]), pad_token_id) + + # Using pad after tokenization + words, boxes = self.get_words_and_boxes_batch() + input_r = tokenizer_r.batch_encode_plus_boxes( + words, + boxes=boxes, + ) + input_r = tokenizer_r.pad(input_r, max_length=max_length, padding="max_length") + + input_p = tokenizer_r.batch_encode_plus_boxes( + words, + boxes=boxes, + ) + input_p = tokenizer_r.pad(input_p, max_length=max_length, padding="max_length") + + self.assert_batch_padded_input_match(input_r, input_p, max_length, pad_token_id) + + def test_padding_warning_message_fast_tokenizer(self): + if not self.test_rust_tokenizer: + return + + words, boxes = self.get_words_and_boxes_batch() + + tokenizer_fast = self.get_rust_tokenizer() + + encoding_fast = tokenizer_fast( + words, + boxes=boxes, + ) + + with self.assertLogs("transformers", level="WARNING") as cm: + tokenizer_fast.pad(encoding_fast) + self.assertEqual(len(cm.records), 1) + self.assertIn( + "Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to" + " encode the text followed by a call to the `pad` method to get a padded encoding.", + cm.records[0].message, + ) + + if not self.test_slow_tokenizer: + return + + tokenizer_slow = self.get_tokenizer() + + encoding_slow = tokenizer_slow( + words, + boxes=boxes, + ) + + with self.assertLogs(level="WARNING") as cm: + # We want to assert there are no warnings, but the 'assertLogs' method does not support that. + # Therefore, we are adding a dummy warning, and then we will assert it is the only warning. + logger.warning("Dummy warning") + tokenizer_slow.pad(encoding_slow) + self.assertEqual(len(cm.records), 1) + self.assertIn( + "Dummy warning", + cm.records[0].message, + ) + + def test_call(self): + # Tests that all call wrap to encode_plus and batch_encode_plus + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + # Test not batched + words, boxes = self.get_words_and_boxes() + encoded_sequences_1 = tokenizer.encode_plus_boxes(words, boxes=boxes) + encoded_sequences_2 = tokenizer(words, boxes=boxes) + self.assertEqual(encoded_sequences_1, encoded_sequences_2) + + # Test not batched pairs + question, words, boxes = self.get_question_words_and_boxes() + encoded_sequences_1 = tokenizer.encode_plus_boxes(words, boxes=boxes) + encoded_sequences_2 = tokenizer(words, boxes=boxes) + self.assertEqual(encoded_sequences_1, encoded_sequences_2) + + # Test batched + words, boxes = self.get_words_and_boxes_batch() + encoded_sequences_1 = tokenizer.batch_encode_plus_boxes(words, is_pair=False, boxes=boxes) + encoded_sequences_2 = tokenizer(words, boxes=boxes) + self.assertEqual(encoded_sequences_1, encoded_sequences_2) + + def test_batch_encode_plus_batch_sequence_length(self): + # Tests that all encoded values have the correct size + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes_batch() + + encoded_sequences = [ + tokenizer.encode_plus_boxes(words_example, boxes=boxes_example) + for words_example, boxes_example in zip(words, boxes) + ] + encoded_sequences_batch = tokenizer.batch_encode_plus_boxes( + words, is_pair=False, boxes=boxes, padding=False + ) + self.assertListEqual( + encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch) + ) + + maximum_length = len( + max([encoded_sequence["input_ids"] for encoded_sequence in encoded_sequences], key=len) + ) + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, words) + + encoded_sequences_padded = [ + tokenizer.encode_plus_boxes( + words_example, boxes=boxes_example, max_length=maximum_length, padding="max_length" + ) + for words_example, boxes_example in zip(words, boxes) + ] + + encoded_sequences_batch_padded = tokenizer.batch_encode_plus_boxes( + words, is_pair=False, boxes=boxes, padding=True + ) + self.assertListEqual( + encoded_sequences_padded, + self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch_padded), + ) + + # check 'longest' is unsensitive to a max length + encoded_sequences_batch_padded_1 = tokenizer.batch_encode_plus_boxes( + words, is_pair=False, boxes=boxes, padding=True + ) + encoded_sequences_batch_padded_2 = tokenizer.batch_encode_plus_boxes( + words, is_pair=False, boxes=boxes, max_length=maximum_length + 10, padding="longest" + ) + for key in encoded_sequences_batch_padded_1.keys(): + self.assertListEqual( + encoded_sequences_batch_padded_1[key], + encoded_sequences_batch_padded_2[key], + ) + + # check 'no_padding' is unsensitive to a max length + encoded_sequences_batch_padded_1 = tokenizer.batch_encode_plus_boxes( + words, is_pair=False, boxes=boxes, padding=False + ) + encoded_sequences_batch_padded_2 = tokenizer.batch_encode_plus_boxes( + words, is_pair=False, boxes=boxes, max_length=maximum_length + 10, padding=False + ) + for key in encoded_sequences_batch_padded_1.keys(): + self.assertListEqual( + encoded_sequences_batch_padded_1[key], + encoded_sequences_batch_padded_2[key], + ) + + @unittest.skip("batch_encode_plus does not handle overflowing tokens.") + def test_batch_encode_plus_overflowing_tokens(self): + pass + + def test_batch_encode_plus_padding(self): + # Test that padded sequences are equivalent between batch_encode_plus and encode_plus + + # Right padding tests + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes_batch() + + max_length = 100 + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, words) + + encoded_sequences = [ + tokenizer.encode_plus_boxes( + words_example, boxes=boxes_example, max_length=max_length, padding="max_length" + ) + for words_example, boxes_example in zip(words, boxes) + ] + encoded_sequences_batch = tokenizer.batch_encode_plus_boxes( + words, is_pair=False, boxes=boxes, max_length=max_length, padding="max_length" + ) + self.assertListEqual( + encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch) + ) + + # Left padding tests + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + tokenizer.padding_side = "left" + words, boxes = self.get_words_and_boxes_batch() + + max_length = 100 + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, words) + + encoded_sequences = [ + tokenizer.encode_plus_boxes( + words_example, boxes=boxes_example, max_length=max_length, padding="max_length" + ) + for words_example, boxes_example in zip(words, boxes) + ] + encoded_sequences_batch = tokenizer.batch_encode_plus_boxes( + words, is_pair=False, boxes=boxes, max_length=max_length, padding="max_length" + ) + self.assertListEqual( + encoded_sequences, self.convert_batch_encode_plus_format_to_encode_plus(encoded_sequences_batch) + ) + + def test_padding_to_multiple_of(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + if tokenizer.pad_token is None: + self.skipTest("No padding token.") + else: + words, boxes = self.get_words_and_boxes() + + normal_tokens = tokenizer(words, boxes=boxes, padding=True, pad_to_multiple_of=8) + + for key, value in normal_tokens.items(): + self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8") + + normal_tokens = tokenizer(words, boxes=boxes, pad_to_multiple_of=8) + for key, value in normal_tokens.items(): + self.assertNotEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8") + + # Should also work with truncation + normal_tokens = tokenizer(words, boxes=boxes, padding=True, truncation=True, pad_to_multiple_of=8) + for key, value in normal_tokens.items(): + self.assertEqual(len(value) % 8, 0, f"BatchEncoding.{key} is not multiple of 8") + + # truncation to something which is not a multiple of pad_to_multiple_of raises an error + self.assertRaises( + ValueError, + tokenizer.__call__, + words, + boxes=boxes, + padding=True, + truncation=True, + max_length=12, + pad_to_multiple_of=8, + ) + + def test_tokenizer_slow_store_full_signature(self): + signature = inspect.signature(self.tokenizer_class.__init__) + tokenizer = self.get_tokenizer() + + for parameter_name, parameter in signature.parameters.items(): + if parameter.default != inspect.Parameter.empty: + self.assertIn(parameter_name, tokenizer.init_kwargs) + + def test_build_inputs_with_special_tokens(self): + if not self.test_slow_tokenizer: + # as we don't have a slow version, we can't compare the outputs between slow and fast versions + return + + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + # Input tokens id + words, boxes = self.get_words_and_boxes() + input_simple = tokenizer_p.encode_boxes(words, boxes=boxes, add_special_tokens=False) + input_pair = tokenizer_p.encode_boxes(words, boxes=boxes, add_special_tokens=False) + + # Generate output + output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple) + output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple) + self.assertEqual(output_p, output_r) + + # Generate pair output + output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair) + output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair) + self.assertEqual(output_p, output_r) + + def test_special_tokens_mask_input_pairs(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + encoded_sequence = tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=False) + encoded_sequence_dict = tokenizer.encode_plus_boxes( + words, + boxes=boxes, + add_special_tokens=True, + return_special_tokens_mask=True, + # add_prefix_space=False, + ) + encoded_sequence_w_special = encoded_sequence_dict["input_ids"] + special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] + self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) + + filtered_sequence = [ + (x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special) + ] + filtered_sequence = [x for x in filtered_sequence if x is not None] + self.assertEqual(encoded_sequence, filtered_sequence) + + def test_special_tokens_mask(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + # Testing single inputs + encoded_sequence = tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=False) + encoded_sequence_dict = tokenizer.encode_plus_boxes( + words, boxes=boxes, add_special_tokens=True, return_special_tokens_mask=True + ) + encoded_sequence_w_special = encoded_sequence_dict["input_ids"] + special_tokens_mask = encoded_sequence_dict["special_tokens_mask"] + self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special)) + + filtered_sequence = [x for i, x in enumerate(encoded_sequence_w_special) if not special_tokens_mask[i]] + self.assertEqual(encoded_sequence, filtered_sequence) + + def test_save_and_load_tokenizer(self): + # safety check on max_len default value so we are sure the test works + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + self.assertNotEqual(tokenizer.model_max_length, 42) + + # Now let's start the test + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + # Isolate this from the other tests because we save additional tokens/etc + words, boxes = self.get_words_and_boxes() + tmpdirname = tempfile.mkdtemp() + + before_tokens = tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=False) + before_vocab = tokenizer.get_vocab() + tokenizer.save_pretrained(tmpdirname) + + after_tokenizer = tokenizer.__class__.from_pretrained(tmpdirname) + after_tokens = after_tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=False) + after_vocab = after_tokenizer.get_vocab() + self.assertListEqual(before_tokens, after_tokens) + self.assertDictEqual(before_vocab, after_vocab) + + shutil.rmtree(tmpdirname) + + @unittest.skip("Not implemented") + def test_right_and_left_truncation(self): + pass + + def test_right_and_left_padding(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + sequence = "Sequence" + padding_size = 10 + + # check correct behaviour if no pad_token_id exists and add it eventually + self._check_no_pad_token_padding(tokenizer, sequence) + + padding_idx = tokenizer.pad_token_id + + # RIGHT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True + tokenizer.padding_side = "right" + encoded_sequence = tokenizer.encode_boxes(words, boxes=boxes) + sequence_length = len(encoded_sequence) + padded_sequence = tokenizer.encode_boxes( + words, boxes=boxes, max_length=sequence_length + padding_size, padding="max_length" + ) + padded_sequence_length = len(padded_sequence) + assert sequence_length + padding_size == padded_sequence_length + assert encoded_sequence + [padding_idx] * padding_size == padded_sequence + + # LEFT PADDING - Check that it correctly pads when a maximum length is specified along with the padding flag set to True + tokenizer.padding_side = "left" + encoded_sequence = tokenizer.encode_boxes(words, boxes=boxes) + sequence_length = len(encoded_sequence) + padded_sequence = tokenizer.encode_boxes( + words, boxes=boxes, max_length=sequence_length + padding_size, padding="max_length" + ) + padded_sequence_length = len(padded_sequence) + assert sequence_length + padding_size == padded_sequence_length + assert [padding_idx] * padding_size + encoded_sequence == padded_sequence + + # RIGHT & LEFT PADDING - Check that nothing is done for 'longest' and 'no_padding' + encoded_sequence = tokenizer.encode_boxes(words, boxes=boxes) + sequence_length = len(encoded_sequence) + + tokenizer.padding_side = "right" + padded_sequence_right = tokenizer.encode_boxes(words, boxes=boxes, padding=True) + padded_sequence_right_length = len(padded_sequence_right) + assert sequence_length == padded_sequence_right_length + assert encoded_sequence == padded_sequence_right + + tokenizer.padding_side = "left" + padded_sequence_left = tokenizer.encode_boxes(words, boxes=boxes, padding="longest") + padded_sequence_left_length = len(padded_sequence_left) + assert sequence_length == padded_sequence_left_length + assert encoded_sequence == padded_sequence_left + + tokenizer.padding_side = "right" + padded_sequence_right = tokenizer.encode_boxes(words, boxes=boxes) + padded_sequence_right_length = len(padded_sequence_right) + assert sequence_length == padded_sequence_right_length + assert encoded_sequence == padded_sequence_right + + tokenizer.padding_side = "left" + padded_sequence_left = tokenizer.encode_boxes(words, boxes=boxes, padding=False) + padded_sequence_left_length = len(padded_sequence_left) + assert sequence_length == padded_sequence_left_length + assert encoded_sequence == padded_sequence_left + + def test_token_type_ids(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + # test 1: single sequence + words, boxes = self.get_words_and_boxes() + + output = tokenizer(words, boxes=boxes, return_token_type_ids=True) + + # Assert that the token type IDs have the same length as the input IDs + self.assertEqual(len(output["token_type_ids"]), len(output["input_ids"])) + + # Assert that the token type IDs have the same length as the attention mask + self.assertEqual(len(output["token_type_ids"]), len(output["attention_mask"])) + + self.assertIn(0, output["token_type_ids"]) + self.assertNotIn(1, output["token_type_ids"]) + + # test 2: two sequences (question + words) + question, words, boxes = self.get_question_words_and_boxes() + + output = tokenizer(question, words, boxes, return_token_type_ids=True) + + # Assert that the token type IDs have the same length as the input IDs + self.assertEqual(len(output["token_type_ids"]), len(output["input_ids"])) + + # Assert that the token type IDs have the same length as the attention mask + self.assertEqual(len(output["token_type_ids"]), len(output["attention_mask"])) + + self.assertIn(0, output["token_type_ids"]) + self.assertNotIn(1, output["token_type_ids"]) + + def test_offsets_mapping(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + text = ["a", "wonderful", "test"] + boxes = [[1, 8, 12, 20] for _ in range(len(text))] + + # No pair + tokens_with_offsets = tokenizer_r.encode_plus_boxes( + text, + boxes=boxes, + return_special_tokens_mask=True, + return_offsets_mapping=True, + add_special_tokens=True, + ) + added_tokens = tokenizer_r.num_special_tokens_to_add(False) + offsets = tokens_with_offsets["offset_mapping"] + + # Assert there is the same number of tokens and offsets + self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"])) + + # Assert there is online added_tokens special_tokens + self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens) + + # Pairs + text = "what's his name" + pair = ["a", "wonderful", "test"] + boxes = [[1, 8, 12, 20] for _ in range(len(pair))] + tokens_with_offsets = tokenizer_r.encode_plus_boxes( + text, + pair, + boxes=boxes, + return_special_tokens_mask=True, + return_offsets_mapping=True, + add_special_tokens=True, + ) + added_tokens = tokenizer_r.num_special_tokens_to_add(True) + offsets = tokens_with_offsets["offset_mapping"] + + # Assert there is the same number of tokens and offsets + self.assertEqual(len(offsets), len(tokens_with_offsets["input_ids"])) + + # Assert there is online added_tokens special_tokens + self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens) + + @require_torch + @slow + def test_torch_encode_plus_sent_to_model(self): + import torch + + from transformers import MODEL_MAPPING, TOKENIZER_MAPPING + + MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING) + + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + if tokenizer.__class__ not in MODEL_TOKENIZER_MAPPING: + return + + config_class, model_class = MODEL_TOKENIZER_MAPPING[tokenizer.__class__] + config = config_class() + + if config.is_encoder_decoder or config.pad_token_id is None: + return + + model = model_class(config) + + # Make sure the model contains at least the full vocabulary size in its embedding matrix + is_using_common_embeddings = hasattr(model.get_input_embeddings(), "weight") + assert ( + (model.get_input_embeddings().weight.shape[0] >= len(tokenizer)) + if is_using_common_embeddings + else True + ) + + # Build sequence + words, boxes = self.get_words_and_boxes() + encoded_sequence = tokenizer.encode_plus_boxes(words, boxes=boxes, return_tensors="pt") + batch_encoded_sequence = tokenizer.batch_encode_plus_boxes( + [words, words], [boxes, boxes], return_tensors="pt" + ) + # This should not fail + + with torch.no_grad(): # saves some time + model(**encoded_sequence) + model(**batch_encoded_sequence) + + def test_rust_and_python_full_tokenizers(self): + if not self.test_rust_tokenizer: + return + + if not self.test_slow_tokenizer: + # as we don't have a slow version, we can't compare the outputs between slow and fast versions + return + + tokenizer = self.get_tokenizer() + rust_tokenizer = self.get_rust_tokenizer() + + words, boxes = self.get_words_and_boxes() + + ids = tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=False) + rust_ids = rust_tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=False) + self.assertListEqual(ids, rust_ids) + + ids = tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=True) + rust_ids = rust_tokenizer.encode_boxes(words, boxes=boxes, add_special_tokens=True) + self.assertListEqual(ids, rust_ids) + + def test_tokenization_python_rust_equals(self): + if not self.test_slow_tokenizer: + # as we don't have a slow version, we can't compare the outputs between slow and fast versions + return + + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + words, boxes = self.get_words_and_boxes() + + # Ensure basic input match + input_p = tokenizer_p.encode_plus_boxes(words, boxes=boxes) + input_r = tokenizer_r.encode_plus_boxes(words, boxes=boxes) + + for key in filter( + lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys() + ): + self.assertSequenceEqual(input_p[key], input_r[key]) + + input_pairs_p = tokenizer_p.encode_plus_boxes(words, boxes=boxes) + input_pairs_r = tokenizer_r.encode_plus_boxes(words, boxes=boxes) + + for key in filter( + lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys() + ): + self.assertSequenceEqual(input_pairs_p[key], input_pairs_r[key]) + + words = ["hello" for _ in range(1000)] + boxes = [[1000, 1000, 1000, 1000] for _ in range(1000)] + + # Ensure truncation match + input_p = tokenizer_p.encode_plus_boxes(words, boxes=boxes, max_length=512, truncation=True) + input_r = tokenizer_r.encode_plus_boxes(words, boxes=boxes, max_length=512, truncation=True) + + for key in filter( + lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys() + ): + self.assertSequenceEqual(input_p[key], input_r[key]) + + # Ensure truncation with stride match + input_p = tokenizer_p.encode_plus_boxes( + words, boxes=boxes, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True + ) + input_r = tokenizer_r.encode_plus_boxes( + words, boxes=boxes, max_length=512, truncation=True, stride=3, return_overflowing_tokens=True + ) + + for key in filter( + lambda x: x in ["input_ids", "token_type_ids", "attention_mask", "bbox"], input_p.keys() + ): + self.assertSequenceEqual(input_p[key], input_r[key][0]) + + def test_embeded_special_tokens(self): + if not self.test_slow_tokenizer: + # as we don't have a slow version, we can't compare the outputs between slow and fast versions + return + + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + words, boxes = self.get_words_and_boxes() + tokens_r = tokenizer_r.encode_plus_boxes( + words, + boxes=boxes, + add_special_tokens=True, + ) + tokens_p = tokenizer_p.encode_plus_boxes( + words, + boxes=boxes, + add_special_tokens=True, + ) + + for key in tokens_p.keys(): + self.assertEqual(tokens_r[key], tokens_p[key]) + + if "token_type_ids" in tokens_r: + self.assertEqual(sum(tokens_r["token_type_ids"]), sum(tokens_p["token_type_ids"])) + + tokens_r = tokenizer_r.convert_ids_to_tokens(tokens_r["input_ids"]) + tokens_p = tokenizer_p.convert_ids_to_tokens(tokens_p["input_ids"]) + self.assertSequenceEqual(tokens_r, tokens_p) + + def test_compare_add_special_tokens(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + simple_num_special_tokens_to_add = tokenizer_r.num_special_tokens_to_add(pair=False) + + words, boxes = self.get_words_and_boxes() + # tokenize() + no_special_tokens = tokenizer_r.tokenize(" ".join(words), add_special_tokens=False) + with_special_tokens = tokenizer_r.tokenize(" ".join(words), add_special_tokens=True) + self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add) + + # encode() + no_special_tokens = tokenizer_r.encode_boxes(words, boxes=boxes, add_special_tokens=False) + with_special_tokens = tokenizer_r.encode_boxes(words, boxes=boxes, add_special_tokens=True) + self.assertEqual(len(no_special_tokens), len(with_special_tokens) - simple_num_special_tokens_to_add) + + # encode_plus() + no_special_tokens = tokenizer_r.encode_plus_boxes(words, boxes=boxes, add_special_tokens=False) + with_special_tokens = tokenizer_r.encode_plus_boxes(words, boxes=boxes, add_special_tokens=True) + for key in no_special_tokens.keys(): + self.assertEqual( + len(no_special_tokens[key]), + len(with_special_tokens[key]) - simple_num_special_tokens_to_add, + ) + + # # batch_encode_plus + words, boxes = self.get_words_and_boxes_batch() + + no_special_tokens = tokenizer_r.batch_encode_plus_boxes(words, boxes=boxes, add_special_tokens=False) + with_special_tokens = tokenizer_r.batch_encode_plus_boxes(words, boxes=boxes, add_special_tokens=True) + for key in no_special_tokens.keys(): + for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]): + self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add) + + @slow + def test_udop_truncation_integration_test(self): + words, boxes = self.get_words_and_boxes() + + tokenizer = UdopTokenizer.from_pretrained("microsoft/udop-large", model_max_length=512) + + for i in range(12, 512): + new_encoded_inputs = tokenizer.encode_boxes(words, boxes=boxes, max_length=i, truncation=True) + + # Ensure that the input IDs are less than the max length defined. + self.assertLessEqual(len(new_encoded_inputs), i) + + tokenizer.model_max_length = 20 + new_encoded_inputs = tokenizer.encode_boxes(words, boxes=boxes, truncation=True) + dropped_encoded_inputs = tokenizer.encode_boxes(words, boxes=boxes, truncation=True) + + # Ensure that the input IDs are still truncated when no max_length is specified + self.assertListEqual(new_encoded_inputs, dropped_encoded_inputs) + self.assertLessEqual(len(new_encoded_inputs), 20) + + @is_pt_tf_cross_test + def test_batch_encode_plus_tensors(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes_batch() + + # A Tensor cannot be build by sequences which are not the same size + self.assertRaises( + ValueError, tokenizer.batch_encode_plus_boxes, words, boxes=boxes, return_tensors="pt" + ) + self.assertRaises( + ValueError, tokenizer.batch_encode_plus_boxes, words, boxes=boxes, return_tensors="tf" + ) + + if tokenizer.pad_token_id is None: + self.assertRaises( + ValueError, + tokenizer.batch_encode_plus_boxes, + words, + boxes=boxes, + padding=True, + return_tensors="pt", + ) + self.assertRaises( + ValueError, + tokenizer.batch_encode_plus_boxes, + words, + boxes=boxes, + padding="longest", + return_tensors="tf", + ) + else: + pytorch_tensor = tokenizer.batch_encode_plus_boxes( + words, boxes=boxes, padding=True, return_tensors="pt" + ) + tensorflow_tensor = tokenizer.batch_encode_plus_boxes( + words, boxes=boxes, padding="longest", return_tensors="tf" + ) + encoded_sequences = tokenizer.batch_encode_plus_boxes(words, boxes=boxes, padding=True) + + for key in encoded_sequences.keys(): + pytorch_value = pytorch_tensor[key].tolist() + tensorflow_value = tensorflow_tensor[key].numpy().tolist() + encoded_value = encoded_sequences[key] + + self.assertEqual(pytorch_value, tensorflow_value, encoded_value) + + def test_sequence_ids(self): + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + if not tokenizer.is_fast: + continue + with self.subTest(f"{tokenizer.__class__.__name__}"): + seq_0 = "Test this method." + seq_1 = ["With", "these", "inputs."] + boxes = [[1000, 1000, 1000, 1000] for _ in range(len(seq_1))] + + # We want to have sequence 0 and sequence 1 are tagged + # respectively with 0 and 1 token_ids + # (regardless of whether the model use token type ids) + # We use this assumption in the QA pipeline among other place + output = tokenizer(seq_0.split(), boxes=boxes) + self.assertIn(0, output.sequence_ids()) + + output = tokenizer(seq_0, seq_1, boxes=boxes) + self.assertIn(0, output.sequence_ids()) + self.assertIn(1, output.sequence_ids()) + + if tokenizer.num_special_tokens_to_add(pair=True): + self.assertIn(None, output.sequence_ids()) + + def test_special_tokens_initialization(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + added_tokens = [AddedToken("", lstrip=True)] + + tokenizer_r = self.rust_tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + words = "Hey this is a token".split() + boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))] + r_output = tokenizer_r.encode_boxes(words, boxes=boxes) + + special_token_id = tokenizer_r.encode_boxes( + [""], boxes=[1000, 1000, 1000, 1000], add_special_tokens=False + )[0] + + self.assertTrue(special_token_id in r_output) + + if self.test_slow_tokenizer: + tokenizer_cr = self.rust_tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True + ) + tokenizer_p = self.tokenizer_class.from_pretrained( + pretrained_name, additional_special_tokens=added_tokens, **kwargs + ) + + words = "Hey this is a token".split() + boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))] + + p_output = tokenizer_p.encode_boxes(words, boxes=boxes) + cr_output = tokenizer_cr.encode_boxes(words, boxes=boxes) + + self.assertEqual(p_output, r_output) + self.assertEqual(cr_output, r_output) + self.assertTrue(special_token_id in p_output) + self.assertTrue(special_token_id in cr_output) + + def test_training_new_tokenizer(self): + # This feature only exists for fast tokenizers + if not self.test_rust_tokenizer: + return + + tokenizer = self.get_rust_tokenizer() + new_tokenizer = tokenizer.train_new_from_iterator(SMALL_TRAINING_CORPUS, 100) + + # Test we can use the new tokenizer with something not seen during training + text = [["this", "is", "the"], ["how", "are", "you"]] + boxes = [[[1, 2, 3, 4], [5, 6, 7, 8], [1, 3, 4, 8]], [[5, 6, 7, 8], [4, 5, 6, 7], [3, 9, 2, 7]]] + inputs = new_tokenizer(text, boxes=boxes) + self.assertEqual(len(inputs["input_ids"]), 2) + decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) + expected_result = "this is the" + + if tokenizer.backend_tokenizer.normalizer is not None: + expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result) + self.assertEqual(expected_result, decoded_input) + + # We check that the parameters of the tokenizer remained the same + # Check we have the same number of added_tokens for both pair and non-pair inputs. + self.assertEqual(tokenizer.num_special_tokens_to_add(False), new_tokenizer.num_special_tokens_to_add(False)) + self.assertEqual(tokenizer.num_special_tokens_to_add(True), new_tokenizer.num_special_tokens_to_add(True)) + + # Check we have the correct max_length for both pair and non-pair inputs. + self.assertEqual(tokenizer.max_len_single_sentence, new_tokenizer.max_len_single_sentence) + self.assertEqual(tokenizer.max_len_sentences_pair, new_tokenizer.max_len_sentences_pair) + + # Assert the set of special tokens match as we didn't ask to change them + self.assertSequenceEqual( + tokenizer.all_special_tokens_extended, + new_tokenizer.all_special_tokens_extended, + ) + + self.assertDictEqual(tokenizer.special_tokens_map, new_tokenizer.special_tokens_map) + + def test_training_new_tokenizer_with_special_tokens_change(self): + # This feature only exists for fast tokenizers + if not self.test_rust_tokenizer: + return + + tokenizer = self.get_rust_tokenizer() + # Test with a special tokens map + class_signature = inspect.signature(tokenizer.__class__) + if "cls_token" in class_signature.parameters: + new_tokenizer = tokenizer.train_new_from_iterator( + SMALL_TRAINING_CORPUS, 100, special_tokens_map={tokenizer.cls_token: ""} + ) + cls_id = new_tokenizer.get_vocab()[""] + self.assertEqual(new_tokenizer.cls_token, "") + self.assertEqual(new_tokenizer.cls_token_id, cls_id) + + # Create a new mapping from the special tokens defined in the original tokenizer + special_tokens_list = SpecialTokensMixin.SPECIAL_TOKENS_ATTRIBUTES.copy() + special_tokens_list.remove("additional_special_tokens") + special_tokens_map = {} + for token in special_tokens_list: + # Get the private one to avoid unnecessary warnings. + if getattr(tokenizer, f"_{token}") is not None: + special_token = getattr(tokenizer, token) + special_tokens_map[special_token] = f"{special_token}a" + + # Train new tokenizer + new_tokenizer = tokenizer.train_new_from_iterator( + SMALL_TRAINING_CORPUS, 100, special_tokens_map=special_tokens_map + ) + + # Check the changes + for token in special_tokens_list: + # Get the private one to avoid unnecessary warnings. + if getattr(tokenizer, f"_{token}") is None: + continue + special_token = getattr(tokenizer, token) + if special_token in special_tokens_map: + new_special_token = getattr(new_tokenizer, token) + self.assertEqual(special_tokens_map[special_token], new_special_token) + + new_id = new_tokenizer.get_vocab()[new_special_token] + self.assertEqual(getattr(new_tokenizer, f"{token}_id"), new_id) + + # Check if the AddedToken / string format has been kept + for special_token in tokenizer.all_special_tokens_extended: + if isinstance(special_token, AddedToken) and special_token.content not in special_tokens_map: + # The special token must appear identically in the list of the new tokenizer. + self.assertTrue( + special_token in new_tokenizer.all_special_tokens_extended, + f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}", + ) + elif isinstance(special_token, AddedToken): + # The special token must appear in the list of the new tokenizer as an object of type AddedToken with + # the same parameters as the old AddedToken except the content that the user has requested to change. + special_token_str = special_token.content + new_special_token_str = special_tokens_map[special_token_str] + + find = False + for candidate in new_tokenizer.all_special_tokens_extended: + if ( + isinstance(candidate, AddedToken) + and candidate.content == new_special_token_str + and candidate.lstrip == special_token.lstrip + and candidate.rstrip == special_token.rstrip + and candidate.normalized == special_token.normalized + and candidate.single_word == special_token.single_word + ): + find = True + break + self.assertTrue( + find, + f"'{new_special_token_str}' doesn't appear in the list " + f"'{new_tokenizer.all_special_tokens_extended}' as an AddedToken with the same parameters as " + f"'{special_token}' in the list {tokenizer.all_special_tokens_extended}", + ) + elif special_token not in special_tokens_map: + # The special token must appear identically in the list of the new tokenizer. + self.assertTrue( + special_token in new_tokenizer.all_special_tokens_extended, + f"'{special_token}' should be in {new_tokenizer.all_special_tokens_extended}", + ) + + else: + # The special token must appear in the list of the new tokenizer as an object of type string. + self.assertTrue(special_tokens_map[special_token] in new_tokenizer.all_special_tokens_extended) + + # Test we can use the new tokenizer with something not seen during training + words = [["this", "is"], ["hello", "🤗"]] + boxes = [[[1, 2, 3, 4], [5, 6, 7, 8]], [[1, 2, 3, 4], [5, 6, 7, 8]]] + inputs = new_tokenizer(words, boxes=boxes) + self.assertEqual(len(inputs["input_ids"]), 2) + decoded_input = new_tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) + expected_result = "this is" + + if tokenizer.backend_tokenizer.normalizer is not None: + expected_result = tokenizer.backend_tokenizer.normalizer.normalize_str(expected_result) + self.assertEqual(expected_result, decoded_input) + + def test_prepare_for_model(self): + tokenizers = self.get_tokenizers(do_lower_case=False) + for tokenizer in tokenizers: + # only test prepare_for_model for the slow tokenizer + if tokenizer.__class__.__name__ == "UdopTokenizerFast": + continue + with self.subTest(f"{tokenizer.__class__.__name__}"): + words, boxes = self.get_words_and_boxes() + prepared_input_dict = tokenizer.prepare_for_model_boxes(words, boxes=boxes, add_special_tokens=True) + + input_dict = tokenizer.encode_plus_boxes(words, boxes=boxes, add_special_tokens=True) + + self.assertEqual(input_dict, prepared_input_dict) + + def test_padding_different_model_input_name(self): + if not self.test_slow_tokenizer: + # as we don't have a slow version, we can't compare the outputs between slow and fast versions + return + + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs) + self.assertEqual(tokenizer_p.pad_token_id, tokenizer_r.pad_token_id) + pad_token_id = tokenizer_p.pad_token_id + + words, boxes = self.get_words_and_boxes_batch() + + input_r = tokenizer_r.batch_encode_plus_boxes(words, boxes=boxes) + input_p = tokenizer_r.batch_encode_plus_boxes(words, boxes=boxes) + + # rename encoded batch to "inputs" + input_r["inputs"] = input_r[tokenizer_r.model_input_names[0]] + del input_r[tokenizer_r.model_input_names[0]] + + input_p["inputs"] = input_p[tokenizer_p.model_input_names[0]] + del input_p[tokenizer_p.model_input_names[0]] + + # Renaming `input_ids` to `inputs` + tokenizer_r.model_input_names = ["inputs"] + tokenizer_r.model_input_names[1:] + tokenizer_p.model_input_names = ["inputs"] + tokenizer_p.model_input_names[1:] + + input_r = tokenizer_r.pad(input_r, padding="longest") + input_p = tokenizer_r.pad(input_p, padding="longest") + + max_length = len(input_p["inputs"][0]) + self.assert_batch_padded_input_match( + input_r, input_p, max_length, pad_token_id, model_main_input_name="inputs" + ) + + def test_batch_encode_dynamic_overflowing(self): + """ + When calling batch_encode with multiple sequences, it can return different number of + overflowing encoding for each sequence: + [ + Sequence 1: [Encoding 1, Encoding 2], + Sequence 2: [Encoding 1], + Sequence 3: [Encoding 1, Encoding 2, ... Encoding N] + ] + This needs to be padded so that it can represented as a tensor + """ + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + tokenizer = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs) + + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name}, {tokenizer.__class__.__name__})"): + if is_torch_available(): + returned_tensor = "pt" + elif is_tf_available(): + returned_tensor = "tf" + else: + returned_tensor = "jax" + + # Single example + words, boxes = self.get_words_and_boxes() + tokens = tokenizer.encode_plus_boxes( + words, + boxes=boxes, + max_length=6, + padding=True, + truncation=True, + return_tensors=returned_tensor, + return_overflowing_tokens=True, + ) + + for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()): + if key != "bbox": + self.assertEqual(len(tokens[key].shape), 2) + else: + self.assertEqual(len(tokens[key].shape), 3) + + # Batch of examples + # For these 2 examples, 3 training examples will be created + words, boxes = self.get_words_and_boxes_batch() + tokens = tokenizer.batch_encode_plus_boxes( + words, + boxes=boxes, + max_length=6, + padding=True, + truncation="only_first", + return_tensors=returned_tensor, + return_overflowing_tokens=True, + ) + + for key in filter(lambda x: "overflow_to_sample_mapping" not in x, tokens.keys()): + if key != "bbox": + self.assertEqual(len(tokens[key].shape), 2) + self.assertEqual(tokens[key].shape[-1], 6) + else: + self.assertEqual(len(tokens[key].shape), 3) + self.assertEqual(tokens[key].shape[-1], 4) + + @unittest.skip("TO DO: overwrite this very extensive test.") + def test_alignement_methods(self): + pass + + @unittest.skip("UDOP tokenizer requires boxes besides sequences.") + def test_maximum_encoding_length_pair_input(self): + pass + + @unittest.skip("UDOP tokenizer requires boxes besides sequences.") + def test_maximum_encoding_length_single_input(self): + pass + + @unittest.skip("UDOP tokenizer requires boxes besides sequences.") + def test_pretokenized_inputs(self): + pass + + @unittest.skip("UDOP tokenizer always expects pretokenized inputs.") + def test_compare_pretokenized_inputs(self): + pass + + @unittest.skip("UDOP fast tokenizer does not support prepare_for_model") + def test_compare_prepare_for_model(self): + pass + + @slow + def test_only_label_first_subword(self): + words = ["hello", "niels"] + boxes = [[1000, 1000, 1000, 1000] for _ in range(len(words))] + word_labels = [0, 1] + + # test slow tokenizer + tokenizer_p = UdopTokenizer.from_pretrained("microsoft/udop-large") + encoding = tokenizer_p(words, boxes=boxes, word_labels=word_labels) + self.assertListEqual(encoding.labels, [0, 1, -100, -100, -100]) + + tokenizer_p = UdopTokenizer.from_pretrained("microsoft/udop-large", only_label_first_subword=False) + encoding = tokenizer_p(words, boxes=boxes, word_labels=word_labels) + self.assertListEqual(encoding.labels, [0, 1, 1, 1, -100]) + + # test fast tokenizer + tokenizer_r = UdopTokenizerFast.from_pretrained("microsoft/udop-large") + encoding = tokenizer_r(words, boxes=boxes, word_labels=word_labels) + self.assertListEqual(encoding.labels, [0, 1, -100, -100, -100]) + + tokenizer_r = UdopTokenizerFast.from_pretrained("microsoft/udop-large", only_label_first_subword=False) + encoding = tokenizer_r(words, boxes=boxes, word_labels=word_labels) + self.assertListEqual(encoding.labels, [0, 1, 1, 1, -100]) + + @slow + def test_udop_integration_test(self): + tokenizer_p = UdopTokenizer.from_pretrained("microsoft/udop-large") + tokenizer_r = UdopTokenizerFast.from_pretrained("microsoft/udop-large") + + # There are 3 cases: + # CASE 1: document image classification (training + inference), document image token classification (inference), + # in which case only words and normalized bounding boxes are provided to the tokenizer + # CASE 2: document image token classification (training), + # in which case one also provides word labels to the tokenizer + # CASE 3: document image visual question answering (inference), + # in which case one also provides a question to the tokenizer + + # We need to test all 3 cases both on batched and non-batched inputs. + + # CASE 1: not batched + words, boxes = self.get_words_and_boxes() + + # fmt: off + expected_results = {'input_ids': [3, 9, 10088, 120, 794, 21820, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'bbox': [[423, 237, 440, 251], [423, 237, 440, 251], [427, 272, 441, 287], [427, 272, 441, 287], [419, 115, 437, 129], [961, 885, 992, 912], [1000, 1000, 1000, 1000], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # noqa: E231 + # fmt: on + + encoding_p = tokenizer_p(words, boxes=boxes, padding="max_length", max_length=20) + encoding_r = tokenizer_r(words, boxes=boxes, padding="max_length", max_length=20) + self.assertDictEqual(dict(encoding_p), expected_results) + self.assertDictEqual(dict(encoding_r), expected_results) + + # CASE 1: batched + words, boxes = self.get_words_and_boxes_batch() + + # fmt: off + expected_results = {'input_ids': [[3, 9, 10088, 120, 794, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [21820, 82, 564, 19, 3, 17396, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'bbox': [[[423, 237, 440, 251], [423, 237, 440, 251], [427, 272, 441, 287], [427, 272, 441, 287], [419, 115, 437, 129], [1000, 1000, 1000, 1000], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[961, 885, 992, 912], [256, 38, 330, 58], [256, 38, 330, 58], [336, 42, 353, 57], [34, 42, 66, 69], [34, 42, 66, 69], [1000, 1000, 1000, 1000], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E231 + # fmt: on + + encoding_p = tokenizer_p(words, boxes=boxes, padding="max_length", max_length=20) + encoding_r = tokenizer_r(words, boxes=boxes, padding="max_length", max_length=20) + self.assertDictEqual(dict(encoding_p), expected_results) + self.assertDictEqual(dict(encoding_r), expected_results) + + # CASE 2: not batched + words, boxes = self.get_words_and_boxes() + word_labels = [1, 2, 3, 4] + + # fmt: off + expected_results = {'input_ids': [3, 9, 10088, 120, 794, 21820, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'bbox': [[423, 237, 440, 251], [423, 237, 440, 251], [427, 272, 441, 287], [427, 272, 441, 287], [419, 115, 437, 129], [961, 885, 992, 912], [1000, 1000, 1000, 1000], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'labels': [1, -100, 2, -100, 3, 4, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]} # noqa: E231 + # fmt: on + + encoding_p = tokenizer_p(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20) + encoding_r = tokenizer_r(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20) + + for key in expected_results: + self.assertListEqual(encoding_p[key], encoding_r[key]) + + self.assertDictEqual(dict(encoding_p), expected_results) + self.assertDictEqual(dict(encoding_r), expected_results) + + # CASE 2: batched + words, boxes = self.get_words_and_boxes_batch() + word_labels = [[1, 2, 3], [2, 46, 17, 22, 3]] + + # fmt: off + expected_results = {'input_ids': [[3, 9, 10088, 120, 794, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [21820, 82, 564, 19, 3, 17396, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'bbox': [[[423, 237, 440, 251], [423, 237, 440, 251], [427, 272, 441, 287], [427, 272, 441, 287], [419, 115, 437, 129], [1000, 1000, 1000, 1000], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[961, 885, 992, 912], [256, 38, 330, 58], [256, 38, 330, 58], [336, 42, 353, 57], [34, 42, 66, 69], [34, 42, 66, 69], [1000, 1000, 1000, 1000], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], 'labels': [[1, -100, 2, -100, 3, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100], [2, 46, 17, 22, 3, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} # noqa: E231 + # fmt: on + + encoding_p = tokenizer_p(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20) + encoding_r = tokenizer_r(words, boxes=boxes, word_labels=word_labels, padding="max_length", max_length=20) + self.assertDictEqual(dict(encoding_p), expected_results) + self.assertDictEqual(dict(encoding_r), expected_results) + + # CASE 3: not batched + question, words, boxes = self.get_question_words_and_boxes() + + # fmt: off + expected_results = {'input_ids': [125, 31, 7, 112, 564, 58, 1, 3, 9, 10088, 120, 794, 1, 0, 0, 0, 0, 0, 0, 0], 'bbox': [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1000, 1000, 1000, 1000], [423, 237, 440, 251], [423, 237, 440, 251], [427, 272, 441, 287], [427, 272, 441, 287], [419, 115, 437, 129], [1000, 1000, 1000, 1000], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]} # noqa: E231 + # fmt: on + + encoding_p = tokenizer_p(question, words, boxes, padding="max_length", max_length=20) + encoding_r = tokenizer_r(question, words, boxes, padding="max_length", max_length=20) + self.assertDictEqual(dict(encoding_p), expected_results) + self.assertDictEqual(dict(encoding_r), expected_results) + + # CASE 3: batched + questions, words, boxes = self.get_question_words_and_boxes_batch() + + # fmt: off + expected_results = {'input_ids': [[125, 31, 7, 112, 564, 58, 1, 3, 9, 10088, 120, 794, 1, 0, 0, 0, 0, 0, 0, 0], [149, 19, 3, 88, 718, 58, 1, 125, 3, 9, 50, 99, 1807, 17, 29, 1, 0, 0, 0, 0]], 'bbox': [[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1000, 1000, 1000, 1000], [423, 237, 440, 251], [423, 237, 440, 251], [427, 272, 441, 287], [427, 272, 441, 287], [419, 115, 437, 129], [1000, 1000, 1000, 1000], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1000, 1000, 1000, 1000], [256, 38, 330, 58], [256, 38, 330, 58], [256, 38, 330, 58], [336, 42, 353, 57], [336, 42, 353, 57], [34, 42, 66, 69], [34, 42, 66, 69], [34, 42, 66, 69], [1000, 1000, 1000, 1000], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]} # noqa: E231 + # fmt: on + + encoding_p = tokenizer_p(questions, words, boxes, padding="max_length", max_length=20) + encoding_r = tokenizer_r(questions, words, boxes, padding="max_length", max_length=20) + self.assertDictEqual(dict(encoding_p), expected_results) + self.assertDictEqual(dict(encoding_r), expected_results) + + @unittest.skip("Doesn't support another framework than PyTorch") + def test_np_encode_plus_sent_to_model(self): + pass + + @unittest.skip("Doesn't use SentencePiece") + def test_sentencepiece_tokenize_and_convert_tokens_to_string(self): + pass + + @unittest.skip("Doesn't use SentencePiece") + def test_sentencepiece_tokenize_and_decode(self): + pass + + def test_text_target(self): + tokenizer_p = UdopTokenizer.from_pretrained("microsoft/udop-large") + tokenizer_r = UdopTokenizerFast.from_pretrained("microsoft/udop-large") + + text = "hello world" + expected_decoding = "hello world" + + # should raise an error if we don't provide it using the `text_target` argument + with self.assertRaises(ValueError): + tokenizer_p(text) + + encoding_p = tokenizer_p(text_target=text) + encoding_r = tokenizer_r(text_target=text) + + self.assertListEqual(encoding_p["input_ids"], [21820, 296, 1]) + self.assertListEqual(encoding_p["attention_mask"], [1, 1, 1]) + self.assertDictEqual(dict(encoding_p), dict(encoding_r)) + self.assertEqual(tokenizer_p.decode(encoding_p["input_ids"]), expected_decoding) diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index da4a1210357daf..fae3ed8da0b4ef 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -84,6 +84,8 @@ "ClapAudioConfig": ["num_classes"], # Not used, but providing useful information to users "SpeechT5HifiGanConfig": ["sampling_rate"], + # used internally in the configuration class file + "UdopConfig": ["feed_forward_proj"], # Actually used in the config or generation config, in that case necessary for the sub-components generation "SeamlessM4TConfig": [ "max_new_tokens", diff --git a/utils/check_repo.py b/utils/check_repo.py index 7cc06c6781164c..44c99194f309a2 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -61,6 +61,7 @@ PRIVATE_MODELS = [ "AltRobertaModel", "DPRSpanPredictor", + "UdopStack", "LongT5Stack", "RealmBertModel", "T5Stack", @@ -304,6 +305,7 @@ "SeamlessM4TCodeHifiGan", "SeamlessM4TForSpeechToSpeech", # no auto class for speech-to-speech "TvpForVideoGrounding", + "UdopForConditionalGeneration", "SeamlessM4Tv2NARTextToUnitModel", "SeamlessM4Tv2NARTextToUnitForConditionalGeneration", "SeamlessM4Tv2CodeHifiGan",