Browse Source

refactor: isolate inference and pipeline

icecraft 11 tháng trước cách đây
mục cha
commit
a3a720ea87

+ 66 - 3
magic_pdf/data/dataset.py

@@ -1,11 +1,13 @@
+import os
 from abc import ABC, abstractmethod
-from typing import Iterator
+from typing import Callable, Iterator
 
 import fitz
 
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.data.schemas import PageInfo
 from magic_pdf.data.utils import fitz_doc_to_image
+from magic_pdf.filter import classify
 
 
 class PageableData(ABC):
@@ -28,6 +30,14 @@ class PageableData(ABC):
         """
         pass
 
+    @abstractmethod
+    def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
+        pass
+
+    @abstractmethod
+    def insert_text(self, coord, content, fontsize, color):
+        pass
+
 
 class Dataset(ABC):
     @abstractmethod
@@ -66,6 +76,18 @@ class Dataset(ABC):
         """
         pass
 
+    @abstractmethod
+    def dump_to_file(self, file_path: str):
+        pass
+
+    @abstractmethod
+    def apply(self, proc: Callable, *args, **kwargs):
+        pass
+
+    @abstractmethod
+    def classify(self) -> SupportedPdfParseMethod:
+        pass
+
 
 class PymuDocDataset(Dataset):
     def __init__(self, bits: bytes):
@@ -74,7 +96,8 @@ class PymuDocDataset(Dataset):
         Args:
             bits (bytes): the bytes of the pdf
         """
-        self._records = [Doc(v) for v in fitz.open('pdf', bits)]
+        self._raw_fitz = fitz.open('pdf', bits)
+        self._records = [Doc(v) for v in self._raw_fitz]
         self._data_bits = bits
         self._raw_data = bits
 
@@ -109,6 +132,19 @@ class PymuDocDataset(Dataset):
         """
         return self._records[page_id]
 
+    def dump_to_file(self, file_path: str):
+        dir_name = os.path.dirname(file_path)
+        if dir_name not in ('', '.', '..'):
+            os.makedirs(dir_name, exist_ok=True)
+        self._raw_fitz.save(file_path)
+
+    def apply(self, proc: Callable, *args, **kwargs):
+        new_args = tuple([self] + list(args))
+        return proc(*new_args, **kwargs)
+
+    def classify(self) -> SupportedPdfParseMethod:
+        return classify(self._data_bits)
+
 
 class ImageDataset(Dataset):
     def __init__(self, bits: bytes):
@@ -118,7 +154,8 @@ class ImageDataset(Dataset):
             bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
         """
         pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
-        self._records = [Doc(v) for v in fitz.open('pdf', pdf_bytes)]
+        self._raw_fitz = fitz.open('pdf', pdf_bytes)
+        self._records = [Doc(v) for v in self._raw_fitz]
         self._raw_data = bits
         self._data_bits = pdf_bytes
 
@@ -153,9 +190,22 @@ class ImageDataset(Dataset):
         """
         return self._records[page_id]
 
+    def dump_to_file(self, file_path: str):
+        dir_name = os.path.dirname(file_path)
+        if dir_name not in ('', '.', '..'):
+            os.makedirs(dir_name, exist_ok=True)
+        self._raw_fitz.save(file_path)
+
+    def apply(self, proc: Callable, *args, **kwargs):
+        return proc(self, *args, **kwargs)
+
+    def classify(self) -> SupportedPdfParseMethod:
+        return SupportedPdfParseMethod.OCR
+
 
 class Doc(PageableData):
     """Initialized with pymudoc object."""
+
     def __init__(self, doc: fitz.Page):
         self._doc = doc
 
@@ -192,3 +242,16 @@ class Doc(PageableData):
     def __getattr__(self, name):
         if hasattr(self._doc, name):
             return getattr(self._doc, name)
+
+    def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
+        self._doc.draw_rect(
+            rect_coords,
+            color=color,
+            fill=fill,
+            fill_opacity=fill_opacity,
+            width=width,
+            overlay=overlay,
+        )
+
+    def insert_text(self, coord, content, fontsize, color):
+        self._doc.insert_text(coord, content, fontsize=fontsize, color=color)

+ 32 - 0
magic_pdf/filter/__init__.py

@@ -0,0 +1,32 @@
+
+from magic_pdf.config.drop_reason import DropReason
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.filter.pdf_classify_by_type import classify as do_classify
+from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
+
+
+def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
+    """根据pdf的元数据,判断是文本pdf,还是ocr pdf."""
+    pdf_meta = pdf_meta_scan(pdf_bytes)
+    if pdf_meta.get('_need_drop', False):  # 如果返回了需要丢弃的标志,则抛出异常
+        raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
+    else:
+        is_encrypted = pdf_meta['is_encrypted']
+        is_needs_password = pdf_meta['is_needs_password']
+        if is_encrypted or is_needs_password:  # 加密的,需要密码的,没有页面的,都不处理
+            raise Exception(f'pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}')
+        else:
+            is_text_pdf, results = do_classify(
+                pdf_meta['total_page'],
+                pdf_meta['page_width_pts'],
+                pdf_meta['page_height_pts'],
+                pdf_meta['image_info_per_page'],
+                pdf_meta['text_len_per_page'],
+                pdf_meta['imgs_per_page'],
+                pdf_meta['text_layout_per_page'],
+                pdf_meta['invalid_chars'],
+            )
+            if is_text_pdf:
+                return SupportedPdfParseMethod.TXT
+            else:
+                return SupportedPdfParseMethod.OCR

+ 11 - 9
magic_pdf/libs/draw_bbox.py

@@ -1,7 +1,9 @@
 import fitz
 from magic_pdf.config.constants import CROSS_PAGE
-from magic_pdf.config.ocr_content_type import BlockType, CategoryId, ContentType
-from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.config.ocr_content_type import (BlockType, CategoryId,
+                                               ContentType)
+from magic_pdf.data.dataset import Dataset
+from magic_pdf.libs.commons import fitz  # PyMuPDF
 from magic_pdf.model.magic_model import MagicModel
 
 
@@ -194,7 +196,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         )
 
     # Save the PDF
-    pdf_docs.save(f'{out_path}/{filename}_layout.pdf')
+    pdf_docs.save(f'{out_path}/{filename}')
 
 
 def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
@@ -282,18 +284,17 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
         draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
 
     # Save the PDF
-    pdf_docs.save(f'{out_path}/{filename}_spans.pdf')
+    pdf_docs.save(f'{out_path}/{filename}')
 
 
-def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
+def draw_model_bbox(model_list, dataset: Dataset, out_path, filename):
     dropped_bbox_list = []
     tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
     imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
     titles_list = []
     texts_list = []
     interequations_list = []
-    pdf_docs = fitz.open('pdf', pdf_bytes)
-    magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
+    magic_model = MagicModel(model_list, dataset)
     for i in range(len(model_list)):
         page_dropped_list = []
         tables_body, tables_caption, tables_footnote = [], [], []
@@ -337,7 +338,8 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
         dropped_bbox_list.append(page_dropped_list)
         imgs_footnote_list.append(imgs_footnote)
 
-    for i, page in enumerate(pdf_docs):
+    for i in range(len(dataset)):
+        page = dataset.get_page(i)
         draw_bbox_with_number(
             i, dropped_bbox_list, page, [158, 158, 158], True
         )  # color !
@@ -352,7 +354,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
         draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
 
     # Save the PDF
-    pdf_docs.save(f'{out_path}/{filename}_model.pdf')
+    dataset.dump_to_file(f'{out_path}/{filename}')
 
 
 def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):

+ 104 - 60
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -1,14 +1,19 @@
+
 import time
 
 import fitz
 import numpy as np
 from loguru import logger
 
+import magic_pdf.model as model_config
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.clean_memory import clean_memory
-from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
-    get_formula_config
+from magic_pdf.libs.config_reader import (get_device, get_formula_config,
+                                          get_layout_config,
+                                          get_local_models_dir,
+                                          get_table_recog_config)
 from magic_pdf.model.model_list import MODEL
-import magic_pdf.model as model_config
+from magic_pdf.model.types import InferenceResult
 
 
 def dict_compare(d1, d2):
@@ -19,25 +24,31 @@ def remove_duplicates_dicts(lst):
     unique_dicts = []
     for dict_item in lst:
         if not any(
-                dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
+            dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
         ):
             unique_dicts.append(dict_item)
     return unique_dicts
 
 
-def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
+def load_images_from_pdf(
+    pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None
+) -> list:
     try:
         from PIL import Image
     except ImportError:
-        logger.error("Pillow not installed, please install by pip.")
+        logger.error('Pillow not installed, please install by pip.')
         exit(1)
 
     images = []
-    with fitz.open("pdf", pdf_bytes) as doc:
+    with fitz.open('pdf', pdf_bytes) as doc:
         pdf_page_num = doc.page_count
-        end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
+        end_page_id = (
+            end_page_id
+            if end_page_id is not None and end_page_id >= 0
+            else pdf_page_num - 1
+        )
         if end_page_id > pdf_page_num - 1:
-            logger.warning("end_page_id is out of range, use images length")
+            logger.warning('end_page_id is out of range, use images length')
             end_page_id = pdf_page_num - 1
 
         for index in range(0, doc.page_count):
@@ -50,11 +61,11 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
                 if pm.width > 4500 or pm.height > 4500:
                     pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
 
-                img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
+                img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
                 img = np.array(img)
-                img_dict = {"img": img, "width": pm.width, "height": pm.height}
+                img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
             else:
-                img_dict = {"img": [], "width": 0, "height": 0}
+                img_dict = {'img': [], 'width': 0, 'height': 0}
 
             images.append(img_dict)
     return images
@@ -69,117 +80,150 @@ class ModelSingleton:
             cls._instance = super().__new__(cls)
         return cls._instance
 
-    def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
+    def get_model(
+        self,
+        ocr: bool,
+        show_log: bool,
+        lang=None,
+        layout_model=None,
+        formula_enable=None,
+        table_enable=None,
+    ):
         key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
         if key not in self._models:
-            self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
-                                                  formula_enable=formula_enable, table_enable=table_enable)
+            self._models[key] = custom_model_init(
+                ocr=ocr,
+                show_log=show_log,
+                lang=lang,
+                layout_model=layout_model,
+                formula_enable=formula_enable,
+                table_enable=table_enable,
+            )
         return self._models[key]
 
 
-def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
-                      layout_model=None, formula_enable=None, table_enable=None):
+def custom_model_init(
+    ocr: bool = False,
+    show_log: bool = False,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+):
 
     model = None
 
-    if model_config.__model_mode__ == "lite":
-        logger.warning("The Lite mode is provided for developers to conduct testing only, and the output quality is "
-                       "not guaranteed to be reliable.")
+    if model_config.__model_mode__ == 'lite':
+        logger.warning(
+            'The Lite mode is provided for developers to conduct testing only, and the output quality is '
+            'not guaranteed to be reliable.'
+        )
         model = MODEL.Paddle
-    elif model_config.__model_mode__ == "full":
+    elif model_config.__model_mode__ == 'full':
         model = MODEL.PEK
 
     if model_config.__use_inside_model__:
         model_init_start = time.time()
         if model == MODEL.Paddle:
             from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
+
             custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
         elif model == MODEL.PEK:
             from magic_pdf.model.pdf_extract_kit import CustomPEKModel
+
             # 从配置文件读取model-dir和device
             local_models_dir = get_local_models_dir()
             device = get_device()
 
             layout_config = get_layout_config()
             if layout_model is not None:
-                layout_config["model"] = layout_model
+                layout_config['model'] = layout_model
 
             formula_config = get_formula_config()
             if formula_enable is not None:
-                formula_config["enable"] = formula_enable
+                formula_config['enable'] = formula_enable
 
             table_config = get_table_recog_config()
             if table_enable is not None:
-                table_config["enable"] = table_enable
+                table_config['enable'] = table_enable
 
             model_input = {
-                            "ocr": ocr,
-                            "show_log": show_log,
-                            "models_dir": local_models_dir,
-                            "device": device,
-                            "table_config": table_config,
-                            "layout_config": layout_config,
-                            "formula_config": formula_config,
-                            "lang": lang,
+                'ocr': ocr,
+                'show_log': show_log,
+                'models_dir': local_models_dir,
+                'device': device,
+                'table_config': table_config,
+                'layout_config': layout_config,
+                'formula_config': formula_config,
+                'lang': lang,
             }
 
             custom_model = CustomPEKModel(**model_input)
         else:
-            logger.error("Not allow model_name!")
+            logger.error('Not allow model_name!')
             exit(1)
         model_init_cost = time.time() - model_init_start
-        logger.info(f"model init cost: {model_init_cost}")
+        logger.info(f'model init cost: {model_init_cost}')
     else:
-        logger.error("use_inside_model is False, not allow to use inside model")
+        logger.error('use_inside_model is False, not allow to use inside model')
         exit(1)
 
     return custom_model
 
 
-def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
-                start_page_id=0, end_page_id=None, lang=None,
-                layout_model=None, formula_enable=None, table_enable=None):
+def doc_analyze(
+    dataset: Dataset,
+    ocr: bool = False,
+    show_log: bool = False,
+    start_page_id=0,
+    end_page_id=None,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+) -> InferenceResult:
 
-    if lang == "":
+    if lang == '':
         lang = None
 
     model_manager = ModelSingleton()
-    custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
-
-    with fitz.open("pdf", pdf_bytes) as doc:
-        pdf_page_num = doc.page_count
-        end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
-        if end_page_id > pdf_page_num - 1:
-            logger.warning("end_page_id is out of range, use images length")
-            end_page_id = pdf_page_num - 1
-
-    images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
+    custom_model = model_manager.get_model(
+        ocr, show_log, lang, layout_model, formula_enable, table_enable
+    )
 
     model_json = []
     doc_analyze_start = time.time()
 
-    for index, img_dict in enumerate(images):
-        img = img_dict["img"]
-        page_width = img_dict["width"]
-        page_height = img_dict["height"]
+    if end_page_id is None:
+        end_page_id = len(dataset)
+
+    for index in range(len(dataset)):
+        page_data = dataset.get_page(index)
+        img_dict = page_data.get_image()
+        img = img_dict['img']
+        page_width = img_dict['width']
+        page_height = img_dict['height']
         if start_page_id <= index <= end_page_id:
             page_start = time.time()
             result = custom_model(img)
             logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
         else:
             result = []
-        page_info = {"page_no": index, "height": page_height, "width": page_width}
-        page_dict = {"layout_dets": result, "page_info": page_info}
+
+        page_info = {'page_no': index, 'height': page_height, 'width': page_width}
+        page_dict = {'layout_dets': result, 'page_info': page_info}
         model_json.append(page_dict)
 
     gc_start = time.time()
     clean_memory()
     gc_time = round(time.time() - gc_start, 2)
-    logger.info(f"gc time: {gc_time}")
+    logger.info(f'gc time: {gc_time}')
 
     doc_analyze_time = round(time.time() - doc_analyze_start, 2)
-    doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
-    logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)},"
-                f" speed: {doc_analyze_speed} pages/second")
+    doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
+    logger.info(
+        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
+        f' speed: {doc_analyze_speed} pages/second'
+    )
 
-    return model_json
+    return InferenceResult(model_json, dataset)

+ 122 - 0
magic_pdf/model/types.py

@@ -0,0 +1,122 @@
+import copy
+import json
+import os
+from typing import Callable
+
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
+from magic_pdf.filter import classify
+from magic_pdf.libs.draw_bbox import draw_model_bbox
+from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
+from magic_pdf.pipe.types import PipeResult
+
+
+class InferenceResult:
+    def __init__(self, inference_results: list, dataset: Dataset):
+        self._infer_res = inference_results
+        self._dataset = dataset
+
+    def draw_model(self, file_path: str) -> None:
+        dir_name = os.path.dirname(file_path)
+        base_name = os.path.basename(file_path)
+        if not os.path.exists(dir_name):
+            os.makedirs(dir_name, exist_ok=True)
+        draw_model_bbox(
+            copy.deepcopy(self._infer_res), self._dataset, dir_name, base_name
+        )
+
+    def dump_model(self, writer: DataWriter, file_path: str):
+        writer.write_string(
+            file_path, json.dumps(self._infer_res, ensure_ascii=False, indent=4)
+        )
+
+    def get_infer_res(self):
+        return self._infer_res
+
+    def apply(self, proc: Callable, *args, **kwargs):
+        return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
+
+    def pipe_auto_mode(
+        self,
+        imageWriter: DataWriter,
+        start_page_id=0,
+        end_page_id=None,
+        debug_mode=False,
+        lang=None,
+    ) -> PipeResult:
+        def proc(*args, **kwargs) -> PipeResult:
+            res = pdf_parse_union(*args, **kwargs)
+            return PipeResult(res, self._dataset)
+
+        pdf_proc_method = classify(self._dataset.data_bits())
+
+        if pdf_proc_method == SupportedPdfParseMethod.TXT:
+            return self.apply(
+                proc,
+                self._dataset,
+                imageWriter,
+                SupportedPdfParseMethod.TXT,
+                start_page_id=0,
+                end_page_id=None,
+                debug_mode=False,
+                lang=None,
+            )
+        else:
+            return self.apply(
+                proc,
+                self._dataset,
+                imageWriter,
+                SupportedPdfParseMethod.OCR,
+                start_page_id=0,
+                end_page_id=None,
+                debug_mode=False,
+                lang=None,
+            )
+
+    def pipe_txt_mode(
+        self,
+        imageWriter: DataWriter,
+        start_page_id=0,
+        end_page_id=None,
+        debug_mode=False,
+        lang=None,
+    ) -> PipeResult:
+        def proc(*args, **kwargs) -> PipeResult:
+            res = pdf_parse_union(*args, **kwargs)
+            return PipeResult(res, self._dataset)
+
+        return self.apply(
+            proc,
+            self._dataset,
+            imageWriter,
+            SupportedPdfParseMethod.TXT,
+            start_page_id=0,
+            end_page_id=None,
+            debug_mode=False,
+            lang=None,
+        )
+
+    def pipe_ocr_mode(
+        self,
+        imageWriter: DataWriter,
+        start_page_id=0,
+        end_page_id=None,
+        debug_mode=False,
+        lang=None,
+    ) -> PipeResult:
+
+        def proc(*args, **kwargs) -> PipeResult:
+            res = pdf_parse_union(*args, **kwargs)
+            return PipeResult(res, self._dataset)
+
+        return self.apply(
+            proc,
+            self._dataset,
+            imageWriter,
+            SupportedPdfParseMethod.TXT,
+            start_page_id=0,
+            end_page_id=None,
+            debug_mode=False,
+            lang=None,
+        )

+ 4 - 5
magic_pdf/pdf_parse_by_ocr.py

@@ -1,9 +1,9 @@
 from magic_pdf.config.enums import SupportedPdfParseMethod
-from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 
 
-def parse_pdf_by_ocr(pdf_bytes,
+def parse_pdf_by_ocr(dataset: Dataset,
                      model_list,
                      imageWriter,
                      start_page_id=0,
@@ -11,9 +11,8 @@ def parse_pdf_by_ocr(pdf_bytes,
                      debug_mode=False,
                      lang=None,
                      ):
-    dataset = PymuDocDataset(pdf_bytes)
-    return pdf_parse_union(dataset,
-                           model_list,
+    return pdf_parse_union(model_list,
+                           dataset,
                            imageWriter,
                            SupportedPdfParseMethod.OCR,
                            start_page_id=start_page_id,

+ 4 - 5
magic_pdf/pdf_parse_by_txt.py

@@ -1,10 +1,10 @@
 from magic_pdf.config.enums import SupportedPdfParseMethod
-from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 
 
 def parse_pdf_by_txt(
-    pdf_bytes,
+    dataset: Dataset,
     model_list,
     imageWriter,
     start_page_id=0,
@@ -12,9 +12,8 @@ def parse_pdf_by_txt(
     debug_mode=False,
     lang=None,
 ):
-    dataset = PymuDocDataset(pdf_bytes)
-    return pdf_parse_union(dataset,
-                           model_list,
+    return pdf_parse_union(model_list,
+                           dataset,
                            imageWriter,
                            SupportedPdfParseMethod.TXT,
                            start_page_id=start_page_id,

+ 1 - 1
magic_pdf/pdf_parse_union_core_v2.py

@@ -832,4 +832,4 @@ def pdf_parse_union(
 
 
 if __name__ == '__main__':
-    pass
+    pass

+ 3 - 2
magic_pdf/pipe/AbsPipe.py

@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
 from magic_pdf.config.drop_reason import DropReason
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.dict2md.ocr_mkcontent import union_make
 from magic_pdf.filter.pdf_classify_by_type import classify
 from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
@@ -14,9 +15,9 @@ class AbsPipe(ABC):
     PIP_OCR = 'ocr'
     PIP_TXT = 'txt'
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
+    def __init__(self, dataset: Dataset, model_list: list, image_writer: DataWriter, is_debug: bool = False,
                  start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
-        self.pdf_bytes = pdf_bytes
+        self.dataset = Dataset
         self.model_list = model_list
         self.image_writer = image_writer
         self.pdf_mid_data = None  # 未压缩

+ 54 - 15
magic_pdf/pipe/OCRPipe.py

@@ -2,40 +2,79 @@ from loguru import logger
 
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.user_api import parse_ocr_pdf
 
 
 class OCRPipe(AbsPipe):
-
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None,
-                 layout_model=None, formula_enable=None, table_enable=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
-                         layout_model, formula_enable, table_enable)
+    def __init__(
+        self,
+        dataset: Dataset,
+        model_list: list,
+        image_writer: DataWriter,
+        is_debug: bool = False,
+        start_page_id=0,
+        end_page_id=None,
+        lang=None,
+        layout_model=None,
+        formula_enable=None,
+        table_enable=None,
+    ):
+        super().__init__(
+            dataset,
+            model_list,
+            image_writer,
+            is_debug,
+            start_page_id,
+            end_page_id,
+            lang,
+            layout_model,
+            formula_enable,
+            table_enable,
+        )
 
     def pipe_classify(self):
         pass
 
     def pipe_analyze(self):
-        self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
-                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                      lang=self.lang, layout_model=self.layout_model,
-                                      formula_enable=self.formula_enable, table_enable=self.table_enable)
+        self.infer_res = doc_analyze(
+            self.dataset,
+            ocr=True,
+            start_page_id=self.start_page_id,
+            end_page_id=self.end_page_id,
+            lang=self.lang,
+            layout_model=self.layout_model,
+            formula_enable=self.formula_enable,
+            table_enable=self.table_enable,
+        )
 
     def pipe_parse(self):
-        self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang, layout_model=self.layout_model,
-                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
+        self.pdf_mid_data = parse_ocr_pdf(
+            self.dataset,
+            self.infer_res,
+            self.image_writer,
+            is_debug=self.is_debug,
+            start_page_id=self.start_page_id,
+            end_page_id=self.end_page_id,
+            lang=self.lang,
+            layout_model=self.layout_model,
+            formula_enable=self.formula_enable,
+            table_enable=self.table_enable,
+        )
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         logger.info('ocr_pipe mk content list finished')
         return result
 
-    def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
+    def pipe_mk_markdown(
+        self,
+        img_parent_path: str,
+        drop_mode=DropMode.WHOLE_PDF,
+        md_make_mode=MakeMode.MM_MD,
+    ):
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
         logger.info(f'ocr_pipe mk {md_make_mode} finished')
         return result

+ 5 - 4
magic_pdf/pipe/TXTPipe.py

@@ -2,6 +2,7 @@ from loguru import logger
 
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.user_api import parse_txt_pdf
@@ -9,23 +10,23 @@ from magic_pdf.user_api import parse_txt_pdf
 
 class TXTPipe(AbsPipe):
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
+    def __init__(self, dataset: Dataset, model_list: list, image_writer: DataWriter, is_debug: bool = False,
                  start_page_id=0, end_page_id=None, lang=None,
                  layout_model=None, formula_enable=None, table_enable=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
+        super().__init__(dataset, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
                          layout_model, formula_enable, table_enable)
 
     def pipe_classify(self):
         pass
 
     def pipe_analyze(self):
-        self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
+        self.model_list = doc_analyze(self.dataset, ocr=False,
                                       start_page_id=self.start_page_id, end_page_id=self.end_page_id,
                                       lang=self.lang, layout_model=self.layout_model,
                                       formula_enable=self.formula_enable, table_enable=self.table_enable)
 
     def pipe_parse(self):
-        self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
+        self.pdf_mid_data = parse_txt_pdf(self.dataset, self.model_list, self.image_writer, is_debug=self.is_debug,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
                                           lang=self.lang, layout_model=self.layout_model,
                                           formula_enable=self.formula_enable, table_enable=self.table_enable)

+ 82 - 30
magic_pdf/pipe/UNIPipe.py

@@ -4,6 +4,7 @@ from loguru import logger
 
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.commons import join_path
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.pipe.AbsPipe import AbsPipe
@@ -12,12 +13,32 @@ from magic_pdf.user_api import parse_ocr_pdf, parse_union_pdf
 
 class UNIPipe(AbsPipe):
 
-    def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: DataWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None,
-                 layout_model=None, formula_enable=None, table_enable=None):
+    def __init__(
+        self,
+        dataset: Dataset,
+        jso_useful_key: dict,
+        image_writer: DataWriter,
+        is_debug: bool = False,
+        start_page_id=0,
+        end_page_id=None,
+        lang=None,
+        layout_model=None,
+        formula_enable=None,
+        table_enable=None,
+    ):
         self.pdf_type = jso_useful_key['_pdf_type']
-        super().__init__(pdf_bytes, jso_useful_key['model_list'], image_writer, is_debug, start_page_id, end_page_id,
-                         lang, layout_model, formula_enable, table_enable)
+        super().__init__(
+            dataset,
+            jso_useful_key['model_list'],
+            image_writer,
+            is_debug,
+            start_page_id,
+            end_page_id,
+            lang,
+            layout_model,
+            formula_enable,
+            table_enable,
+        )
         if len(self.model_list) == 0:
             self.input_model_is_empty = True
         else:
@@ -28,35 +49,66 @@ class UNIPipe(AbsPipe):
 
     def pipe_analyze(self):
         if self.pdf_type == self.PIP_TXT:
-            self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang, layout_model=self.layout_model,
-                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
+            self.model_list = doc_analyze(
+                self.dataset,
+                ocr=False,
+                start_page_id=self.start_page_id,
+                end_page_id=self.end_page_id,
+                lang=self.lang,
+                layout_model=self.layout_model,
+                formula_enable=self.formula_enable,
+                table_enable=self.table_enable,
+            )
         elif self.pdf_type == self.PIP_OCR:
-            self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang, layout_model=self.layout_model,
-                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
+            self.model_list = doc_analyze(
+                self.dataset,
+                ocr=True,
+                start_page_id=self.start_page_id,
+                end_page_id=self.end_page_id,
+                lang=self.lang,
+                layout_model=self.layout_model,
+                formula_enable=self.formula_enable,
+                table_enable=self.table_enable,
+            )
 
     def pipe_parse(self):
         if self.pdf_type == self.PIP_TXT:
-            self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
-                                                is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
-                                                start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                                lang=self.lang, layout_model=self.layout_model,
-                                                formula_enable=self.formula_enable, table_enable=self.table_enable)
+            self.pdf_mid_data = parse_union_pdf(
+                self.dataset,
+                self.model_list,
+                self.image_writer,
+                is_debug=self.is_debug,
+                start_page_id=self.start_page_id,
+                end_page_id=self.end_page_id,
+                lang=self.lang,
+                layout_model=self.layout_model,
+                formula_enable=self.formula_enable,
+                table_enable=self.table_enable,
+            )
         elif self.pdf_type == self.PIP_OCR:
-            self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
-                                              is_debug=self.is_debug,
-                                              start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                              lang=self.lang)
-
-    def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON):
+            self.pdf_mid_data = parse_ocr_pdf(
+                self.dataset,
+                self.model_list,
+                self.image_writer,
+                is_debug=self.is_debug,
+                start_page_id=self.start_page_id,
+                end_page_id=self.end_page_id,
+                lang=self.lang,
+            )
+
+    def pipe_mk_uni_format(
+        self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON
+    ):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         logger.info('uni_pipe mk content list finished')
         return result
 
-    def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
+    def pipe_mk_markdown(
+        self,
+        img_parent_path: str,
+        drop_mode=DropMode.WHOLE_PDF,
+        md_make_mode=MakeMode.MM_MD,
+    ):
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
         logger.info(f'uni_pipe mk {md_make_mode} finished')
         return result
@@ -65,6 +117,7 @@ class UNIPipe(AbsPipe):
 if __name__ == '__main__':
     # 测试
     from magic_pdf.data.data_reader_writer import DataReader
+
     drw = DataReader(r'D:/project/20231108code-clean')
 
     pdf_file_path = r'linshixuqiu\19983-00.pdf'
@@ -82,10 +135,7 @@ if __name__ == '__main__':
     #     "model_list": model_list
     # }
 
-    jso_useful_key = {
-        '_pdf_type': '',
-        'model_list': model_list
-    }
+    jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
     pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer)
     pipe.pipe_classify()
     pipe.pipe_parse()
@@ -94,5 +144,7 @@ if __name__ == '__main__':
 
     md_writer = DataWriter(write_path)
     md_writer.write_string('19983-00.md', md_content)
-    md_writer.write_string('19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4))
+    md_writer.write_string(
+        '19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
+    )
     md_writer.write_string('19983-00.txt', str(content_list))

+ 62 - 0
magic_pdf/pipe/types.py

@@ -0,0 +1,62 @@
+
+import json
+import os
+
+from magic_pdf.config.make_content_config import DropMode, MakeMode
+from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
+from magic_pdf.dict2md.ocr_mkcontent import union_make
+from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
+                                      draw_span_bbox)
+from magic_pdf.libs.json_compressor import JsonCompressor
+
+
+class PipeResult:
+    def __init__(self, pipe_res, dataset: Dataset):
+        self._pipe_res = pipe_res
+        self._dataset = dataset
+
+    def dump_md(self, writer: DataWriter, file_path: str, img_dir_or_bucket_prefix: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
+        pdf_info_list = self._pipe_res['pdf_info']
+        md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix)
+        writer.write_string(file_path, md_content)
+
+    def dump_content_list(self, writer: DataWriter, file_path: str, image_dir_or_bucket_prefix: str, drop_mode=DropMode.NONE):
+        pdf_info_list = self._pipe_res['pdf_info']
+        content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, image_dir_or_bucket_prefix)
+        writer.write_string(file_path, json.dumps(content_list, ensure_ascii=False, indent=4))
+
+    def dump_middle_json(self, writer: DataWriter, file_path: str):
+        writer.write_string(file_path, json.dumps(self._pipe_res, ensure_ascii=False, indent=4))
+
+    def draw_layout(self, file_path: str) -> None:
+        dir_name = os.path.dirname(file_path)
+        base_name = os.path.basename(file_path)
+        if not os.path.exists(dir_name):
+            os.makedirs(dir_name, exist_ok=True)
+        pdf_info = self._pipe_res['pdf_info']
+        draw_layout_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
+
+    def draw_span(self, file_path: str):
+        dir_name = os.path.dirname(file_path)
+        base_name = os.path.basename(file_path)
+        if not os.path.exists(dir_name):
+            os.makedirs(dir_name, exist_ok=True)
+        pdf_info = self._pipe_res['pdf_info']
+        draw_span_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
+
+    def draw_line_sort(self, file_path: str):
+        dir_name = os.path.dirname(file_path)
+        base_name = os.path.basename(file_path)
+        if not os.path.exists(dir_name):
+            os.makedirs(dir_name, exist_ok=True)
+        pdf_info = self._pipe_res['pdf_info']
+        draw_line_sort_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
+
+    def draw_content_list(self, writer: DataWriter, file_path: str, img_dir_or_bucket_prefix: str, drop_mode=DropMode.WHOLE_PDF):
+        pdf_info_list = self._pipe_res['pdf_info']
+        content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_dir_or_bucket_prefix)
+        writer.write_string(file_path, json.dumps(content_list, ensure_ascii=False, indent=4))
+
+    def get_compress_pdf_mid_data(self):
+        return JsonCompressor.compress_json(self.pdf_mid_data)

+ 106 - 59
magic_pdf/tools/common.py

@@ -1,5 +1,3 @@
-import copy
-import json as json_parse
 import os
 
 import click
@@ -7,13 +5,12 @@ import fitz
 from loguru import logger
 
 import magic_pdf.model as model_config
+from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import FileBasedDataWriter
-from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
-                                      draw_model_bbox, draw_span_bbox)
-from magic_pdf.pipe.OCRPipe import OCRPipe
-from magic_pdf.pipe.TXTPipe import TXTPipe
-from magic_pdf.pipe.UNIPipe import UNIPipe
+from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
+from magic_pdf.model.types import InferenceResult
 
 # from io import BytesIO
 # from pypdf import PdfReader, PdfWriter
@@ -56,7 +53,11 @@ def prepare_env(output_dir, pdf_file_name, method):
 def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_id=None):
     document = fitz.open('pdf', pdf_bytes)
     output_document = fitz.open()
-    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(document) - 1
+    end_page_id = (
+        end_page_id
+        if end_page_id is not None and end_page_id >= 0
+        else len(document) - 1
+    )
     if end_page_id > len(document) - 1:
         logger.warning('end_page_id is out of range, use pdf_docs length')
         end_page_id = len(document) - 1
@@ -94,78 +95,123 @@ def do_parse(
         f_draw_model_bbox = True
         f_draw_line_sort_bbox = True
 
-    if lang == "":
+    if lang == '':
         lang = None
 
-    pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id, end_page_id)
+    pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
+        pdf_bytes, start_page_id, end_page_id
+    )
 
-    orig_model_list = copy.deepcopy(model_list)
-    local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name,
-                                                parse_method)
+    local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
 
-    image_writer, md_writer = FileBasedDataWriter(
-        local_image_dir), FileBasedDataWriter(local_md_dir)
+    image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
+        local_md_dir
+    )
     image_dir = str(os.path.basename(local_image_dir))
 
-    if parse_method == 'auto':
-        jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
-        pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
-                       # start_page_id=start_page_id, end_page_id=end_page_id,
-                       lang=lang,
-                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
-    elif parse_method == 'txt':
-        pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       # start_page_id=start_page_id, end_page_id=end_page_id,
-                       lang=lang,
-                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
-    elif parse_method == 'ocr':
-        pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       # start_page_id=start_page_id, end_page_id=end_page_id,
-                       lang=lang,
-                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
-    else:
-        logger.error('unknown parse method')
-        exit(1)
-
-    pipe.pipe_classify()
+    ds = PymuDocDataset(pdf_bytes)
 
     if len(model_list) == 0:
         if model_config.__use_inside_model__:
-            pipe.pipe_analyze()
-            orig_model_list = copy.deepcopy(pipe.model_list)
+            if parse_method == 'auto':
+                if ds.classify() == SupportedPdfParseMethod.TXT:
+                    infer_result = ds.apply(
+                        doc_analyze,
+                        ocr=False,
+                        lang=lang,
+                        layout_model=layout_model,
+                        formula_enable=formula_enable,
+                        table_enable=table_enable,
+                    )
+                else:
+                    infer_result = ds.apply(
+                        doc_analyze,
+                        ocr=True,
+                        lang=lang,
+                        layout_model=layout_model,
+                        formula_enable=formula_enable,
+                        table_enable=table_enable,
+                    )
+                pipe_result = infer_result.pipe_auto_mode(
+                    image_writer, debug_mode=True, lang=lang
+                )
+
+            elif parse_method == 'txt':
+                infer_result = ds.apply(
+                    doc_analyze,
+                    ocr=False,
+                    lang=lang,
+                    layout_model=layout_model,
+                    formula_enable=formula_enable,
+                    table_enable=table_enable,
+                )
+                pipe_result = infer_result.pipe_txt_mode(
+                    image_writer, debug_mode=True, lang=lang
+                )
+            elif parse_method == 'ocr':
+                infer_result = ds.apply(
+                    doc_analyze,
+                    ocr=True,
+                    lang=lang,
+                    layout_model=layout_model,
+                    formula_enable=formula_enable,
+                    table_enable=table_enable,
+                )
+                pipe_result = infer_result.pipe_ocr_mode(
+                    image_writer, debug_mode=True, lang=lang
+                )
+            else:
+                logger.error('unknown parse method')
+                exit(1)
         else:
             logger.error('need model list input')
             exit(2)
+    else:
+        infer_result = InferenceResult(model_list, ds)
+        if parse_method == 'ocr':
+            pipe_result = infer_result.pipe_ocr_mode(
+                image_writer, debug_mode=True, lang=lang
+            )
+        elif parse_method == 'txt':
+            pipe_result = infer_result.pipe_txt_mode(
+                image_writer, debug_mode=True, lang=lang
+            )
+        else:
+            pipe_result = infer_result.pipe_auto_mode(
+                image_writer, debug_mode=True, lang=lang
+            )
+
+    if f_draw_model_bbox:
+        infer_result.draw_model(
+            os.path.join(local_md_dir, f'{pdf_file_name}_model.pdf')
+        )
 
-    pipe.pipe_parse()
-    pdf_info = pipe.pdf_mid_data['pdf_info']
     if f_draw_layout_bbox:
-        draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
+        pipe_result.draw_layout(
+            os.path.join(local_md_dir, f'{pdf_file_name}_layout.pdf')
+        )
     if f_draw_span_bbox:
-        draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
-    if f_draw_model_bbox:
-        draw_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name)
+        pipe_result.draw_span(os.path.join(local_md_dir, f'{pdf_file_name}_spans.pdf'))
+
     if f_draw_line_sort_bbox:
-        draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
+        pipe_result.draw_line_sort(
+            os.path.join(local_md_dir, f'{pdf_file_name}_line_sort.pdf')
+        )
 
-    md_content = pipe.pipe_mk_markdown(image_dir, drop_mode=DropMode.NONE, md_make_mode=f_make_md_mode)
     if f_dump_md:
-        md_writer.write_string(
+        pipe_result.dump_md(
+            md_writer,
             f'{pdf_file_name}.md',
-            md_content
+            image_dir,
+            drop_mode=DropMode.NONE,
+            md_make_mode=f_make_md_mode,
         )
 
     if f_dump_middle_json:
-        md_writer.write_string(
-            f'{pdf_file_name}_middle.json',
-            json_parse.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
-        )
+        pipe_result.dump_middle_json(md_writer, f'{pdf_file_name}_middle.json')
 
     if f_dump_model_json:
-        md_writer.write_string(
-            f'{pdf_file_name}_model.json',
-            json_parse.dumps(orig_model_list, ensure_ascii=False, indent=4)
-        )
+        infer_result.dump_model(md_writer, f'{pdf_file_name}_model.json')
 
     if f_dump_orig_pdf:
         md_writer.write(
@@ -173,11 +219,12 @@ def do_parse(
             pdf_bytes,
         )
 
-    content_list = pipe.pipe_mk_uni_format(image_dir, drop_mode=DropMode.NONE)
     if f_dump_content_list:
-        md_writer.write_string(
+        pipe_result.dump_content_list(
+            md_writer,
             f'{pdf_file_name}_content_list.json',
-            json_parse.dumps(content_list, ensure_ascii=False, indent=4)
+            image_dir,
+            drop_mode=DropMode.NONE,
         )
 
     logger.info(f'local output dir is {local_md_dir}')

+ 44 - 19
magic_pdf/user_api.py

@@ -10,6 +10,7 @@
 from loguru import logger
 
 from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.version import __version__
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
@@ -19,13 +20,21 @@ PARSE_TYPE_TXT = 'txt'
 PARSE_TYPE_OCR = 'ocr'
 
 
-def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False,
-                  start_page_id=0, end_page_id=None, lang=None,
-                  *args, **kwargs):
+def parse_txt_pdf(
+    dataset: Dataset,
+    model_list: list,
+    imageWriter: DataWriter,
+    is_debug=False,
+    start_page_id=0,
+    end_page_id=None,
+    lang=None,
+    *args,
+    **kwargs
+):
     """解析文本类pdf."""
     pdf_info_dict = parse_pdf_by_txt(
-        pdf_bytes,
-        pdf_models,
+        dataset,
+        model_list,
         imageWriter,
         start_page_id=start_page_id,
         end_page_id=end_page_id,
@@ -43,13 +52,21 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i
     return pdf_info_dict
 
 
-def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False,
-                  start_page_id=0, end_page_id=None, lang=None,
-                  *args, **kwargs):
+def parse_ocr_pdf(
+    dataset: Dataset,
+    model_list: list,
+    imageWriter: DataWriter,
+    is_debug=False,
+    start_page_id=0,
+    end_page_id=None,
+    lang=None,
+    *args,
+    **kwargs
+):
     """解析ocr类pdf."""
     pdf_info_dict = parse_pdf_by_ocr(
-        pdf_bytes,
-        pdf_models,
+        dataset,
+        model_list,
         imageWriter,
         start_page_id=start_page_id,
         end_page_id=end_page_id,
@@ -67,17 +84,24 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i
     return pdf_info_dict
 
 
-def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False,
-                    input_model_is_empty: bool = False,
-                    start_page_id=0, end_page_id=None, lang=None,
-                    *args, **kwargs):
+def parse_union_pdf(
+    dataset: Dataset,
+    model_list: list,
+    imageWriter: DataWriter,
+    is_debug=False,
+    start_page_id=0,
+    end_page_id=None,
+    lang=None,
+    *args,
+    **kwargs
+):
     """ocr和文本混合的pdf,全部解析出来."""
 
     def parse_pdf(method):
         try:
             return method(
-                pdf_bytes,
-                pdf_models,
+                dataset,
+                model_list,
                 imageWriter,
                 start_page_id=start_page_id,
                 end_page_id=end_page_id,
@@ -91,12 +115,12 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter,
     pdf_info_dict = parse_pdf(parse_pdf_by_txt)
     if pdf_info_dict is None or pdf_info_dict.get('_need_drop', False):
         logger.warning('parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr')
-        if input_model_is_empty:
+        if len(model_list) == 0:
             layout_model = kwargs.get('layout_model', None)
             formula_enable = kwargs.get('formula_enable', None)
             table_enable = kwargs.get('table_enable', None)
-            pdf_models = doc_analyze(
-                pdf_bytes,
+            infer_res = doc_analyze(
+                dataset,
                 ocr=True,
                 start_page_id=start_page_id,
                 end_page_id=end_page_id,
@@ -105,6 +129,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter,
                 formula_enable=formula_enable,
                 table_enable=table_enable,
             )
+            model_list = infer_res.get_infer_res()
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         if pdf_info_dict is None:
             raise Exception('Both parse_pdf_by_txt and parse_pdf_by_ocr failed.')