| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349 |
- import os
- import json
- from tqdm import tqdm
- from multiprocessing.pool import ThreadPool, Pool
- import argparse
- from dots_ocr.model.inference import inference_with_vllm
- from dots_ocr.utils.consts import image_extensions, MIN_PIXELS, MAX_PIXELS
- from dots_ocr.utils.image_utils import get_image_by_fitz_doc, fetch_image, smart_resize
- from dots_ocr.utils.doc_utils import fitz_doc_to_image, load_images_from_pdf
- from dots_ocr.utils.prompts import dict_promptmode_to_prompt
- from dots_ocr.utils.layout_utils import post_process_output, draw_layout_on_image, pre_process_bboxes
- from dots_ocr.utils.format_transformer import layoutjson2md
- class DotsOCRParser:
- """
- parse image or pdf file
- """
-
- def __init__(self,
- ip='localhost',
- port=8000,
- model_name='model',
- temperature=0.1,
- top_p=1.0,
- max_completion_tokens=16384,
- num_thread=64,
- dpi = 200,
- output_dir="./output",
- min_pixels=None,
- max_pixels=None,
- ):
- self.dpi = dpi
- # default args for vllm server
- self.ip = ip
- self.port = port
- self.model_name = model_name
- # default args for inference
- self.temperature = temperature
- self.top_p = top_p
- self.max_completion_tokens = max_completion_tokens
- self.num_thread = num_thread
- self.output_dir = output_dir
- self.min_pixels = min_pixels
- self.max_pixels = max_pixels
- assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
- assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
- def _inference_with_vllm(self, image, prompt):
- response = inference_with_vllm(
- image,
- prompt,
- model_name=self.model_name,
- ip=self.ip,
- port=self.port,
- temperature=self.temperature,
- top_p=self.top_p,
- max_completion_tokens=self.max_completion_tokens,
- )
- return response
- def get_prompt(self, prompt_mode, bbox=None, origin_image=None, image=None, min_pixels=None, max_pixels=None):
- prompt = dict_promptmode_to_prompt[prompt_mode]
- if prompt_mode == 'prompt_grounding_ocr':
- assert bbox is not None
- bboxes = [bbox]
- bbox = pre_process_bboxes(origin_image, bboxes, input_width=image.width, input_height=image.height, min_pixels=min_pixels, max_pixels=max_pixels)[0]
- prompt = prompt + str(bbox)
- return prompt
- # def post_process_results(self, response, prompt_mode, save_dir, save_name, origin_image, image, min_pixels, max_pixels)
- def _parse_single_image(
- self,
- origin_image,
- prompt_mode,
- save_dir,
- save_name,
- source="image",
- page_idx=0,
- bbox=None,
- fitz_preprocess=False,
- ):
- min_pixels, max_pixels = self.min_pixels, self.max_pixels
- if prompt_mode == "prompt_grounding_ocr":
- min_pixels = min_pixels or MIN_PIXELS # preprocess image to the final input
- max_pixels = max_pixels or MAX_PIXELS
- if min_pixels is not None: assert min_pixels >= MIN_PIXELS, f"min_pixels should >= {MIN_PIXELS}"
- if max_pixels is not None: assert max_pixels <= MAX_PIXELS, f"max_pixels should <+ {MAX_PIXELS}"
- if source == 'image' and fitz_preprocess:
- image = get_image_by_fitz_doc(origin_image, target_dpi=self.dpi)
- image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
- else:
- image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
- input_height, input_width = smart_resize(image.height, image.width)
- prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
- response = self._inference_with_vllm(image, prompt)
- result = {'page_no': page_idx,
- "input_height": input_height,
- "input_width": input_width
- }
- if source == 'pdf':
- save_name = f"{save_name}_page_{page_idx}"
- if prompt_mode in ['prompt_layout_all_en', 'prompt_layout_only_en', 'prompt_grounding_ocr']:
- cells, filtered = post_process_output(
- response,
- prompt_mode,
- origin_image,
- image,
- min_pixels=min_pixels,
- max_pixels=max_pixels,
- )
- if filtered and prompt_mode != 'prompt_layout_only_en': # model output json failed, use filtered process
- json_file_path = os.path.join(save_dir, f"{save_name}.json")
- with open(json_file_path, 'w') as w:
- json.dump(response, w, ensure_ascii=False)
- image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
- origin_image.save(image_layout_path)
- result.update({
- 'layout_info_path': json_file_path,
- 'layout_image_path': image_layout_path,
- })
- md_file_path = os.path.join(save_dir, f"{save_name}.md")
- with open(md_file_path, "w", encoding="utf-8") as md_file:
- md_file.write(cells)
- result.update({
- 'md_content_path': md_file_path
- })
- result.update({
- 'filtered': True
- })
- else:
- try:
- image_with_layout = draw_layout_on_image(origin_image, cells)
- except Exception as e:
- print(f"Error drawing layout on image: {e}")
- image_with_layout = origin_image
- json_file_path = os.path.join(save_dir, f"{save_name}.json")
- with open(json_file_path, 'w') as w:
- json.dump(cells, w, ensure_ascii=False)
- image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
- image_with_layout.save(image_layout_path)
- result.update({
- 'layout_info_path': json_file_path,
- 'layout_image_path': image_layout_path,
- })
- if prompt_mode != "prompt_layout_only_en": # no text md when detection only
- md_content = layoutjson2md(origin_image, cells, text_key='text')
- md_content_no_hf = layoutjson2md(origin_image, cells, text_key='text', no_page_hf=True) # used for clean output or metric of omnidocbench、olmbench
- md_file_path = os.path.join(save_dir, f"{save_name}.md")
- with open(md_file_path, "w", encoding="utf-8") as md_file:
- md_file.write(md_content)
- md_nohf_file_path = os.path.join(save_dir, f"{save_name}_nohf.md")
- with open(md_nohf_file_path, "w", encoding="utf-8") as md_file:
- md_file.write(md_content_no_hf)
- result.update({
- 'md_content_path': md_file_path,
- 'md_content_nohf_path': md_nohf_file_path,
- })
- else:
- image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
- origin_image.save(image_layout_path)
- result.update({
- 'layout_image_path': image_layout_path,
- })
- md_content = response
- md_file_path = os.path.join(save_dir, f"{save_name}.md")
- with open(md_file_path, "w", encoding="utf-8") as md_file:
- md_file.write(md_content)
- result.update({
- 'md_content_path': md_file_path,
- })
- return result
-
- def parse_image(self, input_path, filename, prompt_mode, save_dir, bbox=None, fitz_preprocess=False):
- origin_image = fetch_image(input_path)
- result = self._parse_single_image(origin_image, prompt_mode, save_dir, filename, source="image", bbox=bbox, fitz_preprocess=fitz_preprocess)
- result['file_path'] = input_path
- return [result]
-
- def parse_pdf(self, input_path, filename, prompt_mode, save_dir):
- print(f"loading pdf: {input_path}")
- images_origin = load_images_from_pdf(input_path)
- total_pages = len(images_origin)
- tasks = [
- {
- "origin_image": image,
- "prompt_mode": prompt_mode,
- "save_dir": save_dir,
- "save_name": filename,
- "source":"pdf",
- "page_idx": i,
- } for i, image in enumerate(images_origin)
- ]
- def _execute_task(task_args):
- return self._parse_single_image(**task_args)
- num_thread = min(total_pages, self.num_thread)
- print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
- results = []
- with ThreadPool(num_thread) as pool:
- with tqdm(total=total_pages, desc="Processing PDF pages") as pbar:
- for result in pool.imap_unordered(_execute_task, tasks):
- results.append(result)
- pbar.update(1)
- results.sort(key=lambda x: x["page_no"])
- for i in range(len(results)):
- results[i]['file_path'] = input_path
- return results
- def parse_file(self,
- input_path,
- output_dir="",
- prompt_mode="prompt_layout_all_en",
- bbox=None,
- fitz_preprocess=False
- ):
- output_dir = output_dir or self.output_dir
- output_dir = os.path.abspath(output_dir)
- filename, file_ext = os.path.splitext(os.path.basename(input_path))
- save_dir = os.path.join(output_dir, filename)
- os.makedirs(save_dir, exist_ok=True)
- if file_ext == '.pdf':
- results = self.parse_pdf(input_path, filename, prompt_mode, save_dir)
- elif file_ext in image_extensions:
- results = self.parse_image(input_path, filename, prompt_mode, save_dir, bbox=bbox, fitz_preprocess=fitz_preprocess)
- else:
- raise ValueError(f"file extension {file_ext} not supported, supported extensions are {image_extensions} and pdf")
-
- print(f"Parsing finished, results saving to {save_dir}")
- with open(os.path.join(output_dir, os.path.basename(filename)+'.jsonl'), 'w') as w:
- for result in results:
- w.write(json.dumps(result, ensure_ascii=False) + '\n')
- return results
- def main():
- prompts = list(dict_promptmode_to_prompt.keys())
- parser = argparse.ArgumentParser(
- description="dots.ocr Multilingual Document Layout Parser",
- )
-
- parser.add_argument(
- "input_path", type=str,
- help="Input PDF/image file path"
- )
-
- parser.add_argument(
- "--output", type=str, default="./output",
- help="Output directory (default: ./output)"
- )
-
- parser.add_argument(
- "--prompt", choices=prompts, type=str, default="prompt_layout_all_en",
- help="prompt to query the model, different prompts for different tasks"
- )
- parser.add_argument(
- '--bbox',
- type=int,
- nargs=4,
- metavar=('x1', 'y1', 'x2', 'y2'),
- help='should give this argument if you want to prompt_grounding_ocr'
- )
- parser.add_argument(
- "--ip", type=str, default="localhost",
- help=""
- )
- parser.add_argument(
- "--port", type=int, default=8000,
- help=""
- )
- parser.add_argument(
- "--model_name", type=str, default="model",
- help=""
- )
- parser.add_argument(
- "--temperature", type=float, default=0.1,
- help=""
- )
- parser.add_argument(
- "--top_p", type=float, default=1.0,
- help=""
- )
- parser.add_argument(
- "--dpi", type=int, default=200,
- help=""
- )
- parser.add_argument(
- "--max_completion_tokens", type=int, default=16384,
- help=""
- )
- parser.add_argument(
- "--num_thread", type=int, default=16,
- help=""
- )
- # parser.add_argument(
- # "--fitz_preprocess", type=bool, default=False,
- # help="False will use tikz dpi upsample pipeline, good for images which has been render with low dpi, but maybe result in higher computational costs"
- # )
- parser.add_argument(
- "--min_pixels", type=int, default=None,
- help=""
- )
- parser.add_argument(
- "--max_pixels", type=int, default=None,
- help=""
- )
- args = parser.parse_args()
- dots_ocr_parser = DotsOCRParser(
- ip=args.ip,
- port=args.port,
- model_name=args.model_name,
- temperature=args.temperature,
- top_p=args.top_p,
- max_completion_tokens=args.max_completion_tokens,
- num_thread=args.num_thread,
- dpi=args.dpi,
- output_dir=args.output,
- min_pixels=args.min_pixels,
- max_pixels=args.max_pixels,
- )
- result = dots_ocr_parser.parse_file(
- args.input_path,
- prompt_mode=args.prompt,
- bbox=args.bbox,
- )
-
- if __name__ == "__main__":
- main()
|