parser.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. import os
  2. import json
  3. from tqdm import tqdm
  4. from multiprocessing.pool import ThreadPool, Pool
  5. import argparse
  6. from dots_ocr.model.inference import inference_with_vllm
  7. from dots_ocr.utils.consts import image_extensions, MIN_PIXELS, MAX_PIXELS
  8. from dots_ocr.utils.image_utils import get_image_by_fitz_doc, fetch_image, smart_resize
  9. from dots_ocr.utils.doc_utils import fitz_doc_to_image, load_images_from_pdf
  10. from dots_ocr.utils.prompts import dict_promptmode_to_prompt
  11. from dots_ocr.utils.layout_utils import post_process_output, draw_layout_on_image, pre_process_bboxes
  12. from dots_ocr.utils.format_transformer import layoutjson2md
  13. class DotsOCRParser:
  14. """
  15. parse image or pdf file
  16. """
  17. def __init__(self,
  18. protocol='http',
  19. ip='localhost',
  20. port=8000,
  21. model_name='model',
  22. temperature=0.1,
  23. top_p=1.0,
  24. max_completion_tokens=16384,
  25. num_thread=64,
  26. dpi = 200,
  27. output_dir="./output",
  28. min_pixels=None,
  29. max_pixels=None,
  30. use_hf=False,
  31. ):
  32. self.dpi = dpi
  33. # default args for vllm server
  34. self.protocol = protocol
  35. self.ip = ip
  36. self.port = port
  37. self.model_name = model_name
  38. # default args for inference
  39. self.temperature = temperature
  40. self.top_p = top_p
  41. self.max_completion_tokens = max_completion_tokens
  42. self.num_thread = num_thread
  43. self.output_dir = output_dir
  44. self.min_pixels = min_pixels
  45. self.max_pixels = max_pixels
  46. self.use_hf = use_hf
  47. if self.use_hf:
  48. self._load_hf_model()
  49. print(f"use hf model, num_thread will be set to 1")
  50. else:
  51. print(f"use vllm model, num_thread will be set to {self.num_thread}")
  52. assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
  53. assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
  54. def _load_hf_model(self):
  55. import torch
  56. from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
  57. from qwen_vl_utils import process_vision_info
  58. model_path = "./weights/DotsOCR"
  59. self.model = AutoModelForCausalLM.from_pretrained(
  60. model_path,
  61. attn_implementation="flash_attention_2",
  62. torch_dtype=torch.bfloat16,
  63. device_map="auto",
  64. trust_remote_code=True
  65. )
  66. self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True,use_fast=True)
  67. self.process_vision_info = process_vision_info
  68. def _inference_with_hf(self, image, prompt):
  69. messages = [
  70. {
  71. "role": "user",
  72. "content": [
  73. {
  74. "type": "image",
  75. "image": image
  76. },
  77. {"type": "text", "text": prompt}
  78. ]
  79. }
  80. ]
  81. # Preparation for inference
  82. text = self.processor.apply_chat_template(
  83. messages,
  84. tokenize=False,
  85. add_generation_prompt=True
  86. )
  87. image_inputs, video_inputs = self.process_vision_info(messages)
  88. inputs = self.processor(
  89. text=[text],
  90. images=image_inputs,
  91. videos=video_inputs,
  92. padding=True,
  93. return_tensors="pt",
  94. )
  95. inputs = inputs.to("cuda")
  96. # Inference: Generation of the output
  97. generated_ids = self.model.generate(**inputs, max_new_tokens=24000)
  98. generated_ids_trimmed = [
  99. out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
  100. ]
  101. response = self.processor.batch_decode(
  102. generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
  103. )[0]
  104. return response
  105. def _inference_with_vllm(self, image, prompt):
  106. response = inference_with_vllm(
  107. image,
  108. prompt,
  109. model_name=self.model_name,
  110. protocol=self.protocol,
  111. ip=self.ip,
  112. port=self.port,
  113. temperature=self.temperature,
  114. top_p=self.top_p,
  115. max_completion_tokens=self.max_completion_tokens,
  116. )
  117. return response
  118. def get_prompt(self, prompt_mode, bbox=None, origin_image=None, image=None, min_pixels=None, max_pixels=None):
  119. prompt = dict_promptmode_to_prompt[prompt_mode]
  120. if prompt_mode == 'prompt_grounding_ocr':
  121. assert bbox is not None
  122. bboxes = [bbox]
  123. bbox = pre_process_bboxes(origin_image, bboxes, input_width=image.width, input_height=image.height, min_pixels=min_pixels, max_pixels=max_pixels)[0]
  124. prompt = prompt + str(bbox)
  125. return prompt
  126. # def post_process_results(self, response, prompt_mode, save_dir, save_name, origin_image, image, min_pixels, max_pixels)
  127. def _parse_single_image(
  128. self,
  129. origin_image,
  130. prompt_mode,
  131. save_dir,
  132. save_name,
  133. source="image",
  134. page_idx=0,
  135. bbox=None,
  136. fitz_preprocess=False,
  137. ):
  138. min_pixels, max_pixels = self.min_pixels, self.max_pixels
  139. if prompt_mode == "prompt_grounding_ocr":
  140. min_pixels = min_pixels or MIN_PIXELS # preprocess image to the final input
  141. max_pixels = max_pixels or MAX_PIXELS
  142. if min_pixels is not None: assert min_pixels >= MIN_PIXELS, f"min_pixels should >= {MIN_PIXELS}"
  143. if max_pixels is not None: assert max_pixels <= MAX_PIXELS, f"max_pixels should <= {MAX_PIXELS}"
  144. if source == 'image' and fitz_preprocess:
  145. image = get_image_by_fitz_doc(origin_image, target_dpi=self.dpi)
  146. image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
  147. else:
  148. image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
  149. input_height, input_width = smart_resize(image.height, image.width)
  150. prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
  151. if self.use_hf:
  152. response = self._inference_with_hf(image, prompt)
  153. else:
  154. response = self._inference_with_vllm(image, prompt)
  155. result = {'page_no': page_idx,
  156. "input_height": input_height,
  157. "input_width": input_width
  158. }
  159. if source == 'pdf':
  160. save_name = f"{save_name}_page_{page_idx}"
  161. if prompt_mode in ['prompt_layout_all_en', 'prompt_layout_only_en', 'prompt_grounding_ocr']:
  162. cells, filtered = post_process_output(
  163. response,
  164. prompt_mode,
  165. origin_image,
  166. image,
  167. min_pixels=min_pixels,
  168. max_pixels=max_pixels,
  169. )
  170. if filtered and prompt_mode != 'prompt_layout_only_en': # model output json failed, use filtered process
  171. json_file_path = os.path.join(save_dir, f"{save_name}.json")
  172. with open(json_file_path, 'w', encoding="utf-8") as w:
  173. json.dump(response, w, ensure_ascii=False)
  174. image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
  175. origin_image.save(image_layout_path)
  176. result.update({
  177. 'layout_info_path': json_file_path,
  178. 'layout_image_path': image_layout_path,
  179. })
  180. md_file_path = os.path.join(save_dir, f"{save_name}.md")
  181. with open(md_file_path, "w", encoding="utf-8") as md_file:
  182. md_file.write(cells)
  183. result.update({
  184. 'md_content_path': md_file_path
  185. })
  186. result.update({
  187. 'filtered': True
  188. })
  189. else:
  190. try:
  191. image_with_layout = draw_layout_on_image(origin_image, cells)
  192. except Exception as e:
  193. print(f"Error drawing layout on image: {e}")
  194. image_with_layout = origin_image
  195. json_file_path = os.path.join(save_dir, f"{save_name}.json")
  196. with open(json_file_path, 'w', encoding="utf-8") as w:
  197. json.dump(cells, w, ensure_ascii=False)
  198. image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
  199. image_with_layout.save(image_layout_path)
  200. result.update({
  201. 'layout_info_path': json_file_path,
  202. 'layout_image_path': image_layout_path,
  203. })
  204. if prompt_mode != "prompt_layout_only_en": # no text md when detection only
  205. md_content = layoutjson2md(origin_image, cells, text_key='text')
  206. md_content_no_hf = layoutjson2md(origin_image, cells, text_key='text', no_page_hf=True) # used for clean output or metric of omnidocbench、olmbench
  207. md_file_path = os.path.join(save_dir, f"{save_name}.md")
  208. with open(md_file_path, "w", encoding="utf-8") as md_file:
  209. md_file.write(md_content)
  210. md_nohf_file_path = os.path.join(save_dir, f"{save_name}_nohf.md")
  211. with open(md_nohf_file_path, "w", encoding="utf-8") as md_file:
  212. md_file.write(md_content_no_hf)
  213. result.update({
  214. 'md_content_path': md_file_path,
  215. 'md_content_nohf_path': md_nohf_file_path,
  216. })
  217. else:
  218. image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
  219. origin_image.save(image_layout_path)
  220. result.update({
  221. 'layout_image_path': image_layout_path,
  222. })
  223. md_content = response
  224. md_file_path = os.path.join(save_dir, f"{save_name}.md")
  225. with open(md_file_path, "w", encoding="utf-8") as md_file:
  226. md_file.write(md_content)
  227. result.update({
  228. 'md_content_path': md_file_path,
  229. })
  230. return result
  231. def parse_image(self, input_path, filename, prompt_mode, save_dir, bbox=None, fitz_preprocess=False):
  232. origin_image = fetch_image(input_path)
  233. result = self._parse_single_image(origin_image, prompt_mode, save_dir, filename, source="image", bbox=bbox, fitz_preprocess=fitz_preprocess)
  234. result['file_path'] = input_path
  235. return [result]
  236. def parse_pdf(self, input_path, filename, prompt_mode, save_dir):
  237. print(f"loading pdf: {input_path}")
  238. images_origin = load_images_from_pdf(input_path, dpi=self.dpi)
  239. total_pages = len(images_origin)
  240. tasks = [
  241. {
  242. "origin_image": image,
  243. "prompt_mode": prompt_mode,
  244. "save_dir": save_dir,
  245. "save_name": filename,
  246. "source":"pdf",
  247. "page_idx": i,
  248. } for i, image in enumerate(images_origin)
  249. ]
  250. def _execute_task(task_args):
  251. return self._parse_single_image(**task_args)
  252. if self.use_hf:
  253. num_thread = 1
  254. else:
  255. num_thread = min(total_pages, self.num_thread)
  256. print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
  257. results = []
  258. with ThreadPool(num_thread) as pool:
  259. with tqdm(total=total_pages, desc="Processing PDF pages") as pbar:
  260. for result in pool.imap_unordered(_execute_task, tasks):
  261. results.append(result)
  262. pbar.update(1)
  263. results.sort(key=lambda x: x["page_no"])
  264. for i in range(len(results)):
  265. results[i]['file_path'] = input_path
  266. return results
  267. def parse_file(self,
  268. input_path,
  269. output_dir="",
  270. prompt_mode="prompt_layout_all_en",
  271. bbox=None,
  272. fitz_preprocess=False
  273. ):
  274. output_dir = output_dir or self.output_dir
  275. output_dir = os.path.abspath(output_dir)
  276. filename, file_ext = os.path.splitext(os.path.basename(input_path))
  277. save_dir = os.path.join(output_dir, filename)
  278. os.makedirs(save_dir, exist_ok=True)
  279. if file_ext == '.pdf':
  280. results = self.parse_pdf(input_path, filename, prompt_mode, save_dir)
  281. elif file_ext in image_extensions:
  282. results = self.parse_image(input_path, filename, prompt_mode, save_dir, bbox=bbox, fitz_preprocess=fitz_preprocess)
  283. else:
  284. raise ValueError(f"file extension {file_ext} not supported, supported extensions are {image_extensions} and pdf")
  285. print(f"Parsing finished, results saving to {save_dir}")
  286. with open(os.path.join(output_dir, os.path.basename(filename)+'.jsonl'), 'w', encoding="utf-8") as w:
  287. for result in results:
  288. w.write(json.dumps(result, ensure_ascii=False) + '\n')
  289. return results
  290. def main():
  291. prompts = list(dict_promptmode_to_prompt.keys())
  292. parser = argparse.ArgumentParser(
  293. description="dots.ocr Multilingual Document Layout Parser",
  294. )
  295. parser.add_argument(
  296. "input_path", type=str,
  297. help="Input PDF/image file path"
  298. )
  299. parser.add_argument(
  300. "--output", type=str, default="./output",
  301. help="Output directory (default: ./output)"
  302. )
  303. parser.add_argument(
  304. "--prompt", choices=prompts, type=str, default="prompt_layout_all_en",
  305. help="prompt to query the model, different prompts for different tasks"
  306. )
  307. parser.add_argument(
  308. '--bbox',
  309. type=int,
  310. nargs=4,
  311. metavar=('x1', 'y1', 'x2', 'y2'),
  312. help='should give this argument if you want to prompt_grounding_ocr'
  313. )
  314. parser.add_argument(
  315. "--protocol", type=str, choices=['http', 'https'], default="http",
  316. help=""
  317. )
  318. parser.add_argument(
  319. "--ip", type=str, default="localhost",
  320. help=""
  321. )
  322. parser.add_argument(
  323. "--port", type=int, default=8000,
  324. help=""
  325. )
  326. parser.add_argument(
  327. "--model_name", type=str, default="model",
  328. help=""
  329. )
  330. parser.add_argument(
  331. "--temperature", type=float, default=0.1,
  332. help=""
  333. )
  334. parser.add_argument(
  335. "--top_p", type=float, default=1.0,
  336. help=""
  337. )
  338. parser.add_argument(
  339. "--dpi", type=int, default=200,
  340. help=""
  341. )
  342. parser.add_argument(
  343. "--max_completion_tokens", type=int, default=16384,
  344. help=""
  345. )
  346. parser.add_argument(
  347. "--num_thread", type=int, default=16,
  348. help=""
  349. )
  350. parser.add_argument(
  351. "--no_fitz_preprocess", action='store_true',
  352. help="False will use tikz dpi upsample pipeline, good for images which has been render with low dpi, but maybe result in higher computational costs"
  353. )
  354. parser.add_argument(
  355. "--min_pixels", type=int, default=None,
  356. help=""
  357. )
  358. parser.add_argument(
  359. "--max_pixels", type=int, default=None,
  360. help=""
  361. )
  362. parser.add_argument(
  363. "--use_hf", type=bool, default=False,
  364. help=""
  365. )
  366. args = parser.parse_args()
  367. dots_ocr_parser = DotsOCRParser(
  368. protocol=args.protocol,
  369. ip=args.ip,
  370. port=args.port,
  371. model_name=args.model_name,
  372. temperature=args.temperature,
  373. top_p=args.top_p,
  374. max_completion_tokens=args.max_completion_tokens,
  375. num_thread=args.num_thread,
  376. dpi=args.dpi,
  377. output_dir=args.output,
  378. min_pixels=args.min_pixels,
  379. max_pixels=args.max_pixels,
  380. use_hf=args.use_hf,
  381. )
  382. fitz_preprocess = not args.no_fitz_preprocess
  383. if fitz_preprocess:
  384. print(f"Using fitz preprocess for image input, check the change of the image pixels")
  385. result = dots_ocr_parser.parse_file(
  386. args.input_path,
  387. prompt_mode=args.prompt,
  388. bbox=args.bbox,
  389. fitz_preprocess=fitz_preprocess,
  390. )
  391. if __name__ == "__main__":
  392. main()