parser.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. import os
  2. import json
  3. from tqdm import tqdm
  4. from multiprocessing.pool import ThreadPool, Pool
  5. import argparse
  6. from dots_ocr.model.inference import inference_with_vllm
  7. from dots_ocr.utils.consts import image_extensions, MIN_PIXELS, MAX_PIXELS
  8. from dots_ocr.utils.image_utils import get_image_by_fitz_doc, fetch_image, smart_resize
  9. from dots_ocr.utils.doc_utils import fitz_doc_to_image, load_images_from_pdf
  10. from dots_ocr.utils.prompts import dict_promptmode_to_prompt
  11. from dots_ocr.utils.layout_utils import post_process_output, draw_layout_on_image, pre_process_bboxes
  12. from dots_ocr.utils.format_transformer import layoutjson2md
  13. class DotsOCRParser:
  14. """
  15. parse image or pdf file
  16. """
  17. def __init__(self,
  18. ip='localhost',
  19. port=8000,
  20. model_name='model',
  21. temperature=0.1,
  22. top_p=1.0,
  23. max_completion_tokens=16384,
  24. num_thread=64,
  25. dpi = 200,
  26. output_dir="./output",
  27. min_pixels=None,
  28. max_pixels=None,
  29. ):
  30. self.dpi = dpi
  31. # default args for vllm server
  32. self.ip = ip
  33. self.port = port
  34. self.model_name = model_name
  35. # default args for inference
  36. self.temperature = temperature
  37. self.top_p = top_p
  38. self.max_completion_tokens = max_completion_tokens
  39. self.num_thread = num_thread
  40. self.output_dir = output_dir
  41. self.min_pixels = min_pixels
  42. self.max_pixels = max_pixels
  43. assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
  44. assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
  45. def _inference_with_vllm(self, image, prompt):
  46. response = inference_with_vllm(
  47. image,
  48. prompt,
  49. model_name=self.model_name,
  50. ip=self.ip,
  51. port=self.port,
  52. temperature=self.temperature,
  53. top_p=self.top_p,
  54. max_completion_tokens=self.max_completion_tokens,
  55. )
  56. return response
  57. def get_prompt(self, prompt_mode, bbox=None, origin_image=None, image=None, min_pixels=None, max_pixels=None):
  58. prompt = dict_promptmode_to_prompt[prompt_mode]
  59. if prompt_mode == 'prompt_grounding_ocr':
  60. assert bbox is not None
  61. bboxes = [bbox]
  62. bbox = pre_process_bboxes(origin_image, bboxes, input_width=image.width, input_height=image.height, min_pixels=min_pixels, max_pixels=max_pixels)[0]
  63. prompt = prompt + str(bbox)
  64. return prompt
  65. # def post_process_results(self, response, prompt_mode, save_dir, save_name, origin_image, image, min_pixels, max_pixels)
  66. def _parse_single_image(
  67. self,
  68. origin_image,
  69. prompt_mode,
  70. save_dir,
  71. save_name,
  72. source="image",
  73. page_idx=0,
  74. bbox=None,
  75. fitz_preprocess=False,
  76. ):
  77. min_pixels, max_pixels = self.min_pixels, self.max_pixels
  78. if prompt_mode == "prompt_grounding_ocr":
  79. min_pixels = min_pixels or MIN_PIXELS # preprocess image to the final input
  80. max_pixels = max_pixels or MAX_PIXELS
  81. if min_pixels is not None: assert min_pixels >= MIN_PIXELS, f"min_pixels should >= {MIN_PIXELS}"
  82. if max_pixels is not None: assert max_pixels <= MAX_PIXELS, f"max_pixels should <+ {MAX_PIXELS}"
  83. if source == 'image' and fitz_preprocess:
  84. image = get_image_by_fitz_doc(origin_image, target_dpi=self.dpi)
  85. image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
  86. else:
  87. image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
  88. input_height, input_width = smart_resize(image.height, image.width)
  89. prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
  90. response = self._inference_with_vllm(image, prompt)
  91. result = {'page_no': page_idx,
  92. "input_height": input_height,
  93. "input_width": input_width
  94. }
  95. if source == 'pdf':
  96. save_name = f"{save_name}_page_{page_idx}"
  97. if prompt_mode in ['prompt_layout_all_en', 'prompt_layout_only_en', 'prompt_grounding_ocr']:
  98. cells, filtered = post_process_output(
  99. response,
  100. prompt_mode,
  101. origin_image,
  102. image,
  103. min_pixels=min_pixels,
  104. max_pixels=max_pixels,
  105. )
  106. if filtered and prompt_mode != 'prompt_layout_only_en': # model output json failed, use filtered process
  107. json_file_path = os.path.join(save_dir, f"{save_name}.json")
  108. with open(json_file_path, 'w') as w:
  109. json.dump(response, w, ensure_ascii=False)
  110. image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
  111. origin_image.save(image_layout_path)
  112. result.update({
  113. 'layout_info_path': json_file_path,
  114. 'layout_image_path': image_layout_path,
  115. })
  116. md_file_path = os.path.join(save_dir, f"{save_name}.md")
  117. with open(md_file_path, "w", encoding="utf-8") as md_file:
  118. md_file.write(cells)
  119. result.update({
  120. 'md_content_path': md_file_path
  121. })
  122. result.update({
  123. 'filtered': True
  124. })
  125. else:
  126. try:
  127. image_with_layout = draw_layout_on_image(origin_image, cells)
  128. except Exception as e:
  129. print(f"Error drawing layout on image: {e}")
  130. image_with_layout = origin_image
  131. json_file_path = os.path.join(save_dir, f"{save_name}.json")
  132. with open(json_file_path, 'w') as w:
  133. json.dump(cells, w, ensure_ascii=False)
  134. image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
  135. image_with_layout.save(image_layout_path)
  136. result.update({
  137. 'layout_info_path': json_file_path,
  138. 'layout_image_path': image_layout_path,
  139. })
  140. if prompt_mode != "prompt_layout_only_en": # no text md when detection only
  141. md_content = layoutjson2md(origin_image, cells, text_key='text')
  142. md_content_no_hf = layoutjson2md(origin_image, cells, text_key='text', no_page_hf=True) # used for clean output or metric of omnidocbench、olmbench
  143. md_file_path = os.path.join(save_dir, f"{save_name}.md")
  144. with open(md_file_path, "w", encoding="utf-8") as md_file:
  145. md_file.write(md_content)
  146. md_nohf_file_path = os.path.join(save_dir, f"{save_name}_nohf.md")
  147. with open(md_nohf_file_path, "w", encoding="utf-8") as md_file:
  148. md_file.write(md_content_no_hf)
  149. result.update({
  150. 'md_content_path': md_file_path,
  151. 'md_content_nohf_path': md_nohf_file_path,
  152. })
  153. else:
  154. image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
  155. origin_image.save(image_layout_path)
  156. result.update({
  157. 'layout_image_path': image_layout_path,
  158. })
  159. md_content = response
  160. md_file_path = os.path.join(save_dir, f"{save_name}.md")
  161. with open(md_file_path, "w", encoding="utf-8") as md_file:
  162. md_file.write(md_content)
  163. result.update({
  164. 'md_content_path': md_file_path,
  165. })
  166. return result
  167. def parse_image(self, input_path, filename, prompt_mode, save_dir, bbox=None, fitz_preprocess=False):
  168. origin_image = fetch_image(input_path)
  169. result = self._parse_single_image(origin_image, prompt_mode, save_dir, filename, source="image", bbox=bbox, fitz_preprocess=fitz_preprocess)
  170. result['file_path'] = input_path
  171. return [result]
  172. def parse_pdf(self, input_path, filename, prompt_mode, save_dir):
  173. print(f"loading pdf: {input_path}")
  174. images_origin = load_images_from_pdf(input_path, dpi=self.dpi)
  175. total_pages = len(images_origin)
  176. tasks = [
  177. {
  178. "origin_image": image,
  179. "prompt_mode": prompt_mode,
  180. "save_dir": save_dir,
  181. "save_name": filename,
  182. "source":"pdf",
  183. "page_idx": i,
  184. } for i, image in enumerate(images_origin)
  185. ]
  186. def _execute_task(task_args):
  187. return self._parse_single_image(**task_args)
  188. num_thread = min(total_pages, self.num_thread)
  189. print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
  190. results = []
  191. with ThreadPool(num_thread) as pool:
  192. with tqdm(total=total_pages, desc="Processing PDF pages") as pbar:
  193. for result in pool.imap_unordered(_execute_task, tasks):
  194. results.append(result)
  195. pbar.update(1)
  196. results.sort(key=lambda x: x["page_no"])
  197. for i in range(len(results)):
  198. results[i]['file_path'] = input_path
  199. return results
  200. def parse_file(self,
  201. input_path,
  202. output_dir="",
  203. prompt_mode="prompt_layout_all_en",
  204. bbox=None,
  205. fitz_preprocess=False
  206. ):
  207. output_dir = output_dir or self.output_dir
  208. output_dir = os.path.abspath(output_dir)
  209. filename, file_ext = os.path.splitext(os.path.basename(input_path))
  210. save_dir = os.path.join(output_dir, filename)
  211. os.makedirs(save_dir, exist_ok=True)
  212. if file_ext == '.pdf':
  213. results = self.parse_pdf(input_path, filename, prompt_mode, save_dir)
  214. elif file_ext in image_extensions:
  215. results = self.parse_image(input_path, filename, prompt_mode, save_dir, bbox=bbox, fitz_preprocess=fitz_preprocess)
  216. else:
  217. raise ValueError(f"file extension {file_ext} not supported, supported extensions are {image_extensions} and pdf")
  218. print(f"Parsing finished, results saving to {save_dir}")
  219. with open(os.path.join(output_dir, os.path.basename(filename)+'.jsonl'), 'w') as w:
  220. for result in results:
  221. w.write(json.dumps(result, ensure_ascii=False) + '\n')
  222. return results
  223. def main():
  224. prompts = list(dict_promptmode_to_prompt.keys())
  225. parser = argparse.ArgumentParser(
  226. description="dots.ocr Multilingual Document Layout Parser",
  227. )
  228. parser.add_argument(
  229. "input_path", type=str,
  230. help="Input PDF/image file path"
  231. )
  232. parser.add_argument(
  233. "--output", type=str, default="./output",
  234. help="Output directory (default: ./output)"
  235. )
  236. parser.add_argument(
  237. "--prompt", choices=prompts, type=str, default="prompt_layout_all_en",
  238. help="prompt to query the model, different prompts for different tasks"
  239. )
  240. parser.add_argument(
  241. '--bbox',
  242. type=int,
  243. nargs=4,
  244. metavar=('x1', 'y1', 'x2', 'y2'),
  245. help='should give this argument if you want to prompt_grounding_ocr'
  246. )
  247. parser.add_argument(
  248. "--ip", type=str, default="localhost",
  249. help=""
  250. )
  251. parser.add_argument(
  252. "--port", type=int, default=8000,
  253. help=""
  254. )
  255. parser.add_argument(
  256. "--model_name", type=str, default="model",
  257. help=""
  258. )
  259. parser.add_argument(
  260. "--temperature", type=float, default=0.1,
  261. help=""
  262. )
  263. parser.add_argument(
  264. "--top_p", type=float, default=1.0,
  265. help=""
  266. )
  267. parser.add_argument(
  268. "--dpi", type=int, default=200,
  269. help=""
  270. )
  271. parser.add_argument(
  272. "--max_completion_tokens", type=int, default=16384,
  273. help=""
  274. )
  275. parser.add_argument(
  276. "--num_thread", type=int, default=16,
  277. help=""
  278. )
  279. # parser.add_argument(
  280. # "--fitz_preprocess", type=bool, default=False,
  281. # help="False will use tikz dpi upsample pipeline, good for images which has been render with low dpi, but maybe result in higher computational costs"
  282. # )
  283. parser.add_argument(
  284. "--min_pixels", type=int, default=None,
  285. help=""
  286. )
  287. parser.add_argument(
  288. "--max_pixels", type=int, default=None,
  289. help=""
  290. )
  291. args = parser.parse_args()
  292. dots_ocr_parser = DotsOCRParser(
  293. ip=args.ip,
  294. port=args.port,
  295. model_name=args.model_name,
  296. temperature=args.temperature,
  297. top_p=args.top_p,
  298. max_completion_tokens=args.max_completion_tokens,
  299. num_thread=args.num_thread,
  300. dpi=args.dpi,
  301. output_dir=args.output,
  302. min_pixels=args.min_pixels,
  303. max_pixels=args.max_pixels,
  304. )
  305. result = dots_ocr_parser.parse_file(
  306. args.input_path,
  307. prompt_mode=args.prompt,
  308. bbox=args.bbox,
  309. )
  310. if __name__ == "__main__":
  311. main()