Sfoglia il codice sorgente

feat(cli&analyze&pipeline): add start_page and end_page args for pagination (#507)

* feat(cli&analyze&pipeline): add start_page and end_page args for paginationAdd start_page_id and end_page_id arguments to various components of the PDF parsing
pipeline to support pagination functionality. This feature allows users to specify the
range of pages to be processed, enhancing the efficiency and flexibility of the system.

* feat(cli&analyze&pipeline): add start_page and end_page args for paginationAdd start_page_id and end_page_id arguments to various components of the PDF parsing
pipeline to support pagination functionality. This feature allows users to specify the
range of pages to be processed, enhancing the efficiency and flexibility of the system.

* feat(cli&analyze&pipeline): add start_page and end_page args for paginationAdd start_page_id and end_page_id arguments to various components of the PDF parsing
pipeline to support pagination functionality. This feature allows users to specify the
range of pages to be processed, enhancing the efficiency and flexibility of the system.
Xiaomeng Zhao 1 anno fa
parent
commit
0f91fcf61f

+ 13 - 2
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -103,20 +103,31 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
     return custom_model
 
 
-def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
+def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
+                start_page_id=0, end_page_id=None):
 
     model_manager = ModelSingleton()
     custom_model = model_manager.get_model(ocr, show_log)
 
     images = load_images_from_pdf(pdf_bytes)
 
+    end_page_id = end_page_id if end_page_id else len(images) - 1
+
+    if end_page_id > len(images) - 1:
+        logger.warning("end_page_id is out of range, use images length")
+        end_page_id = len(images) - 1
+
     model_json = []
     doc_analyze_start = time.time()
+
     for index, img_dict in enumerate(images):
         img = img_dict["img"]
         page_width = img_dict["width"]
         page_height = img_dict["height"]
-        result = custom_model(img)
+        if start_page_id <= index <= end_page_id:
+            result = custom_model(img)
+        else:
+            result = []
         page_info = {"page_no": index, "height": page_height, "width": page_width}
         page_dict = {"layout_dets": result, "page_info": page_info}
         model_json.append(page_dict)

+ 13 - 3
magic_pdf/pdf_parse_union_core.py

@@ -210,11 +210,14 @@ def pdf_parse_union(pdf_bytes,
     '''根据输入的起始范围解析pdf'''
     end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
 
+    if end_page_id > len(pdf_docs) - 1:
+        logger.warning("end_page_id is out of range, use pdf_docs length")
+        end_page_id = len(pdf_docs) - 1
+
     '''初始化启动时间'''
     start_time = time.time()
 
-    for page_id in range(start_page_id, end_page_id + 1):
-
+    for page_id, page in enumerate(pdf_docs):
         '''debug时输出每页解析的耗时'''
         if debug_mode:
             time_now = time.time()
@@ -224,7 +227,14 @@ def pdf_parse_union(pdf_bytes,
             start_time = time_now
 
         '''解析pdf中的每一页'''
-        page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
+        if start_page_id <= page_id <= end_page_id:
+            page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
+        else:
+            page_w = page.rect.width
+            page_h = page.rect.height
+            page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
+                                                [], [], [], [],
+                                                True, "skip page")
         pdf_info_dict[f"page_{page_id}"] = page_info
 
     """分段"""

+ 4 - 1
magic_pdf/pipe/AbsPipe.py

@@ -16,12 +16,15 @@ class AbsPipe(ABC):
     PIP_OCR = "ocr"
     PIP_TXT = "txt"
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False):
+    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+                 start_page_id=0, end_page_id=None):
         self.pdf_bytes = pdf_bytes
         self.model_list = model_list
         self.image_writer = image_writer
         self.pdf_mid_data = None  # 未压缩
         self.is_debug = is_debug
+        self.start_page_id = start_page_id
+        self.end_page_id = end_page_id
     
     def get_compress_pdf_mid_data(self):
         return JsonCompressor.compress_json(self.pdf_mid_data)

+ 7 - 4
magic_pdf/pipe/OCRPipe.py

@@ -9,17 +9,20 @@ from magic_pdf.user_api import parse_ocr_pdf
 
 class OCRPipe(AbsPipe):
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug)
+    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+                 start_page_id=0, end_page_id=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
 
     def pipe_classify(self):
         pass
 
     def pipe_analyze(self):
-        self.model_list = doc_analyze(self.pdf_bytes, ocr=True)
+        self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
+                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id)
 
     def pipe_parse(self):
-        self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug)
+        self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 7 - 4
magic_pdf/pipe/TXTPipe.py

@@ -10,17 +10,20 @@ from magic_pdf.user_api import parse_txt_pdf
 
 class TXTPipe(AbsPipe):
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug)
+    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+                 start_page_id=0, end_page_id=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
 
     def pipe_classify(self):
         pass
 
     def pipe_analyze(self):
-        self.model_list = doc_analyze(self.pdf_bytes, ocr=False)
+        self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
+                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id)
 
     def pipe_parse(self):
-        self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug)
+        self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 11 - 6
magic_pdf/pipe/UNIPipe.py

@@ -13,9 +13,10 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
 
 class UNIPipe(AbsPipe):
 
-    def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False):
+    def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
+                 start_page_id=0, end_page_id=None):
         self.pdf_type = jso_useful_key["_pdf_type"]
-        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug)
+        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id)
         if len(self.model_list) == 0:
             self.input_model_is_empty = True
         else:
@@ -26,17 +27,21 @@ class UNIPipe(AbsPipe):
 
     def pipe_analyze(self):
         if self.pdf_type == self.PIP_TXT:
-            self.model_list = doc_analyze(self.pdf_bytes, ocr=False)
+            self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
         elif self.pdf_type == self.PIP_OCR:
-            self.model_list = doc_analyze(self.pdf_bytes, ocr=True)
+            self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
 
     def pipe_parse(self):
         if self.pdf_type == self.PIP_TXT:
             self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
-                                                is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty)
+                                                is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
+                                                start_page_id=self.start_page_id, end_page_id=self.end_page_id)
         elif self.pdf_type == self.PIP_OCR:
             self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
-                                              is_debug=self.is_debug)
+                                              is_debug=self.is_debug,
+                                              start_page_id=self.start_page_id, end_page_id=self.end_page_id)
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 20 - 3
magic_pdf/tools/cli.py

@@ -49,11 +49,26 @@ without method specified, auto will be used by default.""",
     '--debug',
     'debug_able',
     type=bool,
-    help=('Enables detailed debugging information during'
-          'the execution of the CLI commands.', ),
+    help='Enables detailed debugging information during the execution of the CLI commands.',
     default=False,
 )
-def cli(path, output_dir, method, debug_able):
+@click.option(
+    '-s',
+    '--start',
+    'start_page_id',
+    type=int,
+    help='The starting page for PDF parsing, beginning from 0.',
+    default=0,
+)
+@click.option(
+    '-e',
+    '--end',
+    'end_page_id',
+    type=int,
+    help='The ending page for PDF parsing, beginning from 0.',
+    default=None,
+)
+def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
     model_config.__use_inside_model__ = True
     model_config.__model_mode__ = 'full'
     os.makedirs(output_dir, exist_ok=True)
@@ -73,6 +88,8 @@ def cli(path, output_dir, method, debug_able):
                 [],
                 method,
                 debug_able,
+                start_page_id=start_page_id,
+                end_page_id=end_page_id,
             )
 
         except Exception as e:

+ 8 - 3
magic_pdf/tools/common.py

@@ -42,6 +42,8 @@ def do_parse(
     f_dump_content_list=False,
     f_make_md_mode=MakeMode.MM_MD,
     f_draw_model_bbox=False,
+    start_page_id=0,
+    end_page_id=None,
 ):
     if debug_able:
         logger.warning("debug mode is on")
@@ -58,11 +60,14 @@ def do_parse(
 
     if parse_method == 'auto':
         jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
-        pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True)
+        pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
+                       start_page_id=start_page_id, end_page_id=end_page_id)
     elif parse_method == 'txt':
-        pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True)
+        pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
+                       start_page_id=start_page_id, end_page_id=end_page_id)
     elif parse_method == 'ocr':
-        pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True)
+        pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
+                       start_page_id=start_page_id, end_page_id=end_page_id)
     else:
         logger.error('unknown parse method')
         exit(1)

+ 17 - 9
magic_pdf/user_api.py

@@ -25,8 +25,9 @@ PARSE_TYPE_TXT = "txt"
 PARSE_TYPE_OCR = "ocr"
 
 
-def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args,
-                  **kwargs):
+def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
+                  start_page_id=0, end_page_id=None,
+                  *args, **kwargs):
     """
     解析文本类pdf
     """
@@ -34,7 +35,8 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
         pdf_bytes,
         pdf_models,
         imageWriter,
-        start_page_id=start_page,
+        start_page_id=start_page_id,
+        end_page_id=end_page_id,
         debug_mode=is_debug,
     )
 
@@ -45,8 +47,9 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
     return pdf_info_dict
 
 
-def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args,
-                  **kwargs):
+def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
+                  start_page_id=0, end_page_id=None,
+                  *args, **kwargs):
     """
     解析ocr类pdf
     """
@@ -54,7 +57,8 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
         pdf_bytes,
         pdf_models,
         imageWriter,
-        start_page_id=start_page,
+        start_page_id=start_page_id,
+        end_page_id=end_page_id,
         debug_mode=is_debug,
     )
 
@@ -65,8 +69,9 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
     return pdf_info_dict
 
 
-def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0,
+def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
                     input_model_is_empty: bool = False,
+                    start_page_id=0, end_page_id=None,
                     *args, **kwargs):
     """
     ocr和文本混合的pdf,全部解析出来
@@ -78,7 +83,8 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
                 pdf_bytes,
                 pdf_models,
                 imageWriter,
-                start_page_id=start_page,
+                start_page_id=start_page_id,
+                end_page_id=end_page_id,
                 debug_mode=is_debug,
             )
         except Exception as e:
@@ -89,7 +95,9 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
     if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
         logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
         if input_model_is_empty:
-            pdf_models = doc_analyze(pdf_bytes, ocr=True)
+            pdf_models = doc_analyze(pdf_bytes, ocr=True,
+                                     start_page_id=start_page_id,
+                                     end_page_id=end_page_id)
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         if pdf_info_dict is None:
             raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")