浏览代码

Merge pull request #33 from yjmm10/master

support infer with transformer
Qing Yan 3 月之前
父节点
当前提交
6a5e2ba95c
共有 2 个文件被更改,包括 80 次插入2 次删除
  1. 3 0
      README.md
  2. 77 2
      dots_ocr/parser.py

+ 3 - 0
README.md

@@ -1151,6 +1151,9 @@ python3 dots_ocr/parser.py demo/demo_image1.jpg --prompt prompt_ocr
 python3 dots_ocr/parser.py demo/demo_image1.jpg --prompt prompt_grounding_ocr --bbox 163 241 1536 705
 
 ```
+**Based on Transformers**, you can parse an image or a pdf file using the same commands above, just add `--use_hf true`. 
+
+> Notice: transformers is slower than vllm, if you want to use demo/* with transformers,just add `use_hf=True` in `DotsOCRParser(..,use_hf=True)`
 
 <details>
 <summary><b>Output Results</b></summary>

+ 77 - 2
dots_ocr/parser.py

@@ -31,6 +31,7 @@ class DotsOCRParser:
             output_dir="./output", 
             min_pixels=None,
             max_pixels=None,
+            use_hf=False,
         ):
         self.dpi = dpi
 
@@ -46,9 +47,72 @@ class DotsOCRParser:
         self.output_dir = output_dir
         self.min_pixels = min_pixels
         self.max_pixels = max_pixels
+
+        self.use_hf = use_hf
+        if self.use_hf:
+            self._load_hf_model()
+            print(f"use hf model, num_thread will be set to 1")
+        else:
+            print(f"use vllm model, num_thread will be set to {self.num_thread}")
         assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
         assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
 
+    def _load_hf_model(self):
+        import torch
+        from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
+        from qwen_vl_utils import process_vision_info
+
+        model_path = "./weights/DotsOCR"
+        self.model = AutoModelForCausalLM.from_pretrained(
+            model_path,
+            attn_implementation="flash_attention_2",
+            torch_dtype=torch.bfloat16,
+            device_map="auto",
+            trust_remote_code=True
+        )
+        self.processor = AutoProcessor.from_pretrained(model_path,  trust_remote_code=True,use_fast=True)
+        self.process_vision_info = process_vision_info
+
+    def _inference_with_hf(self, image, prompt):
+        messages = [
+            {
+                "role": "user",
+                "content": [
+                    {
+                        "type": "image",
+                        "image": image
+                    },
+                    {"type": "text", "text": prompt}
+                ]
+            }
+        ]
+
+        # Preparation for inference
+        text = self.processor.apply_chat_template(
+            messages, 
+            tokenize=False, 
+            add_generation_prompt=True
+        )
+        image_inputs, video_inputs = self.process_vision_info(messages)
+        inputs = self.processor(
+            text=[text],
+            images=image_inputs,
+            videos=video_inputs,
+            padding=True,
+            return_tensors="pt",
+        )
+
+        inputs = inputs.to("cuda")
+
+        # Inference: Generation of the output
+        generated_ids = self.model.generate(**inputs, max_new_tokens=24000)
+        generated_ids_trimmed = [
+            out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+        ]
+        response = self.processor.batch_decode(
+            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+        )[0]
+        return response
 
     def _inference_with_vllm(self, image, prompt):
         response = inference_with_vllm(
@@ -98,7 +162,10 @@ class DotsOCRParser:
             image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
         input_height, input_width = smart_resize(image.height, image.width)
         prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
-        response = self._inference_with_vllm(image, prompt)
+        if self.use_hf:
+            response = self._inference_with_hf(image, prompt)
+        else:
+            response = self._inference_with_vllm(image, prompt)
         result = {'page_no': page_idx,
             "input_height": input_height,
             "input_width": input_width
@@ -206,7 +273,10 @@ class DotsOCRParser:
         def _execute_task(task_args):
             return self._parse_single_image(**task_args)
 
-        num_thread = min(total_pages, self.num_thread)
+        if self.use_hf:
+            num_thread =  1
+        else:
+            num_thread = min(total_pages, self.num_thread)
         print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
 
         results = []
@@ -321,6 +391,10 @@ def main():
         "--max_pixels", type=int, default=None,
         help=""
     )
+    parser.add_argument(
+        "--use_hf", type=bool, default=False,
+        help=""
+    )
     args = parser.parse_args()
 
     dots_ocr_parser = DotsOCRParser(
@@ -335,6 +409,7 @@ def main():
         output_dir=args.output, 
         min_pixels=args.min_pixels,
         max_pixels=args.max_pixels,
+        use_hf=args.use_hf,
     )
 
     result = dots_ocr_parser.parse_file(