Skip to content

Commit

Permalink
Merge branch 'main' into add-doge-model
Browse files Browse the repository at this point in the history
  • Loading branch information
LoserCheems authored Feb 1, 2025
2 parents 5f7545d + 62db3e6 commit decc891
Show file tree
Hide file tree
Showing 29 changed files with 4,187 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docker/transformers-all-latest-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ SHELL ["sh", "-lc"]
# The following `ARG` are mainly used to specify the versions explicitly & directly in this docker file, and not meant
# to be used as arguments for docker build (so far).

ARG PYTORCH='2.5.1'
ARG PYTORCH='2.6.0'
# (not always a valid torch version)
ARG INTEL_TORCH_EXT='2.3.0'
# Example: `cu102`, `cu113`, etc.
Expand Down
2 changes: 1 addition & 1 deletion docker/transformers-pytorch-gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ARG REF=main
RUN git clone https://github.com/huggingface/transformers && cd transformers && git checkout $REF

# If set to nothing, will install the latest version
ARG PYTORCH='2.5.1'
ARG PYTORCH='2.6.0'
ARG TORCH_VISION=''
ARG TORCH_AUDIO=''
# Example: `cu102`, `cu113`, etc.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,8 @@
title: FLAVA
- local: model_doc/git
title: GIT
- local: model_doc/got_ocr2
title: GOT-OCR2
- local: model_doc/grounding-dino
title: Grounding DINO
- local: model_doc/groupvit
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ Flax), PyTorch, and/or TensorFlow.
| [GIT](model_doc/git) ||||
| [GLM](model_doc/glm) ||||
| [GLPN](model_doc/glpn) ||||
| [GOT-OCR2](model_doc/got_ocr2) ||||
| [GPT Neo](model_doc/gpt_neo) ||||
| [GPT NeoX](model_doc/gpt_neox) ||||
| [GPT NeoX Japanese](model_doc/gpt_neox_japanese) ||||
Expand Down
269 changes: 269 additions & 0 deletions docs/source/en/model_doc/got_ocr2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
<!--Copyright 2024 StepFun and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# GOT-OCR2

## Overview

The GOT-OCR2 model was proposed in [General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model](https://arxiv.org/abs/2409.01704) by Haoran Wei, Chenglong Liu, Jinyue Chen, Jia Wang, Lingyu Kong, Yanming Xu, Zheng Ge, Liang Zhao, Jianjian Sun, Yuang Peng, Chunrui Han, Xiangyu Zhang.

The abstract from the paper is the following:

*Traditional OCR systems (OCR-1.0) are increasingly unable to meet people’snusage due to the growing demand for intelligent processing of man-made opticalncharacters. In this paper, we collectively refer to all artificial optical signals (e.g., plain texts, math/molecular formulas, tables, charts, sheet music, and even geometric shapes) as "characters" and propose the General OCR Theory along with an excellent model, namely GOT, to promote the arrival of OCR-2.0. The GOT, with 580M parameters, is a unified, elegant, and end-to-end model, consisting of a high-compression encoder and a long-contexts decoder. As an OCR-2.0 model, GOT can handle all the above "characters" under various OCR tasks. On the input side, the model supports commonly used scene- and document-style images in slice and whole-page styles. On the output side, GOT can generate plain or formatted results (markdown/tikz/smiles/kern) via an easy prompt. Besides, the model enjoys interactive OCR features, i.e., region-level recognition guided by coordinates or colors. Furthermore, we also adapt dynamic resolution and multipage OCR technologies to GOT for better practicality. In experiments, we provide sufficient results to prove the superiority of our model.*

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/got_ocr_overview.png"
alt="drawing" width="600"/>

<small> GOT-OCR2 training stages. Taken from the <a href="https://arxiv.org/abs/2409.01704">original paper.</a> </small>


Tips:

GOT-OCR2 works on a wide range of tasks, including plain document OCR, scene text OCR, formatted document OCR, and even OCR for tables, charts, mathematical formulas, geometric shapes, molecular formulas and sheet music. While this implementation of the model will only output plain text, the outputs can be further processed to render the desired format, with packages like `pdftex`, `mathpix`, `matplotlib`, `tikz`, `verovio` or `pyecharts`.
The model can also be used for interactive OCR, where the user can specify the region to be recognized by providing the coordinates or the color of the region's bounding box.

This model was contributed by [yonigozlan](https://huggingface.co/yonigozlan).
The original code can be found [here](https://github.com/Ucas-HaoranWei/GOT-OCR2.0).

## Usage example

### Plain text inference

```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")

>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"
>>> inputs = processor(image, return_tensors="pt").to(device)

>>> generate_ids = model.generate(
... **inputs,
... do_sample=False,
... tokenizer=processor.tokenizer,
... stop_strings="<|im_end|>",
... max_new_tokens=4096,
... )

>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
"R&D QUALITY IMPROVEMENT\nSUGGESTION/SOLUTION FORM\nName/Phone Ext. : (...)"
```

### Plain text inference batched

```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")

>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/image_ocr.jpg"

>>> inputs = processor([image1, image2], return_tensors="pt").to(device)

>>> generate_ids = model.generate(
... **inputs,
... do_sample=False,
... tokenizer=processor.tokenizer,
... stop_strings="<|im_end|>",
... max_new_tokens=4,
... )

>>> processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
["Reducing the number", "R&D QUALITY"]
```

### Formatted text inference

GOT-OCR2 can also generate formatted text, such as markdown or LaTeX. Here is an example of how to generate formatted text:

```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")

>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/latex.png"
>>> inputs = processor(image, return_tensors="pt", format=True).to(device)

>>> generate_ids = model.generate(
... **inputs,
... do_sample=False,
... tokenizer=processor.tokenizer,
... stop_strings="<|im_end|>",
... max_new_tokens=4096,
... )

>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
"\\author{\nHanwen Jiang* \\(\\quad\\) Arjun Karpur \\({ }^{\\dagger} \\quad\\) Bingyi Cao \\({ }^{\\dagger} \\quad\\) (...)"
```

### Inference on multiple pages

Although it might be reasonable in most cases to use a “for loop” for multi-page processing, some text data with formatting across several pages make it necessary to process all pages at once. GOT introduces a multi-page OCR (without “for loop”) feature, where multiple pages can be processed by the model at once, whith the output being one continuous text.
Here is an example of how to process multiple pages at once:


```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")

>>> image1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page1.png"
>>> image2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/page2.png"
>>> inputs = processor([image1, image2], return_tensors="pt", multi_page=True, format=True).to(device)

>>> generate_ids = model.generate(
... **inputs,
... do_sample=False,
... tokenizer=processor.tokenizer,
... stop_strings="<|im_end|>",
... max_new_tokens=4096,
... )

>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
"\\title{\nGeneral OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model\n}\n\\author{\nHaoran Wei (...)"
```

### Inference on cropped patches

GOT supports a 1024×1024 input resolution, which is sufficient for most OCR tasks, such as scene OCR or processing A4-sized PDF pages. However, certain scenarios, like horizontally stitched two-page PDFs commonly found in academic papers or images with unusual aspect ratios, can lead to accuracy issues when processed as a single image. To address this, GOT can dynamically crop an image into patches, process them all at once, and merge the results for better accuracy with such inputs.
Here is an example of how to process cropped patches:

```python
>>> import torch
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", torch_dtype=torch.bfloat16, device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")

>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/one_column.png"
>>> inputs = processor(image, return_tensors="pt", format=True, crop_to_patches=True, max_patches=3).to(device)

>>> generate_ids = model.generate(
... **inputs,
... do_sample=False,
... tokenizer=processor.tokenizer,
... stop_strings="<|im_end|>",
... max_new_tokens=4096,
... )

>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
"on developing architectural improvements to make learnable matching methods generalize.\nMotivated by the above observations, (...)"
```

### Inference on a specific region

GOT supports interactive OCR, where the user can specify the region to be recognized by providing the coordinates or the color of the region's bounding box. Here is an example of how to process a specific region:

```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText

>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")

>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/multi_box.png"
>>> inputs = processor(image, return_tensors="pt", color="green").to(device) # or box=[x1, y1, x2, y2] for coordinates (image pixels)

>>> generate_ids = model.generate(
... **inputs,
... do_sample=False,
... tokenizer=processor.tokenizer,
... stop_strings="<|im_end|>",
... max_new_tokens=4096,
... )

>>> processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
"You should keep in mind what features from the module should be used, especially \nwhen you’re planning to sell a template."
```

### Inference on general OCR data example: sheet music

Although this implementation of the model will only output plain text, the outputs can be further processed to render the desired format, with packages like `pdftex`, `mathpix`, `matplotlib`, `tikz`, `verovio` or `pyecharts`.
Here is an example of how to process sheet music:

```python
>>> from transformers import AutoProcessor, AutoModelForImageTextToText
>>> import verovio

>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
>>> processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")

>>> image = "https://huggingface.co/datasets/hf-internal-testing/fixtures_got_ocr/resolve/main/sheet_music.png"
>>> inputs = processor(image, return_tensors="pt", format=True).to(device)

>>> generate_ids = model.generate(
... **inputs,
... do_sample=False,
... tokenizer=processor.tokenizer,
... stop_strings="<|im_end|>",
... max_new_tokens=4096,
... )

>>> outputs = processor.decode(generate_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True)
>>> tk = verovio.toolkit()
>>> tk.loadData(outputs)
>>> tk.setOptions(
... {
... "pageWidth": 2100,
... "pageHeight": 800,
... "footer": "none",
... "barLineWidth": 0.5,
... "beamMaxSlope": 15,
... "staffLineWidth": 0.2,
... "spacingStaff": 6,
... }
... )
>>> tk.getPageCount()
>>> svg = tk.renderToSVG()
>>> svg = svg.replace('overflow="inherit"', 'overflow="visible"')
>>> with open("output.svg", "w") as f:
>>> f.write(svg)
```
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/sheet_music.svg"
alt="drawing" width="600"/>

## GotOcr2Config

[[autodoc]] GotOcr2Config

## GotOcr2VisionConfig

[[autodoc]] GotOcr2VisionConfig

## GotOcr2ImageProcessor

[[autodoc]] GotOcr2ImageProcessor

## GotOcr2Processor

[[autodoc]] GotOcr2Processor

## GotOcr2ForConditionalGeneration

[[autodoc]] GotOcr2ForConditionalGeneration
- forward

2 changes: 2 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [Emu3](https://huggingface.co/docs/transformers/model_doc/emu3)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
* [GotOcr2](https://huggingface.co/docs/transformers/model_doc/got_ocr2#transformers.GotOcr2ForConditionalGeneration)
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
Expand Down Expand Up @@ -254,6 +255,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
* [GotOcr2](https://huggingface.co/docs/transformers/model_doc/got_ocr2#transformers.GotOcr2ForConditionalGeneration)
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,11 @@
],
"models.glm": ["GlmConfig"],
"models.glpn": ["GLPNConfig"],
"models.got_ocr2": [
"GotOcr2Config",
"GotOcr2Processor",
"GotOcr2VisionConfig",
],
"models.gpt2": [
"GPT2Config",
"GPT2Tokenizer",
Expand Down Expand Up @@ -1239,6 +1244,7 @@
_import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"])
_import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"])
_import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"])
_import_structure["models.got_ocr2"].extend(["GotOcr2ImageProcessor"])
_import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"])
_import_structure["models.idefics"].extend(["IdeficsImageProcessor"])
_import_structure["models.idefics2"].extend(["Idefics2ImageProcessor"])
Expand Down Expand Up @@ -2435,6 +2441,12 @@
"GLPNPreTrainedModel",
]
)
_import_structure["models.got_ocr2"].extend(
[
"GotOcr2ForConditionalGeneration",
"GotOcr2PreTrainedModel",
]
)
_import_structure["models.gpt2"].extend(
[
"GPT2DoubleHeadsModel",
Expand Down Expand Up @@ -5550,6 +5562,7 @@
)
from .models.glm import GlmConfig
from .models.glpn import GLPNConfig
from .models.got_ocr2 import GotOcr2Config, GotOcr2Processor, GotOcr2VisionConfig
from .models.gpt2 import (
GPT2Config,
GPT2Tokenizer,
Expand Down Expand Up @@ -6352,6 +6365,7 @@
)
from .models.fuyu import FuyuImageProcessor, FuyuProcessor
from .models.glpn import GLPNFeatureExtractor, GLPNImageProcessor
from .models.got_ocr2 import GotOcr2ImageProcessor
from .models.grounding_dino import GroundingDinoImageProcessor
from .models.idefics import IdeficsImageProcessor
from .models.idefics2 import Idefics2ImageProcessor
Expand Down Expand Up @@ -7362,6 +7376,10 @@
GLPNModel,
GLPNPreTrainedModel,
)
from .models.got_ocr2 import (
GotOcr2ForConditionalGeneration,
GotOcr2PreTrainedModel,
)
from .models.gpt2 import (
GPT2DoubleHeadsModel,
GPT2ForQuestionAnswering,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3183,7 +3183,7 @@ def _sample(

model_forward = self.__call__
if isinstance(model_kwargs.get("past_key_values"), Cache):
is_compileable = model_kwargs["past_key_values"].is_compileable
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
git,
glm,
glpn,
got_ocr2,
gpt2,
gpt_bigcode,
gpt_neo,
Expand Down
Loading

0 comments on commit decc891

Please sign in to comment.