|
|
@@ -31,6 +31,7 @@ class DotsOCRParser:
|
|
|
output_dir="./output",
|
|
|
min_pixels=None,
|
|
|
max_pixels=None,
|
|
|
+ use_hf=False,
|
|
|
):
|
|
|
self.dpi = dpi
|
|
|
|
|
|
@@ -46,9 +47,72 @@ class DotsOCRParser:
|
|
|
self.output_dir = output_dir
|
|
|
self.min_pixels = min_pixels
|
|
|
self.max_pixels = max_pixels
|
|
|
+
|
|
|
+ self.use_hf = use_hf
|
|
|
+ if self.use_hf:
|
|
|
+ self._load_hf_model()
|
|
|
+ print(f"use hf model, num_thread will be set to 1")
|
|
|
+ else:
|
|
|
+ print(f"use vllm model, num_thread will be set to {self.num_thread}")
|
|
|
assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
|
|
|
assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
|
|
|
|
|
|
+ def _load_hf_model(self):
|
|
|
+ import torch
|
|
|
+ from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
|
|
+ from qwen_vl_utils import process_vision_info
|
|
|
+
|
|
|
+ model_path = "./weights/DotsOCR"
|
|
|
+ self.model = AutoModelForCausalLM.from_pretrained(
|
|
|
+ model_path,
|
|
|
+ attn_implementation="flash_attention_2",
|
|
|
+ torch_dtype=torch.bfloat16,
|
|
|
+ device_map="auto",
|
|
|
+ trust_remote_code=True
|
|
|
+ )
|
|
|
+ self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True,use_fast=True)
|
|
|
+ self.process_vision_info = process_vision_info
|
|
|
+
|
|
|
+ def _inference_with_hf(self, image, prompt):
|
|
|
+ messages = [
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": [
|
|
|
+ {
|
|
|
+ "type": "image",
|
|
|
+ "image": image
|
|
|
+ },
|
|
|
+ {"type": "text", "text": prompt}
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ]
|
|
|
+
|
|
|
+ # Preparation for inference
|
|
|
+ text = self.processor.apply_chat_template(
|
|
|
+ messages,
|
|
|
+ tokenize=False,
|
|
|
+ add_generation_prompt=True
|
|
|
+ )
|
|
|
+ image_inputs, video_inputs = self.process_vision_info(messages)
|
|
|
+ inputs = self.processor(
|
|
|
+ text=[text],
|
|
|
+ images=image_inputs,
|
|
|
+ videos=video_inputs,
|
|
|
+ padding=True,
|
|
|
+ return_tensors="pt",
|
|
|
+ )
|
|
|
+
|
|
|
+ inputs = inputs.to("cuda")
|
|
|
+
|
|
|
+ # Inference: Generation of the output
|
|
|
+ generated_ids = self.model.generate(**inputs, max_new_tokens=24000)
|
|
|
+ generated_ids_trimmed = [
|
|
|
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
|
+ ]
|
|
|
+ response = self.processor.batch_decode(
|
|
|
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
|
+ )[0]
|
|
|
+ return response
|
|
|
|
|
|
def _inference_with_vllm(self, image, prompt):
|
|
|
response = inference_with_vllm(
|
|
|
@@ -98,7 +162,10 @@ class DotsOCRParser:
|
|
|
image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
|
|
|
input_height, input_width = smart_resize(image.height, image.width)
|
|
|
prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
|
|
|
- response = self._inference_with_vllm(image, prompt)
|
|
|
+ if self.use_hf:
|
|
|
+ response = self._inference_with_hf(image, prompt)
|
|
|
+ else:
|
|
|
+ response = self._inference_with_vllm(image, prompt)
|
|
|
result = {'page_no': page_idx,
|
|
|
"input_height": input_height,
|
|
|
"input_width": input_width
|
|
|
@@ -206,7 +273,10 @@ class DotsOCRParser:
|
|
|
def _execute_task(task_args):
|
|
|
return self._parse_single_image(**task_args)
|
|
|
|
|
|
- num_thread = min(total_pages, self.num_thread)
|
|
|
+ if self.use_hf:
|
|
|
+ num_thread = 1
|
|
|
+ else:
|
|
|
+ num_thread = min(total_pages, self.num_thread)
|
|
|
print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
|
|
|
|
|
|
results = []
|
|
|
@@ -321,6 +391,10 @@ def main():
|
|
|
"--max_pixels", type=int, default=None,
|
|
|
help=""
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--use_hf", type=bool, default=False,
|
|
|
+ help=""
|
|
|
+ )
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
dots_ocr_parser = DotsOCRParser(
|
|
|
@@ -335,6 +409,7 @@ def main():
|
|
|
output_dir=args.output,
|
|
|
min_pixels=args.min_pixels,
|
|
|
max_pixels=args.max_pixels,
|
|
|
+ use_hf=args.use_hf,
|
|
|
)
|
|
|
|
|
|
result = dots_ocr_parser.parse_file(
|