import os import json from tqdm import tqdm from multiprocessing.pool import ThreadPool, Pool import argparse from dots_ocr.model.inference import inference_with_vllm from dots_ocr.utils.consts import image_extensions, MIN_PIXELS, MAX_PIXELS from dots_ocr.utils.image_utils import get_image_by_fitz_doc, fetch_image, smart_resize from dots_ocr.utils.doc_utils import fitz_doc_to_image, load_images_from_pdf from dots_ocr.utils.prompts import dict_promptmode_to_prompt from dots_ocr.utils.layout_utils import post_process_output, draw_layout_on_image, pre_process_bboxes from dots_ocr.utils.format_transformer import layoutjson2md class DotsOCRParser: """ parse image or pdf file """ def __init__(self, ip='localhost', port=8000, model_name='model', temperature=0.1, top_p=1.0, max_completion_tokens=16384, num_thread=64, dpi = 200, output_dir="./output", min_pixels=None, max_pixels=None, use_hf=False, ): self.dpi = dpi # default args for vllm server self.ip = ip self.port = port self.model_name = model_name # default args for inference self.temperature = temperature self.top_p = top_p self.max_completion_tokens = max_completion_tokens self.num_thread = num_thread self.output_dir = output_dir self.min_pixels = min_pixels self.max_pixels = max_pixels self.use_hf = use_hf if self.use_hf: self._load_hf_model() print(f"use hf model, num_thread will be set to 1") else: print(f"use vllm model, num_thread will be set to {self.num_thread}") assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS def _load_hf_model(self): import torch from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer from qwen_vl_utils import process_vision_info model_path = "./weights/DotsOCR" self.model = AutoModelForCausalLM.from_pretrained( model_path, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True,use_fast=True) self.process_vision_info = process_vision_info def _inference_with_hf(self, image, prompt): messages = [ { "role": "user", "content": [ { "type": "image", "image": image }, {"type": "text", "text": prompt} ] } ] # Preparation for inference text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = self.process_vision_info(messages) inputs = self.processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ) inputs = inputs.to("cuda") # Inference: Generation of the output generated_ids = self.model.generate(**inputs, max_new_tokens=24000) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] response = self.processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return response def _inference_with_vllm(self, image, prompt): response = inference_with_vllm( image, prompt, model_name=self.model_name, ip=self.ip, port=self.port, temperature=self.temperature, top_p=self.top_p, max_completion_tokens=self.max_completion_tokens, ) return response def get_prompt(self, prompt_mode, bbox=None, origin_image=None, image=None, min_pixels=None, max_pixels=None): prompt = dict_promptmode_to_prompt[prompt_mode] if prompt_mode == 'prompt_grounding_ocr': assert bbox is not None bboxes = [bbox] bbox = pre_process_bboxes(origin_image, bboxes, input_width=image.width, input_height=image.height, min_pixels=min_pixels, max_pixels=max_pixels)[0] prompt = prompt + str(bbox) return prompt # def post_process_results(self, response, prompt_mode, save_dir, save_name, origin_image, image, min_pixels, max_pixels) def _parse_single_image( self, origin_image, prompt_mode, save_dir, save_name, source="image", page_idx=0, bbox=None, fitz_preprocess=False, ): min_pixels, max_pixels = self.min_pixels, self.max_pixels if prompt_mode == "prompt_grounding_ocr": min_pixels = min_pixels or MIN_PIXELS # preprocess image to the final input max_pixels = max_pixels or MAX_PIXELS if min_pixels is not None: assert min_pixels >= MIN_PIXELS, f"min_pixels should >= {MIN_PIXELS}" if max_pixels is not None: assert max_pixels <= MAX_PIXELS, f"max_pixels should <= {MAX_PIXELS}" if source == 'image' and fitz_preprocess: image = get_image_by_fitz_doc(origin_image, target_dpi=self.dpi) image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels) else: image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels) input_height, input_width = smart_resize(image.height, image.width) prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels) if self.use_hf: response = self._inference_with_hf(image, prompt) else: response = self._inference_with_vllm(image, prompt) result = {'page_no': page_idx, "input_height": input_height, "input_width": input_width } if source == 'pdf': save_name = f"{save_name}_page_{page_idx}" if prompt_mode in ['prompt_layout_all_en', 'prompt_layout_only_en', 'prompt_grounding_ocr']: cells, filtered = post_process_output( response, prompt_mode, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels, ) if filtered and prompt_mode != 'prompt_layout_only_en': # model output json failed, use filtered process json_file_path = os.path.join(save_dir, f"{save_name}.json") with open(json_file_path, 'w', encoding="utf-8") as w: json.dump(response, w, ensure_ascii=False) image_layout_path = os.path.join(save_dir, f"{save_name}.jpg") origin_image.save(image_layout_path) result.update({ 'layout_info_path': json_file_path, 'layout_image_path': image_layout_path, }) md_file_path = os.path.join(save_dir, f"{save_name}.md") with open(md_file_path, "w", encoding="utf-8") as md_file: md_file.write(cells) result.update({ 'md_content_path': md_file_path }) result.update({ 'filtered': True }) else: try: image_with_layout = draw_layout_on_image(origin_image, cells) except Exception as e: print(f"Error drawing layout on image: {e}") image_with_layout = origin_image json_file_path = os.path.join(save_dir, f"{save_name}.json") with open(json_file_path, 'w', encoding="utf-8") as w: json.dump(cells, w, ensure_ascii=False) image_layout_path = os.path.join(save_dir, f"{save_name}.jpg") image_with_layout.save(image_layout_path) result.update({ 'layout_info_path': json_file_path, 'layout_image_path': image_layout_path, }) if prompt_mode != "prompt_layout_only_en": # no text md when detection only md_content = layoutjson2md(origin_image, cells, text_key='text') md_content_no_hf = layoutjson2md(origin_image, cells, text_key='text', no_page_hf=True) # used for clean output or metric of omnidocbench、olmbench md_file_path = os.path.join(save_dir, f"{save_name}.md") with open(md_file_path, "w", encoding="utf-8") as md_file: md_file.write(md_content) md_nohf_file_path = os.path.join(save_dir, f"{save_name}_nohf.md") with open(md_nohf_file_path, "w", encoding="utf-8") as md_file: md_file.write(md_content_no_hf) result.update({ 'md_content_path': md_file_path, 'md_content_nohf_path': md_nohf_file_path, }) else: image_layout_path = os.path.join(save_dir, f"{save_name}.jpg") origin_image.save(image_layout_path) result.update({ 'layout_image_path': image_layout_path, }) md_content = response md_file_path = os.path.join(save_dir, f"{save_name}.md") with open(md_file_path, "w", encoding="utf-8") as md_file: md_file.write(md_content) result.update({ 'md_content_path': md_file_path, }) return result def parse_image(self, input_path, filename, prompt_mode, save_dir, bbox=None, fitz_preprocess=False): origin_image = fetch_image(input_path) result = self._parse_single_image(origin_image, prompt_mode, save_dir, filename, source="image", bbox=bbox, fitz_preprocess=fitz_preprocess) result['file_path'] = input_path return [result] def parse_pdf(self, input_path, filename, prompt_mode, save_dir): print(f"loading pdf: {input_path}") images_origin = load_images_from_pdf(input_path, dpi=self.dpi) total_pages = len(images_origin) tasks = [ { "origin_image": image, "prompt_mode": prompt_mode, "save_dir": save_dir, "save_name": filename, "source":"pdf", "page_idx": i, } for i, image in enumerate(images_origin) ] def _execute_task(task_args): return self._parse_single_image(**task_args) if self.use_hf: num_thread = 1 else: num_thread = min(total_pages, self.num_thread) print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...") results = [] with ThreadPool(num_thread) as pool: with tqdm(total=total_pages, desc="Processing PDF pages") as pbar: for result in pool.imap_unordered(_execute_task, tasks): results.append(result) pbar.update(1) results.sort(key=lambda x: x["page_no"]) for i in range(len(results)): results[i]['file_path'] = input_path return results def parse_file(self, input_path, output_dir="", prompt_mode="prompt_layout_all_en", bbox=None, fitz_preprocess=False ): output_dir = output_dir or self.output_dir output_dir = os.path.abspath(output_dir) filename, file_ext = os.path.splitext(os.path.basename(input_path)) save_dir = os.path.join(output_dir, filename) os.makedirs(save_dir, exist_ok=True) if file_ext == '.pdf': results = self.parse_pdf(input_path, filename, prompt_mode, save_dir) elif file_ext in image_extensions: results = self.parse_image(input_path, filename, prompt_mode, save_dir, bbox=bbox, fitz_preprocess=fitz_preprocess) else: raise ValueError(f"file extension {file_ext} not supported, supported extensions are {image_extensions} and pdf") print(f"Parsing finished, results saving to {save_dir}") with open(os.path.join(output_dir, os.path.basename(filename)+'.jsonl'), 'w', encoding="utf-8") as w: for result in results: w.write(json.dumps(result, ensure_ascii=False) + '\n') return results def main(): prompts = list(dict_promptmode_to_prompt.keys()) parser = argparse.ArgumentParser( description="dots.ocr Multilingual Document Layout Parser", ) parser.add_argument( "input_path", type=str, help="Input PDF/image file path" ) parser.add_argument( "--output", type=str, default="./output", help="Output directory (default: ./output)" ) parser.add_argument( "--prompt", choices=prompts, type=str, default="prompt_layout_all_en", help="prompt to query the model, different prompts for different tasks" ) parser.add_argument( '--bbox', type=int, nargs=4, metavar=('x1', 'y1', 'x2', 'y2'), help='should give this argument if you want to prompt_grounding_ocr' ) parser.add_argument( "--ip", type=str, default="localhost", help="" ) parser.add_argument( "--port", type=int, default=8000, help="" ) parser.add_argument( "--model_name", type=str, default="model", help="" ) parser.add_argument( "--temperature", type=float, default=0.1, help="" ) parser.add_argument( "--top_p", type=float, default=1.0, help="" ) parser.add_argument( "--dpi", type=int, default=200, help="" ) parser.add_argument( "--max_completion_tokens", type=int, default=16384, help="" ) parser.add_argument( "--num_thread", type=int, default=16, help="" ) parser.add_argument( "--no_fitz_preprocess", action='store_true', help="False will use tikz dpi upsample pipeline, good for images which has been render with low dpi, but maybe result in higher computational costs" ) parser.add_argument( "--min_pixels", type=int, default=None, help="" ) parser.add_argument( "--max_pixels", type=int, default=None, help="" ) parser.add_argument( "--use_hf", type=bool, default=False, help="" ) args = parser.parse_args() dots_ocr_parser = DotsOCRParser( ip=args.ip, port=args.port, model_name=args.model_name, temperature=args.temperature, top_p=args.top_p, max_completion_tokens=args.max_completion_tokens, num_thread=args.num_thread, dpi=args.dpi, output_dir=args.output, min_pixels=args.min_pixels, max_pixels=args.max_pixels, use_hf=args.use_hf, ) fitz_preprocess = not args.no_fitz_preprocess if fitz_preprocess: print(f"Using fitz preprocess for image input, check the change of the image pixels") result = dots_ocr_parser.parse_file( args.input_path, prompt_mode=args.prompt, bbox=args.bbox, fitz_preprocess=fitz_preprocess, ) if __name__ == "__main__": main()