Browse Source

feat(cli&analyze&pipeline): add start_page and end_page args for pagination (#507)

* feat(cli&analyze&pipeline): add start_page and end_page args for paginationAdd start_page_id and end_page_id arguments to various components of the PDF parsing
pipeline to support pagination functionality. This feature allows users to specify the
range of pages to be processed, enhancing the efficiency and flexibility of the system.

* feat(cli&analyze&pipeline): add start_page and end_page args for paginationAdd start_page_id and end_page_id arguments to various components of the PDF parsing
pipeline to support pagination functionality. This feature allows users to specify the
range of pages to be processed, enhancing the efficiency and flexibility of the system.

* feat(cli&analyze&pipeline): add start_page and end_page args for paginationAdd start_page_id and end_page_id arguments to various components of the PDF parsing
pipeline to support pagination functionality. This feature allows users to specify the
range of pages to be processed, enhancing the efficiency and flexibility of the system.
Xiaomeng Zhao 1 year ago
parent
commit
0f91fcf61f

+ 13 - 2
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -103,20 +103,31 @@ def custom_model_init(ocr: bool = False, show_log: bool = False):
     return custom_model
     return custom_model
 
 
 
 
-def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False):
+def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
+                start_page_id=0, end_page_id=None):
 
 
     model_manager = ModelSingleton()
     model_manager = ModelSingleton()
     custom_model = model_manager.get_model(ocr, show_log)
     custom_model = model_manager.get_model(ocr, show_log)
 
 
     images = load_images_from_pdf(pdf_bytes)
     images = load_images_from_pdf(pdf_bytes)
 
 
+    end_page_id = end_page_id if end_page_id else len(images) - 1
+
+    if end_page_id > len(images) - 1:
+        logger.warning("end_page_id is out of range, use images length")
+        end_page_id = len(images) - 1
+
     model_json = []
     model_json = []
     doc_analyze_start = time.time()
     doc_analyze_start = time.time()
+
     for index, img_dict in enumerate(images):
     for index, img_dict in enumerate(images):
         img = img_dict["img"]
         img = img_dict["img"]
         page_width = img_dict["width"]
         page_width = img_dict["width"]
         page_height = img_dict["height"]
         page_height = img_dict["height"]
-        result = custom_model(img)
+        if start_page_id <= index <= end_page_id:
+            result = custom_model(img)
+        else:
+            result = []
         page_info = {"page_no": index, "height": page_height, "width": page_width}
         page_info = {"page_no": index, "height": page_height, "width": page_width}
         page_dict = {"layout_dets": result, "page_info": page_info}
         page_dict = {"layout_dets": result, "page_info": page_info}
         model_json.append(page_dict)
         model_json.append(page_dict)

+ 13 - 3
magic_pdf/pdf_parse_union_core.py

@@ -210,11 +210,14 @@ def pdf_parse_union(pdf_bytes,
     '''根据输入的起始范围解析pdf'''
     '''根据输入的起始范围解析pdf'''
     end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
     end_page_id = end_page_id if end_page_id else len(pdf_docs) - 1
 
 
+    if end_page_id > len(pdf_docs) - 1:
+        logger.warning("end_page_id is out of range, use pdf_docs length")
+        end_page_id = len(pdf_docs) - 1
+
     '''初始化启动时间'''
     '''初始化启动时间'''
     start_time = time.time()
     start_time = time.time()
 
 
-    for page_id in range(start_page_id, end_page_id + 1):
-
+    for page_id, page in enumerate(pdf_docs):
         '''debug时输出每页解析的耗时'''
         '''debug时输出每页解析的耗时'''
         if debug_mode:
         if debug_mode:
             time_now = time.time()
             time_now = time.time()
@@ -224,7 +227,14 @@ def pdf_parse_union(pdf_bytes,
             start_time = time_now
             start_time = time_now
 
 
         '''解析pdf中的每一页'''
         '''解析pdf中的每一页'''
-        page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
+        if start_page_id <= page_id <= end_page_id:
+            page_info = parse_page_core(pdf_docs, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode)
+        else:
+            page_w = page.rect.width
+            page_h = page.rect.height
+            page_info = ocr_construct_page_component_v2([], [], page_id, page_w, page_h, [],
+                                                [], [], [], [],
+                                                True, "skip page")
         pdf_info_dict[f"page_{page_id}"] = page_info
         pdf_info_dict[f"page_{page_id}"] = page_info
 
 
     """分段"""
     """分段"""

+ 4 - 1
magic_pdf/pipe/AbsPipe.py

@@ -16,12 +16,15 @@ class AbsPipe(ABC):
     PIP_OCR = "ocr"
     PIP_OCR = "ocr"
     PIP_TXT = "txt"
     PIP_TXT = "txt"
 
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False):
+    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+                 start_page_id=0, end_page_id=None):
         self.pdf_bytes = pdf_bytes
         self.pdf_bytes = pdf_bytes
         self.model_list = model_list
         self.model_list = model_list
         self.image_writer = image_writer
         self.image_writer = image_writer
         self.pdf_mid_data = None  # 未压缩
         self.pdf_mid_data = None  # 未压缩
         self.is_debug = is_debug
         self.is_debug = is_debug
+        self.start_page_id = start_page_id
+        self.end_page_id = end_page_id
     
     
     def get_compress_pdf_mid_data(self):
     def get_compress_pdf_mid_data(self):
         return JsonCompressor.compress_json(self.pdf_mid_data)
         return JsonCompressor.compress_json(self.pdf_mid_data)

+ 7 - 4
magic_pdf/pipe/OCRPipe.py

@@ -9,17 +9,20 @@ from magic_pdf.user_api import parse_ocr_pdf
 
 
 class OCRPipe(AbsPipe):
 class OCRPipe(AbsPipe):
 
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug)
+    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+                 start_page_id=0, end_page_id=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
 
 
     def pipe_classify(self):
     def pipe_classify(self):
         pass
         pass
 
 
     def pipe_analyze(self):
     def pipe_analyze(self):
-        self.model_list = doc_analyze(self.pdf_bytes, ocr=True)
+        self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
+                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id)
 
 
     def pipe_parse(self):
     def pipe_parse(self):
-        self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug)
+        self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
 
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 7 - 4
magic_pdf/pipe/TXTPipe.py

@@ -10,17 +10,20 @@ from magic_pdf.user_api import parse_txt_pdf
 
 
 class TXTPipe(AbsPipe):
 class TXTPipe(AbsPipe):
 
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug)
+    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+                 start_page_id=0, end_page_id=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
 
 
     def pipe_classify(self):
     def pipe_classify(self):
         pass
         pass
 
 
     def pipe_analyze(self):
     def pipe_analyze(self):
-        self.model_list = doc_analyze(self.pdf_bytes, ocr=False)
+        self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
+                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id)
 
 
     def pipe_parse(self):
     def pipe_parse(self):
-        self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug)
+        self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
 
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 11 - 6
magic_pdf/pipe/UNIPipe.py

@@ -13,9 +13,10 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
 
 
 class UNIPipe(AbsPipe):
 class UNIPipe(AbsPipe):
 
 
-    def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False):
+    def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
+                 start_page_id=0, end_page_id=None):
         self.pdf_type = jso_useful_key["_pdf_type"]
         self.pdf_type = jso_useful_key["_pdf_type"]
-        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug)
+        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id)
         if len(self.model_list) == 0:
         if len(self.model_list) == 0:
             self.input_model_is_empty = True
             self.input_model_is_empty = True
         else:
         else:
@@ -26,17 +27,21 @@ class UNIPipe(AbsPipe):
 
 
     def pipe_analyze(self):
     def pipe_analyze(self):
         if self.pdf_type == self.PIP_TXT:
         if self.pdf_type == self.PIP_TXT:
-            self.model_list = doc_analyze(self.pdf_bytes, ocr=False)
+            self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
         elif self.pdf_type == self.PIP_OCR:
         elif self.pdf_type == self.PIP_OCR:
-            self.model_list = doc_analyze(self.pdf_bytes, ocr=True)
+            self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
 
 
     def pipe_parse(self):
     def pipe_parse(self):
         if self.pdf_type == self.PIP_TXT:
         if self.pdf_type == self.PIP_TXT:
             self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
             self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
-                                                is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty)
+                                                is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
+                                                start_page_id=self.start_page_id, end_page_id=self.end_page_id)
         elif self.pdf_type == self.PIP_OCR:
         elif self.pdf_type == self.PIP_OCR:
             self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
             self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
-                                              is_debug=self.is_debug)
+                                              is_debug=self.is_debug,
+                                              start_page_id=self.start_page_id, end_page_id=self.end_page_id)
 
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)

+ 20 - 3
magic_pdf/tools/cli.py

@@ -49,11 +49,26 @@ without method specified, auto will be used by default.""",
     '--debug',
     '--debug',
     'debug_able',
     'debug_able',
     type=bool,
     type=bool,
-    help=('Enables detailed debugging information during'
-          'the execution of the CLI commands.', ),
+    help='Enables detailed debugging information during the execution of the CLI commands.',
     default=False,
     default=False,
 )
 )
-def cli(path, output_dir, method, debug_able):
+@click.option(
+    '-s',
+    '--start',
+    'start_page_id',
+    type=int,
+    help='The starting page for PDF parsing, beginning from 0.',
+    default=0,
+)
+@click.option(
+    '-e',
+    '--end',
+    'end_page_id',
+    type=int,
+    help='The ending page for PDF parsing, beginning from 0.',
+    default=None,
+)
+def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
     model_config.__use_inside_model__ = True
     model_config.__use_inside_model__ = True
     model_config.__model_mode__ = 'full'
     model_config.__model_mode__ = 'full'
     os.makedirs(output_dir, exist_ok=True)
     os.makedirs(output_dir, exist_ok=True)
@@ -73,6 +88,8 @@ def cli(path, output_dir, method, debug_able):
                 [],
                 [],
                 method,
                 method,
                 debug_able,
                 debug_able,
+                start_page_id=start_page_id,
+                end_page_id=end_page_id,
             )
             )
 
 
         except Exception as e:
         except Exception as e:

+ 8 - 3
magic_pdf/tools/common.py

@@ -42,6 +42,8 @@ def do_parse(
     f_dump_content_list=False,
     f_dump_content_list=False,
     f_make_md_mode=MakeMode.MM_MD,
     f_make_md_mode=MakeMode.MM_MD,
     f_draw_model_bbox=False,
     f_draw_model_bbox=False,
+    start_page_id=0,
+    end_page_id=None,
 ):
 ):
     if debug_able:
     if debug_able:
         logger.warning("debug mode is on")
         logger.warning("debug mode is on")
@@ -58,11 +60,14 @@ def do_parse(
 
 
     if parse_method == 'auto':
     if parse_method == 'auto':
         jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
         jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
-        pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True)
+        pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
+                       start_page_id=start_page_id, end_page_id=end_page_id)
     elif parse_method == 'txt':
     elif parse_method == 'txt':
-        pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True)
+        pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
+                       start_page_id=start_page_id, end_page_id=end_page_id)
     elif parse_method == 'ocr':
     elif parse_method == 'ocr':
-        pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True)
+        pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
+                       start_page_id=start_page_id, end_page_id=end_page_id)
     else:
     else:
         logger.error('unknown parse method')
         logger.error('unknown parse method')
         exit(1)
         exit(1)

+ 17 - 9
magic_pdf/user_api.py

@@ -25,8 +25,9 @@ PARSE_TYPE_TXT = "txt"
 PARSE_TYPE_OCR = "ocr"
 PARSE_TYPE_OCR = "ocr"
 
 
 
 
-def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args,
-                  **kwargs):
+def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
+                  start_page_id=0, end_page_id=None,
+                  *args, **kwargs):
     """
     """
     解析文本类pdf
     解析文本类pdf
     """
     """
@@ -34,7 +35,8 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
         pdf_bytes,
         pdf_bytes,
         pdf_models,
         pdf_models,
         imageWriter,
         imageWriter,
-        start_page_id=start_page,
+        start_page_id=start_page_id,
+        end_page_id=end_page_id,
         debug_mode=is_debug,
         debug_mode=is_debug,
     )
     )
 
 
@@ -45,8 +47,9 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
     return pdf_info_dict
     return pdf_info_dict
 
 
 
 
-def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0, *args,
-                  **kwargs):
+def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
+                  start_page_id=0, end_page_id=None,
+                  *args, **kwargs):
     """
     """
     解析ocr类pdf
     解析ocr类pdf
     """
     """
@@ -54,7 +57,8 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
         pdf_bytes,
         pdf_bytes,
         pdf_models,
         pdf_models,
         imageWriter,
         imageWriter,
-        start_page_id=start_page,
+        start_page_id=start_page_id,
+        end_page_id=end_page_id,
         debug_mode=is_debug,
         debug_mode=is_debug,
     )
     )
 
 
@@ -65,8 +69,9 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
     return pdf_info_dict
     return pdf_info_dict
 
 
 
 
-def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False, start_page=0,
+def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
                     input_model_is_empty: bool = False,
                     input_model_is_empty: bool = False,
+                    start_page_id=0, end_page_id=None,
                     *args, **kwargs):
                     *args, **kwargs):
     """
     """
     ocr和文本混合的pdf,全部解析出来
     ocr和文本混合的pdf,全部解析出来
@@ -78,7 +83,8 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
                 pdf_bytes,
                 pdf_bytes,
                 pdf_models,
                 pdf_models,
                 imageWriter,
                 imageWriter,
-                start_page_id=start_page,
+                start_page_id=start_page_id,
+                end_page_id=end_page_id,
                 debug_mode=is_debug,
                 debug_mode=is_debug,
             )
             )
         except Exception as e:
         except Exception as e:
@@ -89,7 +95,9 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
     if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
     if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
         logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
         logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
         if input_model_is_empty:
         if input_model_is_empty:
-            pdf_models = doc_analyze(pdf_bytes, ocr=True)
+            pdf_models = doc_analyze(pdf_bytes, ocr=True,
+                                     start_page_id=start_page_id,
+                                     end_page_id=end_page_id)
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         if pdf_info_dict is None:
         if pdf_info_dict is None:
             raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
             raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")