فهرست منبع

Merge pull request #1261 from opendatalab/release-0.10.6

Release 0.10.6
Xiaomeng Zhao 11 ماه پیش
والد
کامیت
b4f7b53ecb
46فایلهای تغییر یافته به همراه1797 افزوده شده و 478 حذف شده
  1. 2 18
      .github/workflows/cli.yml
  2. 4 4
      .github/workflows/huigui.yml
  3. 0 8
      docs/README_Windows_CUDA_Acceleration_en_US.md
  4. 0 9
      docs/README_Windows_CUDA_Acceleration_zh_CN.md
  5. 5 0
      magic_pdf/config/constants.py
  6. 13 1
      magic_pdf/data/data_reader_writer/base.py
  7. 175 4
      magic_pdf/data/dataset.py
  8. 2 2
      magic_pdf/dict2md/ocr_mkcontent.py
  9. 32 0
      magic_pdf/filter/__init__.py
  10. 3 2
      magic_pdf/filter/pdf_meta_scan.py
  11. 11 10
      magic_pdf/libs/draw_bbox.py
  12. 30 30
      magic_pdf/libs/pdf_check.py
  13. 124 0
      magic_pdf/model/__init__.py
  14. 119 60
      magic_pdf/model/doc_analyze_by_custom_model.py
  15. 190 0
      magic_pdf/model/operators.py
  16. 20 1
      magic_pdf/model/pdf_extract_kit.py
  17. 13 3
      magic_pdf/model/sub_modules/model_init.py
  18. 11 5
      magic_pdf/model/sub_modules/model_utils.py
  19. 4 5
      magic_pdf/pdf_parse_by_ocr.py
  20. 4 5
      magic_pdf/pdf_parse_by_txt.py
  21. 10 11
      magic_pdf/pdf_parse_union_core_v2.py
  22. 3 2
      magic_pdf/pipe/AbsPipe.py
  23. 54 15
      magic_pdf/pipe/OCRPipe.py
  24. 5 4
      magic_pdf/pipe/TXTPipe.py
  25. 82 30
      magic_pdf/pipe/UNIPipe.py
  26. 138 0
      magic_pdf/pipe/operators.py
  27. 108 59
      magic_pdf/tools/common.py
  28. 47 24
      magic_pdf/user_api.py
  29. 3 0
      next_docs/en/_static/image/pipeline.drawio.svg
  30. 2 0
      next_docs/en/api.rst
  31. 8 0
      next_docs/en/api/model_operators.rst
  32. 9 0
      next_docs/en/api/pipe_operators.rst
  33. 1 1
      next_docs/en/conf.py
  34. 52 38
      next_docs/en/user_guide/quick_start/to_markdown.rst
  35. 3 1
      next_docs/en/user_guide/tutorial.rst
  36. 185 0
      next_docs/en/user_guide/tutorial/pipeline.rst
  37. 5 1
      next_docs/requirements.txt
  38. 3 0
      next_docs/zh_cn/_static/image/pipeline.drawio.svg
  39. 53 42
      next_docs/zh_cn/user_guide/quick_start/to_markdown.rst
  40. 2 0
      next_docs/zh_cn/user_guide/tutorial.rst
  41. 59 66
      next_docs/zh_cn/user_guide/tutorial/output_file_description.rst
  42. 179 0
      next_docs/zh_cn/user_guide/tutorial/pipeline.rst
  43. 2 2
      requirements-docker.txt
  44. 2 2
      requirements.txt
  45. 4 2
      setup.py
  46. 16 11
      tests/test_cli/test_cli_sdk.py

+ 2 - 18
.github/workflows/cli.yml

@@ -30,7 +30,7 @@ jobs:
         source activate mineru
         source activate mineru
         conda env list
         conda env list
         pip show coverage
         pip show coverage
-        # cd $GITHUB_WORKSPACE && sh tests/retry_env.sh
+        cd $GITHUB_WORKSPACE && sh tests/retry_env.sh
         cd $GITHUB_WORKSPACE && python tests/clean_coverage.py      
         cd $GITHUB_WORKSPACE && python tests/clean_coverage.py      
         cd $GITHUB_WORKSPACE && coverage run -m pytest tests/unittest/ --cov=magic_pdf/  --cov-report html --cov-report term-missing
         cd $GITHUB_WORKSPACE && coverage run -m pytest tests/unittest/ --cov=magic_pdf/  --cov-report html --cov-report term-missing
         cd $GITHUB_WORKSPACE && python tests/get_coverage.py
         cd $GITHUB_WORKSPACE && python tests/get_coverage.py
@@ -41,22 +41,6 @@ jobs:
     needs: cli-test
     needs: cli-test
     runs-on: pdf
     runs-on: pdf
     steps:
     steps:
-    - name: get_actor
-      run: |
-          metion_list="dt-yy"
-          echo $GITHUB_ACTOR
-          if [[ $GITHUB_ACTOR == "drunkpig" ]]; then
-            metion_list="xuchao"
-          elif [[ $GITHUB_ACTOR == "myhloli" ]]; then
-            metion_list="zhaoxiaomeng"
-          elif [[ $GITHUB_ACTOR == "icecraft" ]]; then
-            metion_list="xurui1"
-          fi
-          echo $metion_list
-          echo "METIONS=$metion_list" >> "$GITHUB_ENV"
-          echo ${{ env.METIONS }}
-
     - name: notify
     - name: notify
       run: |
       run: |
-        echo ${{ secrets.USER_ID }}
-        curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'${{ secrets.USER_ID }}'"}]]}}}}'  ${{ secrets.WEBHOOK_URL }}
+        curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'$USER_ID'"}]]}}}}'  $WEBHOOK_URL

+ 4 - 4
.github/workflows/huigui.yml

@@ -29,14 +29,14 @@ jobs:
         source activate mineru
         source activate mineru
         conda env list
         conda env list
         pip show coverage
         pip show coverage
-        # cd $GITHUB_WORKSPACE && sh tests/retry_env.sh
+        cd $GITHUB_WORKSPACE && sh tests/retry_env.sh
         cd $GITHUB_WORKSPACE && python tests/clean_coverage.py      
         cd $GITHUB_WORKSPACE && python tests/clean_coverage.py      
         cd $GITHUB_WORKSPACE && coverage run -m pytest tests/unittest/ --cov=magic_pdf/  --cov-report html --cov-report term-missing
         cd $GITHUB_WORKSPACE && coverage run -m pytest tests/unittest/ --cov=magic_pdf/  --cov-report html --cov-report term-missing
         cd $GITHUB_WORKSPACE && python tests/get_coverage.py
         cd $GITHUB_WORKSPACE && python tests/get_coverage.py
         cd $GITHUB_WORKSPACE && pytest -s -v tests/test_cli/test_cli_sdk.py
         cd $GITHUB_WORKSPACE && pytest -s -v tests/test_cli/test_cli_sdk.py
 
 
   notify_to_feishu:
   notify_to_feishu:
-    if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure') && (github.ref_name == 'master') }}
+    if: ${{ always() && !cancelled() && contains(needs.*.result, 'failure')}}
     needs: cli-test
     needs: cli-test
     runs-on: pdf
     runs-on: pdf
     steps:
     steps:
@@ -57,5 +57,5 @@ jobs:
 
 
     - name: notify
     - name: notify
       run: |
       run: |
-        echo ${{ secrets.USER_ID }}
-        curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'${{ secrets.USER_ID }}'"}]]}}}}'  ${{ secrets.WEBHOOK_URL }}
+        #echo ${{ secrets.USER_ID }}
+        curl -X POST -H "Content-Type: application/json" -d '{"msg_type":"post","content":{"post":{"zh_cn":{"title":"'${{ github.repository }}' GitHubAction Failed","content":[[{"tag":"text","text":""},{"tag":"a","text":"Please click here for details ","href":"https://github.com/'${{ github.repository }}'/actions/runs/'${GITHUB_RUN_ID}'"},{"tag":"at","user_id":"'$USER_ID'"}]]}}}}'  $WEBHOOK_URL

+ 0 - 8
docs/README_Windows_CUDA_Acceleration_en_US.md

@@ -67,14 +67,6 @@ If your graphics card has at least 8GB of VRAM, follow these steps to test CUDA-
    ```
    ```
    pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
    pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
    ```
    ```
-   > [!IMPORTANT]
-   > Ensure the following versions are specified in the command:
-   >
-   > ```
-   > torch==2.3.1 torchvision==0.18.1
-   > ```
-   >
-   > These are the highest versions we support. Installing higher versions without specifying them will cause the program to fail.
 
 
 2. **Modify the value of `"device-mode"`** in the `magic-pdf.json` configuration file located in your user directory.
 2. **Modify the value of `"device-mode"`** in the `magic-pdf.json` configuration file located in your user directory.
 
 

+ 0 - 9
docs/README_Windows_CUDA_Acceleration_zh_CN.md

@@ -69,15 +69,6 @@ pip install -U magic-pdf[full] --extra-index-url https://wheels.myhloli.com -i h
 pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
 pip install --force-reinstall torch==2.3.1 torchvision==0.18.1 --index-url https://download.pytorch.org/whl/cu118
 ```
 ```
 
 
-> [!IMPORTANT]
-> 务必在命令中指定以下版本
->
-> ```bash
-> torch==2.3.1 torchvision==0.18.1
-> ```
->
-> 这是我们支持的最高版本,如果不指定版本会自动安装更高版本导致程序无法运行
-
 **2.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
 **2.修改【用户目录】中配置文件magic-pdf.json中"device-mode"的值**
 
 
 ```json
 ```json

+ 5 - 0
magic_pdf/config/constants.py

@@ -51,3 +51,8 @@ class MODEL_NAME:
     UniMerNet_v2_Small = 'unimernet_small'
     UniMerNet_v2_Small = 'unimernet_small'
 
 
     RAPID_TABLE = 'rapid_table'
     RAPID_TABLE = 'rapid_table'
+
+
+PARSE_TYPE_TXT = 'txt'
+PARSE_TYPE_OCR = 'ocr'
+

+ 13 - 1
magic_pdf/data/data_reader_writer/base.py

@@ -48,4 +48,16 @@ class DataWriter(ABC):
             path (str): the target file where to write
             path (str): the target file where to write
             data (str): the data want to write
             data (str): the data want to write
         """
         """
-        self.write(path, data.encode())
+
+        def safe_encode(data: str, method: str):
+            try:
+                bit_data = data.encode(encoding=method, errors='replace')
+                return bit_data, True
+            except:  # noqa
+                return None, False
+
+        for method in ['utf-8', 'ascii']:
+            bit_data, flag = safe_encode(data, method)
+            if flag:
+                self.write(path, bit_data)
+                break

+ 175 - 4
magic_pdf/data/dataset.py

@@ -1,11 +1,13 @@
+import os
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import Iterator
+from typing import Callable, Iterator
 
 
 import fitz
 import fitz
 
 
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.data.schemas import PageInfo
 from magic_pdf.data.schemas import PageInfo
 from magic_pdf.data.utils import fitz_doc_to_image
 from magic_pdf.data.utils import fitz_doc_to_image
+from magic_pdf.filter import classify
 
 
 
 
 class PageableData(ABC):
 class PageableData(ABC):
@@ -28,6 +30,32 @@ class PageableData(ABC):
         """
         """
         pass
         pass
 
 
+    @abstractmethod
+    def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
+        """draw rectangle.
+
+        Args:
+            rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
+            color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
+            fill (list[float] | None): fill the board with RGB, None means will not fill with color
+            fill_opacity (float): opacity of the fill, range from [0, 1]
+            width (float): the width of board
+            overlay (bool): fill the color in foreground or background. True means fill in background.
+        """
+        pass
+
+    @abstractmethod
+    def insert_text(self, coord, content, fontsize, color):
+        """insert text.
+
+        Args:
+            coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
+            content (str): the text content
+            fontsize (int): font size of the text
+            color (list[float] | None):  three element tuple which describe the RGB of the board line, None will use the default font color!
+        """
+        pass
+
 
 
 class Dataset(ABC):
 class Dataset(ABC):
     @abstractmethod
     @abstractmethod
@@ -66,6 +94,43 @@ class Dataset(ABC):
         """
         """
         pass
         pass
 
 
+    @abstractmethod
+    def dump_to_file(self, file_path: str):
+        """Dump the file
+
+        Args: 
+            file_path (str): the file path 
+        """
+        pass
+
+    @abstractmethod
+    def apply(self, proc: Callable, *args, **kwargs):
+        """Apply callable method which.
+
+        Args:
+            proc (Callable): invoke proc as follows:
+                proc(self, *args, **kwargs)
+
+        Returns:
+            Any: return the result generated by proc
+        """
+        pass
+
+    @abstractmethod
+    def classify(self) -> SupportedPdfParseMethod:
+        """classify the dataset 
+
+        Returns:
+            SupportedPdfParseMethod: _description_
+        """
+        pass
+
+    @abstractmethod
+    def clone(self):
+        """clone this dataset
+        """
+        pass
+
 
 
 class PymuDocDataset(Dataset):
 class PymuDocDataset(Dataset):
     def __init__(self, bits: bytes):
     def __init__(self, bits: bytes):
@@ -74,7 +139,8 @@ class PymuDocDataset(Dataset):
         Args:
         Args:
             bits (bytes): the bytes of the pdf
             bits (bytes): the bytes of the pdf
         """
         """
-        self._records = [Doc(v) for v in fitz.open('pdf', bits)]
+        self._raw_fitz = fitz.open('pdf', bits)
+        self._records = [Doc(v) for v in self._raw_fitz]
         self._data_bits = bits
         self._data_bits = bits
         self._raw_data = bits
         self._raw_data = bits
 
 
@@ -109,6 +175,43 @@ class PymuDocDataset(Dataset):
         """
         """
         return self._records[page_id]
         return self._records[page_id]
 
 
+    def dump_to_file(self, file_path: str):
+        """Dump the file
+
+        Args: 
+            file_path (str): the file path 
+        """
+        
+        dir_name = os.path.dirname(file_path)
+        if dir_name not in ('', '.', '..'):
+            os.makedirs(dir_name, exist_ok=True)
+        self._raw_fitz.save(file_path)
+
+    def apply(self, proc: Callable, *args, **kwargs):
+        """Apply callable method which.
+
+        Args:
+            proc (Callable): invoke proc as follows:
+                proc(dataset, *args, **kwargs)
+
+        Returns:
+            Any: return the result generated by proc
+        """
+        return proc(self, *args, **kwargs)
+
+    def classify(self) -> SupportedPdfParseMethod:
+        """classify the dataset 
+
+        Returns:
+            SupportedPdfParseMethod: _description_
+        """
+        return classify(self._data_bits)
+
+    def clone(self):
+        """clone this dataset
+        """
+        return PymuDocDataset(self._raw_data)
+
 
 
 class ImageDataset(Dataset):
 class ImageDataset(Dataset):
     def __init__(self, bits: bytes):
     def __init__(self, bits: bytes):
@@ -118,7 +221,8 @@ class ImageDataset(Dataset):
             bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
             bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
         """
         """
         pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
         pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
-        self._records = [Doc(v) for v in fitz.open('pdf', pdf_bytes)]
+        self._raw_fitz = fitz.open('pdf', pdf_bytes)
+        self._records = [Doc(v) for v in self._raw_fitz]
         self._raw_data = bits
         self._raw_data = bits
         self._data_bits = pdf_bytes
         self._data_bits = pdf_bytes
 
 
@@ -153,14 +257,50 @@ class ImageDataset(Dataset):
         """
         """
         return self._records[page_id]
         return self._records[page_id]
 
 
+    def dump_to_file(self, file_path: str):
+        """Dump the file
+
+        Args: 
+            file_path (str): the file path 
+        """
+        dir_name = os.path.dirname(file_path)
+        if dir_name not in ('', '.', '..'):
+            os.makedirs(dir_name, exist_ok=True)
+        self._raw_fitz.save(file_path)
+
+    def apply(self, proc: Callable, *args, **kwargs):
+        """Apply callable method which.
+
+        Args:
+            proc (Callable): invoke proc as follows:
+                proc(dataset, *args, **kwargs)
+
+        Returns:
+            Any: return the result generated by proc
+        """
+        return proc(self, *args, **kwargs)
+
+    def classify(self) -> SupportedPdfParseMethod:
+        """classify the dataset 
+
+        Returns:
+            SupportedPdfParseMethod: _description_
+        """
+        return SupportedPdfParseMethod.OCR
+
+    def clone(self):
+        """clone this dataset
+        """
+        return ImageDataset(self._raw_data)
 
 
 class Doc(PageableData):
 class Doc(PageableData):
     """Initialized with pymudoc object."""
     """Initialized with pymudoc object."""
+
     def __init__(self, doc: fitz.Page):
     def __init__(self, doc: fitz.Page):
         self._doc = doc
         self._doc = doc
 
 
     def get_image(self):
     def get_image(self):
-        """Return the imge info.
+        """Return the image info.
 
 
         Returns:
         Returns:
             dict: {
             dict: {
@@ -192,3 +332,34 @@ class Doc(PageableData):
     def __getattr__(self, name):
     def __getattr__(self, name):
         if hasattr(self._doc, name):
         if hasattr(self._doc, name):
             return getattr(self._doc, name)
             return getattr(self._doc, name)
+
+    def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
+        """draw rectangle.
+
+        Args:
+            rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
+            color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
+            fill (list[float] | None): fill the board with RGB, None means will not fill with color
+            fill_opacity (float): opacity of the fill, range from [0, 1]
+            width (float): the width of board
+            overlay (bool): fill the color in foreground or background. True means fill in background.
+        """
+        self._doc.draw_rect(
+            rect_coords,
+            color=color,
+            fill=fill,
+            fill_opacity=fill_opacity,
+            width=width,
+            overlay=overlay,
+        )
+
+    def insert_text(self, coord, content, fontsize, color):
+        """insert text.
+
+        Args:
+            coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
+            content (str): the text content
+            fontsize (int): font size of the text
+            color (list[float] | None):  three element tuple which describe the RGB of the board line, None will use the default font color!
+        """
+        self._doc.insert_text(coord, content, fontsize=fontsize, color=color)

+ 2 - 2
magic_pdf/dict2md/ocr_mkcontent.py

@@ -165,8 +165,8 @@ def merge_para_with_text(para_block):
             if content:
             if content:
                 langs = ['zh', 'ja', 'ko']
                 langs = ['zh', 'ja', 'ko']
                 # logger.info(f'block_lang: {block_lang}, content: {content}')
                 # logger.info(f'block_lang: {block_lang}, content: {content}')
-                if block_lang in langs: # 中文/日语/韩文语境下,换行不需要空格分隔
-                    if j == len(line['spans']) - 1:
+                if block_lang in langs: # 中文/日语/韩文语境下,换行不需要空格分隔,但是如果是行内公式结尾,还是要加空格
+                    if j == len(line['spans']) - 1 and span_type not in [ContentType.InlineEquation]:
                         para_text += content
                         para_text += content
                     else:
                     else:
                         para_text += f'{content} '
                         para_text += f'{content} '

+ 32 - 0
magic_pdf/filter/__init__.py

@@ -0,0 +1,32 @@
+
+from magic_pdf.config.drop_reason import DropReason
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.filter.pdf_classify_by_type import classify as do_classify
+from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
+
+
+def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
+    """根据pdf的元数据,判断是文本pdf,还是ocr pdf."""
+    pdf_meta = pdf_meta_scan(pdf_bytes)
+    if pdf_meta.get('_need_drop', False):  # 如果返回了需要丢弃的标志,则抛出异常
+        raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
+    else:
+        is_encrypted = pdf_meta['is_encrypted']
+        is_needs_password = pdf_meta['is_needs_password']
+        if is_encrypted or is_needs_password:  # 加密的,需要密码的,没有页面的,都不处理
+            raise Exception(f'pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}')
+        else:
+            is_text_pdf, results = do_classify(
+                pdf_meta['total_page'],
+                pdf_meta['page_width_pts'],
+                pdf_meta['page_height_pts'],
+                pdf_meta['image_info_per_page'],
+                pdf_meta['text_len_per_page'],
+                pdf_meta['imgs_per_page'],
+                pdf_meta['text_layout_per_page'],
+                pdf_meta['invalid_chars'],
+            )
+            if is_text_pdf:
+                return SupportedPdfParseMethod.TXT
+            else:
+                return SupportedPdfParseMethod.OCR

+ 3 - 2
magic_pdf/filter/pdf_meta_scan.py

@@ -8,7 +8,7 @@ from loguru import logger
 from magic_pdf.config.drop_reason import DropReason
 from magic_pdf.config.drop_reason import DropReason
 from magic_pdf.libs.commons import get_top_percent_list, mymax
 from magic_pdf.libs.commons import get_top_percent_list, mymax
 from magic_pdf.libs.language import detect_lang
 from magic_pdf.libs.language import detect_lang
-from magic_pdf.libs.pdf_check import detect_invalid_chars_by_pymupdf
+from magic_pdf.libs.pdf_check import detect_invalid_chars_by_pymupdf, detect_invalid_chars
 
 
 scan_max_page = 50
 scan_max_page = 50
 junk_limit_min = 10
 junk_limit_min = 10
@@ -323,7 +323,8 @@ def get_language(doc: fitz.Document):
 
 
 def check_invalid_chars(pdf_bytes):
 def check_invalid_chars(pdf_bytes):
     """乱码检测."""
     """乱码检测."""
-    return detect_invalid_chars_by_pymupdf(pdf_bytes)
+    # return detect_invalid_chars_by_pymupdf(pdf_bytes)
+    return detect_invalid_chars(pdf_bytes)
 
 
 
 
 def pdf_meta_scan(pdf_bytes: bytes):
 def pdf_meta_scan(pdf_bytes: bytes):

+ 11 - 10
magic_pdf/libs/draw_bbox.py

@@ -1,7 +1,8 @@
 import fitz
 import fitz
 from magic_pdf.config.constants import CROSS_PAGE
 from magic_pdf.config.constants import CROSS_PAGE
-from magic_pdf.config.ocr_content_type import BlockType, CategoryId, ContentType
-from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.config.ocr_content_type import (BlockType, CategoryId,
+                                               ContentType)
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.model.magic_model import MagicModel
 from magic_pdf.model.magic_model import MagicModel
 
 
 
 
@@ -194,7 +195,7 @@ def draw_layout_bbox(pdf_info, pdf_bytes, out_path, filename):
         )
         )
 
 
     # Save the PDF
     # Save the PDF
-    pdf_docs.save(f'{out_path}/{filename}_layout.pdf')
+    pdf_docs.save(f'{out_path}/{filename}')
 
 
 
 
 def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
 def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
@@ -282,18 +283,17 @@ def draw_span_bbox(pdf_info, pdf_bytes, out_path, filename):
         draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
         draw_bbox_without_number(i, dropped_list, page, [158, 158, 158], False)
 
 
     # Save the PDF
     # Save the PDF
-    pdf_docs.save(f'{out_path}/{filename}_spans.pdf')
+    pdf_docs.save(f'{out_path}/{filename}')
 
 
 
 
-def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
+def draw_model_bbox(model_list, dataset: Dataset, out_path, filename):
     dropped_bbox_list = []
     dropped_bbox_list = []
     tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
     tables_body_list, tables_caption_list, tables_footnote_list = [], [], []
     imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
     imgs_body_list, imgs_caption_list, imgs_footnote_list = [], [], []
     titles_list = []
     titles_list = []
     texts_list = []
     texts_list = []
     interequations_list = []
     interequations_list = []
-    pdf_docs = fitz.open('pdf', pdf_bytes)
-    magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
+    magic_model = MagicModel(model_list, dataset)
     for i in range(len(model_list)):
     for i in range(len(model_list)):
         page_dropped_list = []
         page_dropped_list = []
         tables_body, tables_caption, tables_footnote = [], [], []
         tables_body, tables_caption, tables_footnote = [], [], []
@@ -337,7 +337,8 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
         dropped_bbox_list.append(page_dropped_list)
         dropped_bbox_list.append(page_dropped_list)
         imgs_footnote_list.append(imgs_footnote)
         imgs_footnote_list.append(imgs_footnote)
 
 
-    for i, page in enumerate(pdf_docs):
+    for i in range(len(dataset)):
+        page = dataset.get_page(i)
         draw_bbox_with_number(
         draw_bbox_with_number(
             i, dropped_bbox_list, page, [158, 158, 158], True
             i, dropped_bbox_list, page, [158, 158, 158], True
         )  # color !
         )  # color !
@@ -352,7 +353,7 @@ def draw_model_bbox(model_list: list, pdf_bytes, out_path, filename):
         draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
         draw_bbox_with_number(i, interequations_list, page, [0, 255, 0], True)
 
 
     # Save the PDF
     # Save the PDF
-    pdf_docs.save(f'{out_path}/{filename}_model.pdf')
+    dataset.dump_to_file(f'{out_path}/{filename}')
 
 
 
 
 def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
 def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
@@ -390,7 +391,7 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
     for i, page in enumerate(pdf_docs):
     for i, page in enumerate(pdf_docs):
         draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
         draw_bbox_with_number(i, layout_bbox_list, page, [255, 0, 0], False)
 
 
-    pdf_docs.save(f'{out_path}/{filename}_line_sort.pdf')
+    pdf_docs.save(f'{out_path}/{filename}')
 
 
 
 
 def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
 def draw_layout_sort_bbox(pdf_info, pdf_bytes, out_path, filename):

+ 30 - 30
magic_pdf/libs/pdf_check.py

@@ -1,9 +1,9 @@
 import fitz
 import fitz
 import numpy as np
 import numpy as np
 from loguru import logger
 from loguru import logger
-# import re
-# from io import BytesIO
-# from pdfminer.high_level import extract_text
+import re
+from io import BytesIO
+from pdfminer.high_level import extract_text
 
 
 
 
 def calculate_sample_count(total_page: int):
 def calculate_sample_count(total_page: int):
@@ -33,33 +33,33 @@ def extract_pages(src_pdf_bytes: bytes) -> fitz.Document:
     return sample_docs
     return sample_docs
 
 
 
 
-# def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
-#     """"
-#     检测PDF中是否包含非法字符
-#     """
-#     '''pdfminer比较慢,需要先随机抽取10页左右的sample'''
-#     sample_docs = extract_pages(src_pdf_bytes)
-#     sample_pdf_bytes = sample_docs.tobytes()
-#     sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
-#     text = extract_text(sample_pdf_file_like_object)
-#     text = text.replace("\n", "")
-#     # logger.info(text)
-#     '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
-#     cid_pattern = re.compile(r'\(cid:\d+\)')
-#     matches = cid_pattern.findall(text)
-#     cid_count = len(matches)
-#     cid_len = sum(len(match) for match in matches)
-#     text_len = len(text)
-#     if text_len == 0:
-#         cid_chars_radio = 0
-#     else:
-#         cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
-#     logger.info(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
-#     '''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
-#     if cid_chars_radio > 0.05:
-#         return False  # 乱码文档
-#     else:
-#         return True   # 正常文档
+def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
+    """"
+    检测PDF中是否包含非法字符
+    """
+    '''pdfminer比较慢,需要先随机抽取10页左右的sample'''
+    sample_docs = extract_pages(src_pdf_bytes)
+    sample_pdf_bytes = sample_docs.tobytes()
+    sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
+    text = extract_text(sample_pdf_file_like_object)
+    text = text.replace("\n", "")
+    # logger.info(text)
+    '''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
+    cid_pattern = re.compile(r'\(cid:\d+\)')
+    matches = cid_pattern.findall(text)
+    cid_count = len(matches)
+    cid_len = sum(len(match) for match in matches)
+    text_len = len(text)
+    if text_len == 0:
+        cid_chars_radio = 0
+    else:
+        cid_chars_radio = cid_count/(cid_count + text_len - cid_len)
+    logger.info(f"cid_count: {cid_count}, text_len: {text_len}, cid_chars_radio: {cid_chars_radio}")
+    '''当一篇文章存在5%以上的文本是乱码时,认为该文档为乱码文档'''
+    if cid_chars_radio > 0.05:
+        return False  # 乱码文档
+    else:
+        return True   # 正常文档
 
 
 
 
 def count_replacement_characters(text: str) -> int:
 def count_replacement_characters(text: str) -> int:

+ 124 - 0
magic_pdf/model/__init__.py

@@ -1,2 +1,126 @@
+from typing import Callable
+
+from abc import ABC, abstractmethod
+
+from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
+from magic_pdf.pipe.operators import PipeResult
+
+
 __use_inside_model__ = True
 __use_inside_model__ = True
 __model_mode__ = "full"
 __model_mode__ = "full"
+
+
+class InferenceResultBase(ABC):
+
+    @abstractmethod
+    def __init__(self, inference_results: list, dataset: Dataset):
+        """Initialized method.
+
+        Args:
+            inference_results (list): the inference result generated by model
+            dataset (Dataset): the dataset related with model inference result
+        """
+        self._infer_res = inference_results
+        self._dataset = dataset
+
+    @abstractmethod
+    def draw_model(self, file_path: str) -> None:
+        """Draw model inference result.
+
+        Args:
+            file_path (str): the output file path
+        """
+        pass
+
+    @abstractmethod
+    def dump_model(self, writer: DataWriter, file_path: str):
+        """Dump model inference result to file.
+
+        Args:
+            writer (DataWriter): writer handle
+            file_path (str): the location of target file
+        """
+        pass
+
+    @abstractmethod
+    def get_infer_res(self):
+        """Get the inference result.
+
+        Returns:
+            list: the inference result generated by model
+        """
+        pass
+
+    @abstractmethod
+    def apply(self, proc: Callable, *args, **kwargs):
+        """Apply callable method which.
+
+        Args:
+            proc (Callable): invoke proc as follows:
+                proc(inference_result, *args, **kwargs)
+
+        Returns:
+            Any: return the result generated by proc
+        """
+        pass
+
+    @abstractmethod
+    def pipe_auto_mode(
+        self,
+        imageWriter: DataWriter,
+        start_page_id=0,
+        end_page_id=None,
+        debug_mode=False,
+        lang=None,
+    ) -> PipeResult:
+        """Post-proc the model inference result.
+            step1: classify the dataset type
+            step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
+
+        Args:
+            imageWriter (DataWriter): the image writer handle
+            start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
+            end_page_id (int, optional):  Defaults to the last page index of dataset. Let user select some pages He/She want to process
+            debug_mode (bool, optional): Defaults to False. will dump more log if enabled
+            lang (str, optional): Defaults to None.
+
+        Returns:
+            PipeResult: the result
+        """
+        pass
+
+    @abstractmethod
+    def pipe_txt_mode(
+        self,
+        imageWriter: DataWriter,
+        start_page_id=0,
+        end_page_id=None,
+        debug_mode=False,
+        lang=None,
+    ) -> PipeResult:
+        """Post-proc the model inference result, Extract the text using the
+        third library, such as `pymupdf`
+
+        Args:
+            imageWriter (DataWriter): the image writer handle
+            start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
+            end_page_id (int, optional):  Defaults to the last page index of dataset. Let user select some pages He/She want to process
+            debug_mode (bool, optional): Defaults to False. will dump more log if enabled
+            lang (str, optional): Defaults to None.
+
+        Returns:
+            PipeResult: the result
+        """
+        pass
+
+    @abstractmethod
+    def pipe_ocr_mode(
+        self,
+        imageWriter: DataWriter,
+        start_page_id=0,
+        end_page_id=None,
+        debug_mode=False,
+        lang=None,
+    ) -> PipeResult:
+        pass

+ 119 - 60
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -1,14 +1,34 @@
+import os
 import time
 import time
 
 
 import fitz
 import fitz
 import numpy as np
 import numpy as np
 from loguru import logger
 from loguru import logger
 
 
+# 关闭paddle的信号处理
+import paddle
+paddle.disable_signal_handler()
+
+os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
+os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger
+
+try:
+    import torchtext
+
+    if torchtext.__version__ >= '0.18.0':
+        torchtext.disable_torchtext_deprecation_warning()
+except ImportError:
+    pass
+
+import magic_pdf.model as model_config
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.libs.clean_memory import clean_memory
-from magic_pdf.libs.config_reader import get_local_models_dir, get_device, get_table_recog_config, get_layout_config, \
-    get_formula_config
+from magic_pdf.libs.config_reader import (get_device, get_formula_config,
+                                          get_layout_config,
+                                          get_local_models_dir,
+                                          get_table_recog_config)
 from magic_pdf.model.model_list import MODEL
 from magic_pdf.model.model_list import MODEL
-import magic_pdf.model as model_config
+from magic_pdf.model.operators import InferenceResult
 
 
 
 
 def dict_compare(d1, d2):
 def dict_compare(d1, d2):
@@ -19,25 +39,31 @@ def remove_duplicates_dicts(lst):
     unique_dicts = []
     unique_dicts = []
     for dict_item in lst:
     for dict_item in lst:
         if not any(
         if not any(
-                dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
+            dict_compare(dict_item, existing_dict) for existing_dict in unique_dicts
         ):
         ):
             unique_dicts.append(dict_item)
             unique_dicts.append(dict_item)
     return unique_dicts
     return unique_dicts
 
 
 
 
-def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None) -> list:
+def load_images_from_pdf(
+    pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id=None
+) -> list:
     try:
     try:
         from PIL import Image
         from PIL import Image
     except ImportError:
     except ImportError:
-        logger.error("Pillow not installed, please install by pip.")
+        logger.error('Pillow not installed, please install by pip.')
         exit(1)
         exit(1)
 
 
     images = []
     images = []
-    with fitz.open("pdf", pdf_bytes) as doc:
+    with fitz.open('pdf', pdf_bytes) as doc:
         pdf_page_num = doc.page_count
         pdf_page_num = doc.page_count
-        end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
+        end_page_id = (
+            end_page_id
+            if end_page_id is not None and end_page_id >= 0
+            else pdf_page_num - 1
+        )
         if end_page_id > pdf_page_num - 1:
         if end_page_id > pdf_page_num - 1:
-            logger.warning("end_page_id is out of range, use images length")
+            logger.warning('end_page_id is out of range, use images length')
             end_page_id = pdf_page_num - 1
             end_page_id = pdf_page_num - 1
 
 
         for index in range(0, doc.page_count):
         for index in range(0, doc.page_count):
@@ -50,11 +76,11 @@ def load_images_from_pdf(pdf_bytes: bytes, dpi=200, start_page_id=0, end_page_id
                 if pm.width > 4500 or pm.height > 4500:
                 if pm.width > 4500 or pm.height > 4500:
                     pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
                     pm = page.get_pixmap(matrix=fitz.Matrix(1, 1), alpha=False)
 
 
-                img = Image.frombytes("RGB", (pm.width, pm.height), pm.samples)
+                img = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
                 img = np.array(img)
                 img = np.array(img)
-                img_dict = {"img": img, "width": pm.width, "height": pm.height}
+                img_dict = {'img': img, 'width': pm.width, 'height': pm.height}
             else:
             else:
-                img_dict = {"img": [], "width": 0, "height": 0}
+                img_dict = {'img': [], 'width': 0, 'height': 0}
 
 
             images.append(img_dict)
             images.append(img_dict)
     return images
     return images
@@ -69,117 +95,150 @@ class ModelSingleton:
             cls._instance = super().__new__(cls)
             cls._instance = super().__new__(cls)
         return cls._instance
         return cls._instance
 
 
-    def get_model(self, ocr: bool, show_log: bool, lang=None, layout_model=None, formula_enable=None, table_enable=None):
+    def get_model(
+        self,
+        ocr: bool,
+        show_log: bool,
+        lang=None,
+        layout_model=None,
+        formula_enable=None,
+        table_enable=None,
+    ):
         key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
         key = (ocr, show_log, lang, layout_model, formula_enable, table_enable)
         if key not in self._models:
         if key not in self._models:
-            self._models[key] = custom_model_init(ocr=ocr, show_log=show_log, lang=lang, layout_model=layout_model,
-                                                  formula_enable=formula_enable, table_enable=table_enable)
+            self._models[key] = custom_model_init(
+                ocr=ocr,
+                show_log=show_log,
+                lang=lang,
+                layout_model=layout_model,
+                formula_enable=formula_enable,
+                table_enable=table_enable,
+            )
         return self._models[key]
         return self._models[key]
 
 
 
 
-def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None,
-                      layout_model=None, formula_enable=None, table_enable=None):
+def custom_model_init(
+    ocr: bool = False,
+    show_log: bool = False,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+):
 
 
     model = None
     model = None
 
 
-    if model_config.__model_mode__ == "lite":
-        logger.warning("The Lite mode is provided for developers to conduct testing only, and the output quality is "
-                       "not guaranteed to be reliable.")
+    if model_config.__model_mode__ == 'lite':
+        logger.warning(
+            'The Lite mode is provided for developers to conduct testing only, and the output quality is '
+            'not guaranteed to be reliable.'
+        )
         model = MODEL.Paddle
         model = MODEL.Paddle
-    elif model_config.__model_mode__ == "full":
+    elif model_config.__model_mode__ == 'full':
         model = MODEL.PEK
         model = MODEL.PEK
 
 
     if model_config.__use_inside_model__:
     if model_config.__use_inside_model__:
         model_init_start = time.time()
         model_init_start = time.time()
         if model == MODEL.Paddle:
         if model == MODEL.Paddle:
             from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
             from magic_pdf.model.pp_structure_v2 import CustomPaddleModel
+
             custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
             custom_model = CustomPaddleModel(ocr=ocr, show_log=show_log, lang=lang)
         elif model == MODEL.PEK:
         elif model == MODEL.PEK:
             from magic_pdf.model.pdf_extract_kit import CustomPEKModel
             from magic_pdf.model.pdf_extract_kit import CustomPEKModel
+
             # 从配置文件读取model-dir和device
             # 从配置文件读取model-dir和device
             local_models_dir = get_local_models_dir()
             local_models_dir = get_local_models_dir()
             device = get_device()
             device = get_device()
 
 
             layout_config = get_layout_config()
             layout_config = get_layout_config()
             if layout_model is not None:
             if layout_model is not None:
-                layout_config["model"] = layout_model
+                layout_config['model'] = layout_model
 
 
             formula_config = get_formula_config()
             formula_config = get_formula_config()
             if formula_enable is not None:
             if formula_enable is not None:
-                formula_config["enable"] = formula_enable
+                formula_config['enable'] = formula_enable
 
 
             table_config = get_table_recog_config()
             table_config = get_table_recog_config()
             if table_enable is not None:
             if table_enable is not None:
-                table_config["enable"] = table_enable
+                table_config['enable'] = table_enable
 
 
             model_input = {
             model_input = {
-                            "ocr": ocr,
-                            "show_log": show_log,
-                            "models_dir": local_models_dir,
-                            "device": device,
-                            "table_config": table_config,
-                            "layout_config": layout_config,
-                            "formula_config": formula_config,
-                            "lang": lang,
+                'ocr': ocr,
+                'show_log': show_log,
+                'models_dir': local_models_dir,
+                'device': device,
+                'table_config': table_config,
+                'layout_config': layout_config,
+                'formula_config': formula_config,
+                'lang': lang,
             }
             }
 
 
             custom_model = CustomPEKModel(**model_input)
             custom_model = CustomPEKModel(**model_input)
         else:
         else:
-            logger.error("Not allow model_name!")
+            logger.error('Not allow model_name!')
             exit(1)
             exit(1)
         model_init_cost = time.time() - model_init_start
         model_init_cost = time.time() - model_init_start
-        logger.info(f"model init cost: {model_init_cost}")
+        logger.info(f'model init cost: {model_init_cost}')
     else:
     else:
-        logger.error("use_inside_model is False, not allow to use inside model")
+        logger.error('use_inside_model is False, not allow to use inside model')
         exit(1)
         exit(1)
 
 
     return custom_model
     return custom_model
 
 
 
 
-def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
-                start_page_id=0, end_page_id=None, lang=None,
-                layout_model=None, formula_enable=None, table_enable=None):
+def doc_analyze(
+    dataset: Dataset,
+    ocr: bool = False,
+    show_log: bool = False,
+    start_page_id=0,
+    end_page_id=None,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+) -> InferenceResult:
 
 
-    if lang == "":
+    if lang == '':
         lang = None
         lang = None
 
 
     model_manager = ModelSingleton()
     model_manager = ModelSingleton()
-    custom_model = model_manager.get_model(ocr, show_log, lang, layout_model, formula_enable, table_enable)
-
-    with fitz.open("pdf", pdf_bytes) as doc:
-        pdf_page_num = doc.page_count
-        end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else pdf_page_num - 1
-        if end_page_id > pdf_page_num - 1:
-            logger.warning("end_page_id is out of range, use images length")
-            end_page_id = pdf_page_num - 1
-
-    images = load_images_from_pdf(pdf_bytes, start_page_id=start_page_id, end_page_id=end_page_id)
+    custom_model = model_manager.get_model(
+        ocr, show_log, lang, layout_model, formula_enable, table_enable
+    )
 
 
     model_json = []
     model_json = []
     doc_analyze_start = time.time()
     doc_analyze_start = time.time()
 
 
-    for index, img_dict in enumerate(images):
-        img = img_dict["img"]
-        page_width = img_dict["width"]
-        page_height = img_dict["height"]
+    if end_page_id is None:
+        end_page_id = len(dataset)
+
+    for index in range(len(dataset)):
+        page_data = dataset.get_page(index)
+        img_dict = page_data.get_image()
+        img = img_dict['img']
+        page_width = img_dict['width']
+        page_height = img_dict['height']
         if start_page_id <= index <= end_page_id:
         if start_page_id <= index <= end_page_id:
             page_start = time.time()
             page_start = time.time()
             result = custom_model(img)
             result = custom_model(img)
             logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
             logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
         else:
         else:
             result = []
             result = []
-        page_info = {"page_no": index, "height": page_height, "width": page_width}
-        page_dict = {"layout_dets": result, "page_info": page_info}
+
+        page_info = {'page_no': index, 'height': page_height, 'width': page_width}
+        page_dict = {'layout_dets': result, 'page_info': page_info}
         model_json.append(page_dict)
         model_json.append(page_dict)
 
 
     gc_start = time.time()
     gc_start = time.time()
     clean_memory()
     clean_memory()
     gc_time = round(time.time() - gc_start, 2)
     gc_time = round(time.time() - gc_start, 2)
-    logger.info(f"gc time: {gc_time}")
+    logger.info(f'gc time: {gc_time}')
 
 
     doc_analyze_time = round(time.time() - doc_analyze_start, 2)
     doc_analyze_time = round(time.time() - doc_analyze_start, 2)
-    doc_analyze_speed = round( (end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
-    logger.info(f"doc analyze time: {round(time.time() - doc_analyze_start, 2)},"
-                f" speed: {doc_analyze_speed} pages/second")
+    doc_analyze_speed = round((end_page_id + 1 - start_page_id) / doc_analyze_time, 2)
+    logger.info(
+        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
+        f' speed: {doc_analyze_speed} pages/second'
+    )
 
 
-    return model_json
+    return InferenceResult(model_json, dataset)

+ 190 - 0
magic_pdf/model/operators.py

@@ -0,0 +1,190 @@
+import copy
+import json
+import os
+from typing import Callable
+
+from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
+from magic_pdf.config.enums import SupportedPdfParseMethod
+from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
+from magic_pdf.filter import classify
+from magic_pdf.libs.draw_bbox import draw_model_bbox
+from magic_pdf.libs.version import __version__
+from magic_pdf.model import InferenceResultBase
+from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
+from magic_pdf.pipe.operators import PipeResult
+
+
+class InferenceResult(InferenceResultBase):
+    def __init__(self, inference_results: list, dataset: Dataset):
+        """Initialized method.
+
+        Args:
+            inference_results (list): the inference result generated by model
+            dataset (Dataset): the dataset related with model inference result
+        """
+        self._infer_res = inference_results
+        self._dataset = dataset
+
+    def draw_model(self, file_path: str) -> None:
+        """Draw model inference result.
+
+        Args:
+            file_path (str): the output file path
+        """
+        dir_name = os.path.dirname(file_path)
+        base_name = os.path.basename(file_path)
+        if not os.path.exists(dir_name):
+            os.makedirs(dir_name, exist_ok=True)
+        draw_model_bbox(
+            copy.deepcopy(self._infer_res), self._dataset, dir_name, base_name
+        )
+
+    def dump_model(self, writer: DataWriter, file_path: str):
+        """Dump model inference result to file.
+
+        Args:
+            writer (DataWriter): writer handle
+            file_path (str): the location of target file
+        """
+        writer.write_string(
+            file_path, json.dumps(self._infer_res, ensure_ascii=False, indent=4)
+        )
+
+    def get_infer_res(self):
+        """Get the inference result.
+
+        Returns:
+            list: the inference result generated by model
+        """
+        return self._infer_res
+
+    def apply(self, proc: Callable, *args, **kwargs):
+        """Apply callable method which.
+
+        Args:
+            proc (Callable): invoke proc as follows:
+                proc(inference_result, *args, **kwargs)
+
+        Returns:
+            Any: return the result generated by proc
+        """
+        return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
+
+    def pipe_auto_mode(
+        self,
+        imageWriter: DataWriter,
+        start_page_id=0,
+        end_page_id=None,
+        debug_mode=False,
+        lang=None,
+    ) -> PipeResult:
+        """Post-proc the model inference result.
+            step1: classify the dataset type
+            step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
+
+        Args:
+            imageWriter (DataWriter): the image writer handle
+            start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
+            end_page_id (int, optional):  Defaults to the last page index of dataset. Let user select some pages He/She want to process
+            debug_mode (bool, optional): Defaults to False. will dump more log if enabled
+            lang (str, optional): Defaults to None.
+
+        Returns:
+            PipeResult: the result
+        """
+
+        pdf_proc_method = classify(self._dataset.data_bits())
+
+        if pdf_proc_method == SupportedPdfParseMethod.TXT:
+            return self.pipe_txt_mode(
+                imageWriter, start_page_id, end_page_id, debug_mode, lang
+            )
+        else:
+            return self.pipe_ocr_mode(
+                imageWriter, start_page_id, end_page_id, debug_mode, lang
+            )
+
+    def pipe_txt_mode(
+        self,
+        imageWriter: DataWriter,
+        start_page_id=0,
+        end_page_id=None,
+        debug_mode=False,
+        lang=None,
+    ) -> PipeResult:
+        """Post-proc the model inference result, Extract the text using the
+        third library, such as `pymupdf`
+
+        Args:
+            imageWriter (DataWriter): the image writer handle
+            start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
+            end_page_id (int, optional):  Defaults to the last page index of dataset. Let user select some pages He/She want to process
+            debug_mode (bool, optional): Defaults to False. will dump more log if enabled
+            lang (str, optional): Defaults to None.
+
+        Returns:
+            PipeResult: the result
+        """
+
+        def proc(*args, **kwargs) -> PipeResult:
+            res = pdf_parse_union(*args, **kwargs)
+            res['_parse_type'] = PARSE_TYPE_TXT
+            res['_version_name'] = __version__
+            if 'lang' in kwargs and kwargs['lang'] is not None:
+                res['lang'] = kwargs['lang']
+            return PipeResult(res, self._dataset)
+
+        res = self.apply(
+            proc,
+            self._dataset,
+            imageWriter,
+            SupportedPdfParseMethod.TXT,
+            start_page_id=start_page_id,
+            end_page_id=end_page_id,
+            debug_mode=debug_mode,
+            lang=lang,
+        )
+        return res
+
+    def pipe_ocr_mode(
+        self,
+        imageWriter: DataWriter,
+        start_page_id=0,
+        end_page_id=None,
+        debug_mode=False,
+        lang=None,
+    ) -> PipeResult:
+        """Post-proc the model inference result, Extract the text using `OCR`
+        technical.
+
+        Args:
+            imageWriter (DataWriter): the image writer handle
+            start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
+            end_page_id (int, optional):  Defaults to the last page index of dataset. Let user select some pages He/She want to process
+            debug_mode (bool, optional): Defaults to False. will dump more log if enabled
+            lang (str, optional): Defaults to None.
+
+        Returns:
+            PipeResult: the result
+        """
+
+        def proc(*args, **kwargs) -> PipeResult:
+            res = pdf_parse_union(*args, **kwargs)
+            res['_parse_type'] = PARSE_TYPE_OCR
+            res['_version_name'] = __version__
+            if 'lang' in kwargs and kwargs['lang'] is not None:
+                res['lang'] = kwargs['lang']
+            return PipeResult(res, self._dataset)
+
+        res = self.apply(
+            proc,
+            self._dataset,
+            imageWriter,
+            SupportedPdfParseMethod.OCR,
+            start_page_id=start_page_id,
+            end_page_id=end_page_id,
+            debug_mode=debug_mode,
+            lang=lang,
+        )
+        return res

+ 20 - 1
magic_pdf/model/pdf_extract_kit.py

@@ -179,7 +179,25 @@ class CustomPEKModel:
             layout_res = self.layout_model(image, ignore_catids=[])
             layout_res = self.layout_model(image, ignore_catids=[])
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             # doclayout_yolo
             # doclayout_yolo
-            layout_res = self.layout_model.predict(image)
+            img_pil = Image.fromarray(image)
+            width, height = img_pil.size
+            # logger.info(f'width: {width}, height: {height}')
+            input_res = {"poly":[0,0,width,0,width,height,0,height]}
+            new_image, useful_list = crop_img(input_res, img_pil, crop_paste_x=width//2, crop_paste_y=0)
+            paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
+            layout_res = self.layout_model.predict(new_image)
+            for res in layout_res:
+                p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
+                p1 = p1 - paste_x + xmin
+                p2 = p2 - paste_y + ymin
+                p3 = p3 - paste_x + xmin
+                p4 = p4 - paste_y + ymin
+                p5 = p5 - paste_x + xmin
+                p6 = p6 - paste_y + ymin
+                p7 = p7 - paste_x + xmin
+                p8 = p8 - paste_y + ymin
+                res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
+
         layout_cost = round(time.time() - layout_start, 2)
         layout_cost = round(time.time() - layout_start, 2)
         logger.info(f'layout detection time: {layout_cost}')
         logger.info(f'layout detection time: {layout_cost}')
 
 
@@ -215,6 +233,7 @@ class CustomPEKModel:
 
 
             # OCR recognition
             # OCR recognition
             new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
             new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+
             if self.apply_ocr:
             if self.apply_ocr:
                 ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
                 ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
             else:
             else:

+ 13 - 3
magic_pdf/model/sub_modules/model_init.py

@@ -92,14 +92,24 @@ class AtomModelSingleton:
         return cls._instance
         return cls._instance
 
 
     def get_atom_model(self, atom_model_name: str, **kwargs):
     def get_atom_model(self, atom_model_name: str, **kwargs):
+
         lang = kwargs.get('lang', None)
         lang = kwargs.get('lang', None)
         layout_model_name = kwargs.get('layout_model_name', None)
         layout_model_name = kwargs.get('layout_model_name', None)
-        key = (atom_model_name, layout_model_name, lang)
+        table_model_name = kwargs.get('table_model_name', None)
+
+        if atom_model_name in [AtomicModel.OCR]:
+            key = (atom_model_name, lang)
+        elif atom_model_name in [AtomicModel.Layout]:
+            key = (atom_model_name, layout_model_name)
+        elif atom_model_name in [AtomicModel.Table]:
+            key = (atom_model_name, table_model_name)
+        else:
+            key = atom_model_name
+
         if key not in self._models:
         if key not in self._models:
             self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
             self._models[key] = atom_model_init(model_name=atom_model_name, **kwargs)
         return self._models[key]
         return self._models[key]
 
 
-
 def atom_model_init(model_name: str, **kwargs):
 def atom_model_init(model_name: str, **kwargs):
     atom_model = None
     atom_model = None
     if model_name == AtomicModel.Layout:
     if model_name == AtomicModel.Layout:
@@ -129,7 +139,7 @@ def atom_model_init(model_name: str, **kwargs):
         atom_model = ocr_model_init(
         atom_model = ocr_model_init(
             kwargs.get('ocr_show_log'),
             kwargs.get('ocr_show_log'),
             kwargs.get('det_db_box_thresh'),
             kwargs.get('det_db_box_thresh'),
-            kwargs.get('lang')
+            kwargs.get('lang'),
         )
         )
     elif model_name == AtomicModel.Table:
     elif model_name == AtomicModel.Table:
         atom_model = table_model_init(
         atom_model = table_model_init(

+ 11 - 5
magic_pdf/model/sub_modules/model_utils.py

@@ -42,10 +42,16 @@ def get_res_list_from_layout_res(layout_res):
 
 
 
 
 def clean_vram(device, vram_threshold=8):
 def clean_vram(device, vram_threshold=8):
+    total_memory = get_vram(device)
+    if total_memory and total_memory <= vram_threshold:
+        gc_start = time.time()
+        clean_memory()
+        gc_time = round(time.time() - gc_start, 2)
+        logger.info(f"gc time: {gc_time}")
+
+
+def get_vram(device):
     if torch.cuda.is_available() and device != 'cpu':
     if torch.cuda.is_available() and device != 'cpu':
         total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # 将字节转换为 GB
         total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # 将字节转换为 GB
-        if total_memory <= vram_threshold:
-            gc_start = time.time()
-            clean_memory()
-            gc_time = round(time.time() - gc_start, 2)
-            logger.info(f"gc time: {gc_time}")
+        return total_memory
+    return None

+ 4 - 5
magic_pdf/pdf_parse_by_ocr.py

@@ -1,9 +1,9 @@
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.enums import SupportedPdfParseMethod
-from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 
 
 
 
-def parse_pdf_by_ocr(pdf_bytes,
+def parse_pdf_by_ocr(dataset: Dataset,
                      model_list,
                      model_list,
                      imageWriter,
                      imageWriter,
                      start_page_id=0,
                      start_page_id=0,
@@ -11,9 +11,8 @@ def parse_pdf_by_ocr(pdf_bytes,
                      debug_mode=False,
                      debug_mode=False,
                      lang=None,
                      lang=None,
                      ):
                      ):
-    dataset = PymuDocDataset(pdf_bytes)
-    return pdf_parse_union(dataset,
-                           model_list,
+    return pdf_parse_union(model_list,
+                           dataset,
                            imageWriter,
                            imageWriter,
                            SupportedPdfParseMethod.OCR,
                            SupportedPdfParseMethod.OCR,
                            start_page_id=start_page_id,
                            start_page_id=start_page_id,

+ 4 - 5
magic_pdf/pdf_parse_by_txt.py

@@ -1,10 +1,10 @@
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.enums import SupportedPdfParseMethod
-from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
 
 
 
 
 def parse_pdf_by_txt(
 def parse_pdf_by_txt(
-    pdf_bytes,
+    dataset: Dataset,
     model_list,
     model_list,
     imageWriter,
     imageWriter,
     start_page_id=0,
     start_page_id=0,
@@ -12,9 +12,8 @@ def parse_pdf_by_txt(
     debug_mode=False,
     debug_mode=False,
     lang=None,
     lang=None,
 ):
 ):
-    dataset = PymuDocDataset(pdf_bytes)
-    return pdf_parse_union(dataset,
-                           model_list,
+    return pdf_parse_union(model_list,
+                           dataset,
                            imageWriter,
                            imageWriter,
                            SupportedPdfParseMethod.TXT,
                            SupportedPdfParseMethod.TXT,
                            start_page_id=start_page_id,
                            start_page_id=start_page_id,

+ 10 - 11
magic_pdf/pdf_parse_union_core_v2.py

@@ -4,8 +4,8 @@ import statistics
 import time
 import time
 from typing import List
 from typing import List
 
 
-import torch
 import fitz
 import fitz
+import torch
 from loguru import logger
 from loguru import logger
 
 
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.enums import SupportedPdfParseMethod
@@ -16,17 +16,13 @@ from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
 from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
 from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.hash_utils import compute_md5
-
 from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
 from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
 from magic_pdf.model.magic_model import MagicModel
 from magic_pdf.model.magic_model import MagicModel
 
 
-os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
-os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger
-
 try:
 try:
     import torchtext
     import torchtext
 
 
-    if torchtext.__version__ >= "0.18.0":
+    if torchtext.__version__ >= '0.18.0':
         torchtext.disable_torchtext_deprecation_warning()
         torchtext.disable_torchtext_deprecation_warning()
 except ImportError:
 except ImportError:
     pass
     pass
@@ -39,6 +35,9 @@ from magic_pdf.pre_proc.ocr_detect_all_bboxes import ocr_prepare_bboxes_for_layo
 from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block
 from magic_pdf.pre_proc.ocr_dict_merge import fill_spans_in_blocks, fix_block_spans_v2, fix_discarded_block
 from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, remove_overlaps_min_spans
 from magic_pdf.pre_proc.ocr_span_list_modify import get_qa_need_list_v2, remove_overlaps_low_confidence_spans, remove_overlaps_min_spans
 
 
+os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
+os.environ['YOLO_VERBOSE'] = 'False'  # disable yolo logger
+
 
 
 def __replace_STX_ETX(text_str: str):
 def __replace_STX_ETX(text_str: str):
     """Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
     """Replace \u0002 and \u0003, as these characters become garbled when extracted using pymupdf. In fact, they were originally quotation marks.
@@ -233,7 +232,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
         # 初始化ocr模型
         # 初始化ocr模型
         atom_model_manager = AtomModelSingleton()
         atom_model_manager = AtomModelSingleton()
         ocr_model = atom_model_manager.get_atom_model(
         ocr_model = atom_model_manager.get_atom_model(
-            atom_model_name="ocr",
+            atom_model_name='ocr',
             ocr_show_log=False,
             ocr_show_log=False,
             det_db_box_thresh=0.3,
             det_db_box_thresh=0.3,
             lang=lang
             lang=lang
@@ -241,7 +240,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
 
 
         for span in empty_spans:
         for span in empty_spans:
             # 对span的bbox截图再ocr
             # 对span的bbox截图再ocr
-            span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode="cv2")
+            span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode='cv2')
             ocr_res = ocr_model.ocr(span_img, det=False)
             ocr_res = ocr_model.ocr(span_img, det=False)
             if ocr_res and len(ocr_res) > 0:
             if ocr_res and len(ocr_res) > 0:
                 if len(ocr_res[0]) > 0:
                 if len(ocr_res[0]) > 0:
@@ -681,7 +680,7 @@ def parse_page_core(
     """根据parse_mode,构造spans,主要是文本类的字符填充"""
     """根据parse_mode,构造spans,主要是文本类的字符填充"""
     if parse_mode == SupportedPdfParseMethod.TXT:
     if parse_mode == SupportedPdfParseMethod.TXT:
 
 
-        """使用新版本的混合ocr方案"""
+        """使用新版本的混合ocr方案."""
         spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang)
         spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang)
 
 
     elif parse_mode == SupportedPdfParseMethod.OCR:
     elif parse_mode == SupportedPdfParseMethod.OCR:
@@ -689,7 +688,6 @@ def parse_page_core(
     else:
     else:
         raise Exception('parse_mode must be txt or ocr')
         raise Exception('parse_mode must be txt or ocr')
 
 
-
     """先处理不需要排版的discarded_blocks"""
     """先处理不需要排版的discarded_blocks"""
     discarded_block_with_spans, spans = fill_spans_in_blocks(
     discarded_block_with_spans, spans = fill_spans_in_blocks(
         all_discarded_blocks, spans, 0.4
         all_discarded_blocks, spans, 0.4
@@ -762,8 +760,8 @@ def parse_page_core(
 
 
 
 
 def pdf_parse_union(
 def pdf_parse_union(
-    dataset: Dataset,
     model_list,
     model_list,
+    dataset: Dataset,
     imageWriter,
     imageWriter,
     parse_mode,
     parse_mode,
     start_page_id=0,
     start_page_id=0,
@@ -771,6 +769,7 @@ def pdf_parse_union(
     debug_mode=False,
     debug_mode=False,
     lang=None,
     lang=None,
 ):
 ):
+
     pdf_bytes_md5 = compute_md5(dataset.data_bits())
     pdf_bytes_md5 = compute_md5(dataset.data_bits())
 
 
     """初始化空的pdf_info_dict"""
     """初始化空的pdf_info_dict"""

+ 3 - 2
magic_pdf/pipe/AbsPipe.py

@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
 from magic_pdf.config.drop_reason import DropReason
 from magic_pdf.config.drop_reason import DropReason
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.dict2md.ocr_mkcontent import union_make
 from magic_pdf.dict2md.ocr_mkcontent import union_make
 from magic_pdf.filter.pdf_classify_by_type import classify
 from magic_pdf.filter.pdf_classify_by_type import classify
 from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
 from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
@@ -14,9 +15,9 @@ class AbsPipe(ABC):
     PIP_OCR = 'ocr'
     PIP_OCR = 'ocr'
     PIP_TXT = 'txt'
     PIP_TXT = 'txt'
 
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
+    def __init__(self, dataset: Dataset, model_list: list, image_writer: DataWriter, is_debug: bool = False,
                  start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
                  start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
-        self.pdf_bytes = pdf_bytes
+        self.dataset = Dataset
         self.model_list = model_list
         self.model_list = model_list
         self.image_writer = image_writer
         self.image_writer = image_writer
         self.pdf_mid_data = None  # 未压缩
         self.pdf_mid_data = None  # 未压缩

+ 54 - 15
magic_pdf/pipe/OCRPipe.py

@@ -2,40 +2,79 @@ from loguru import logger
 
 
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.user_api import parse_ocr_pdf
 from magic_pdf.user_api import parse_ocr_pdf
 
 
 
 
 class OCRPipe(AbsPipe):
 class OCRPipe(AbsPipe):
-
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None,
-                 layout_model=None, formula_enable=None, table_enable=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
-                         layout_model, formula_enable, table_enable)
+    def __init__(
+        self,
+        dataset: Dataset,
+        model_list: list,
+        image_writer: DataWriter,
+        is_debug: bool = False,
+        start_page_id=0,
+        end_page_id=None,
+        lang=None,
+        layout_model=None,
+        formula_enable=None,
+        table_enable=None,
+    ):
+        super().__init__(
+            dataset,
+            model_list,
+            image_writer,
+            is_debug,
+            start_page_id,
+            end_page_id,
+            lang,
+            layout_model,
+            formula_enable,
+            table_enable,
+        )
 
 
     def pipe_classify(self):
     def pipe_classify(self):
         pass
         pass
 
 
     def pipe_analyze(self):
     def pipe_analyze(self):
-        self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
-                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                      lang=self.lang, layout_model=self.layout_model,
-                                      formula_enable=self.formula_enable, table_enable=self.table_enable)
+        self.infer_res = doc_analyze(
+            self.dataset,
+            ocr=True,
+            start_page_id=self.start_page_id,
+            end_page_id=self.end_page_id,
+            lang=self.lang,
+            layout_model=self.layout_model,
+            formula_enable=self.formula_enable,
+            table_enable=self.table_enable,
+        )
 
 
     def pipe_parse(self):
     def pipe_parse(self):
-        self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang, layout_model=self.layout_model,
-                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
+        self.pdf_mid_data = parse_ocr_pdf(
+            self.dataset,
+            self.infer_res,
+            self.image_writer,
+            is_debug=self.is_debug,
+            start_page_id=self.start_page_id,
+            end_page_id=self.end_page_id,
+            lang=self.lang,
+            layout_model=self.layout_model,
+            formula_enable=self.formula_enable,
+            table_enable=self.table_enable,
+        )
 
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         logger.info('ocr_pipe mk content list finished')
         logger.info('ocr_pipe mk content list finished')
         return result
         return result
 
 
-    def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
+    def pipe_mk_markdown(
+        self,
+        img_parent_path: str,
+        drop_mode=DropMode.WHOLE_PDF,
+        md_make_mode=MakeMode.MM_MD,
+    ):
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
         logger.info(f'ocr_pipe mk {md_make_mode} finished')
         logger.info(f'ocr_pipe mk {md_make_mode} finished')
         return result
         return result

+ 5 - 4
magic_pdf/pipe/TXTPipe.py

@@ -2,6 +2,7 @@ from loguru import logger
 
 
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.user_api import parse_txt_pdf
 from magic_pdf.user_api import parse_txt_pdf
@@ -9,23 +10,23 @@ from magic_pdf.user_api import parse_txt_pdf
 
 
 class TXTPipe(AbsPipe):
 class TXTPipe(AbsPipe):
 
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
+    def __init__(self, dataset: Dataset, model_list: list, image_writer: DataWriter, is_debug: bool = False,
                  start_page_id=0, end_page_id=None, lang=None,
                  start_page_id=0, end_page_id=None, lang=None,
                  layout_model=None, formula_enable=None, table_enable=None):
                  layout_model=None, formula_enable=None, table_enable=None):
-        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
+        super().__init__(dataset, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
                          layout_model, formula_enable, table_enable)
                          layout_model, formula_enable, table_enable)
 
 
     def pipe_classify(self):
     def pipe_classify(self):
         pass
         pass
 
 
     def pipe_analyze(self):
     def pipe_analyze(self):
-        self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
+        self.model_list = doc_analyze(self.dataset, ocr=False,
                                       start_page_id=self.start_page_id, end_page_id=self.end_page_id,
                                       start_page_id=self.start_page_id, end_page_id=self.end_page_id,
                                       lang=self.lang, layout_model=self.layout_model,
                                       lang=self.lang, layout_model=self.layout_model,
                                       formula_enable=self.formula_enable, table_enable=self.table_enable)
                                       formula_enable=self.formula_enable, table_enable=self.table_enable)
 
 
     def pipe_parse(self):
     def pipe_parse(self):
-        self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,
+        self.pdf_mid_data = parse_txt_pdf(self.dataset, self.model_list, self.image_writer, is_debug=self.is_debug,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
                                           lang=self.lang, layout_model=self.layout_model,
                                           lang=self.lang, layout_model=self.layout_model,
                                           formula_enable=self.formula_enable, table_enable=self.table_enable)
                                           formula_enable=self.formula_enable, table_enable=self.table_enable)

+ 82 - 30
magic_pdf/pipe/UNIPipe.py

@@ -4,6 +4,7 @@ from loguru import logger
 
 
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.commons import join_path
 from magic_pdf.libs.commons import join_path
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.pipe.AbsPipe import AbsPipe
@@ -12,12 +13,32 @@ from magic_pdf.user_api import parse_ocr_pdf, parse_union_pdf
 
 
 class UNIPipe(AbsPipe):
 class UNIPipe(AbsPipe):
 
 
-    def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: DataWriter, is_debug: bool = False,
-                 start_page_id=0, end_page_id=None, lang=None,
-                 layout_model=None, formula_enable=None, table_enable=None):
+    def __init__(
+        self,
+        dataset: Dataset,
+        jso_useful_key: dict,
+        image_writer: DataWriter,
+        is_debug: bool = False,
+        start_page_id=0,
+        end_page_id=None,
+        lang=None,
+        layout_model=None,
+        formula_enable=None,
+        table_enable=None,
+    ):
         self.pdf_type = jso_useful_key['_pdf_type']
         self.pdf_type = jso_useful_key['_pdf_type']
-        super().__init__(pdf_bytes, jso_useful_key['model_list'], image_writer, is_debug, start_page_id, end_page_id,
-                         lang, layout_model, formula_enable, table_enable)
+        super().__init__(
+            dataset,
+            jso_useful_key['model_list'],
+            image_writer,
+            is_debug,
+            start_page_id,
+            end_page_id,
+            lang,
+            layout_model,
+            formula_enable,
+            table_enable,
+        )
         if len(self.model_list) == 0:
         if len(self.model_list) == 0:
             self.input_model_is_empty = True
             self.input_model_is_empty = True
         else:
         else:
@@ -28,35 +49,66 @@ class UNIPipe(AbsPipe):
 
 
     def pipe_analyze(self):
     def pipe_analyze(self):
         if self.pdf_type == self.PIP_TXT:
         if self.pdf_type == self.PIP_TXT:
-            self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang, layout_model=self.layout_model,
-                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
+            self.model_list = doc_analyze(
+                self.dataset,
+                ocr=False,
+                start_page_id=self.start_page_id,
+                end_page_id=self.end_page_id,
+                lang=self.lang,
+                layout_model=self.layout_model,
+                formula_enable=self.formula_enable,
+                table_enable=self.table_enable,
+            )
         elif self.pdf_type == self.PIP_OCR:
         elif self.pdf_type == self.PIP_OCR:
-            self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
-                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                          lang=self.lang, layout_model=self.layout_model,
-                                          formula_enable=self.formula_enable, table_enable=self.table_enable)
+            self.model_list = doc_analyze(
+                self.dataset,
+                ocr=True,
+                start_page_id=self.start_page_id,
+                end_page_id=self.end_page_id,
+                lang=self.lang,
+                layout_model=self.layout_model,
+                formula_enable=self.formula_enable,
+                table_enable=self.table_enable,
+            )
 
 
     def pipe_parse(self):
     def pipe_parse(self):
         if self.pdf_type == self.PIP_TXT:
         if self.pdf_type == self.PIP_TXT:
-            self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
-                                                is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
-                                                start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                                lang=self.lang, layout_model=self.layout_model,
-                                                formula_enable=self.formula_enable, table_enable=self.table_enable)
+            self.pdf_mid_data = parse_union_pdf(
+                self.dataset,
+                self.model_list,
+                self.image_writer,
+                is_debug=self.is_debug,
+                start_page_id=self.start_page_id,
+                end_page_id=self.end_page_id,
+                lang=self.lang,
+                layout_model=self.layout_model,
+                formula_enable=self.formula_enable,
+                table_enable=self.table_enable,
+            )
         elif self.pdf_type == self.PIP_OCR:
         elif self.pdf_type == self.PIP_OCR:
-            self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
-                                              is_debug=self.is_debug,
-                                              start_page_id=self.start_page_id, end_page_id=self.end_page_id,
-                                              lang=self.lang)
-
-    def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON):
+            self.pdf_mid_data = parse_ocr_pdf(
+                self.dataset,
+                self.model_list,
+                self.image_writer,
+                is_debug=self.is_debug,
+                start_page_id=self.start_page_id,
+                end_page_id=self.end_page_id,
+                lang=self.lang,
+            )
+
+    def pipe_mk_uni_format(
+        self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON
+    ):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         logger.info('uni_pipe mk content list finished')
         logger.info('uni_pipe mk content list finished')
         return result
         return result
 
 
-    def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
+    def pipe_mk_markdown(
+        self,
+        img_parent_path: str,
+        drop_mode=DropMode.WHOLE_PDF,
+        md_make_mode=MakeMode.MM_MD,
+    ):
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
         logger.info(f'uni_pipe mk {md_make_mode} finished')
         logger.info(f'uni_pipe mk {md_make_mode} finished')
         return result
         return result
@@ -65,6 +117,7 @@ class UNIPipe(AbsPipe):
 if __name__ == '__main__':
 if __name__ == '__main__':
     # 测试
     # 测试
     from magic_pdf.data.data_reader_writer import DataReader
     from magic_pdf.data.data_reader_writer import DataReader
+
     drw = DataReader(r'D:/project/20231108code-clean')
     drw = DataReader(r'D:/project/20231108code-clean')
 
 
     pdf_file_path = r'linshixuqiu\19983-00.pdf'
     pdf_file_path = r'linshixuqiu\19983-00.pdf'
@@ -82,10 +135,7 @@ if __name__ == '__main__':
     #     "model_list": model_list
     #     "model_list": model_list
     # }
     # }
 
 
-    jso_useful_key = {
-        '_pdf_type': '',
-        'model_list': model_list
-    }
+    jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
     pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer)
     pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer)
     pipe.pipe_classify()
     pipe.pipe_classify()
     pipe.pipe_parse()
     pipe.pipe_parse()
@@ -94,5 +144,7 @@ if __name__ == '__main__':
 
 
     md_writer = DataWriter(write_path)
     md_writer = DataWriter(write_path)
     md_writer.write_string('19983-00.md', md_content)
     md_writer.write_string('19983-00.md', md_content)
-    md_writer.write_string('19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4))
+    md_writer.write_string(
+        '19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
+    )
     md_writer.write_string('19983-00.txt', str(content_list))
     md_writer.write_string('19983-00.txt', str(content_list))

+ 138 - 0
magic_pdf/pipe/operators.py

@@ -0,0 +1,138 @@
+import json
+import os
+from typing import Callable
+import copy
+
+from magic_pdf.config.make_content_config import DropMode, MakeMode
+from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
+from magic_pdf.dict2md.ocr_mkcontent import union_make
+from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
+                                      draw_span_bbox)
+from magic_pdf.libs.json_compressor import JsonCompressor
+
+
+class PipeResult:
+    def __init__(self, pipe_res, dataset: Dataset):
+        """Initialized.
+
+        Args:
+            pipe_res (list[dict]): the pipeline processed result of model inference result
+            dataset (Dataset): the dataset associated with pipe_res
+        """
+        self._pipe_res = pipe_res
+        self._dataset = dataset
+
+    def dump_md(
+        self,
+        writer: DataWriter,
+        file_path: str,
+        img_dir_or_bucket_prefix: str,
+        drop_mode=DropMode.WHOLE_PDF,
+        md_make_mode=MakeMode.MM_MD,
+    ):
+        """Dump The Markdown.
+
+        Args:
+            writer (DataWriter): File writer handle
+            file_path (str): The file location of markdown
+            img_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
+            drop_mode (str, optional): Drop strategy when some page which is corrupted or inappropriate. Defaults to DropMode.WHOLE_PDF.
+            md_make_mode (str, optional): The content Type of Markdown be made. Defaults to MakeMode.MM_MD.
+        """
+        pdf_info_list = self._pipe_res['pdf_info']
+        md_content = union_make(
+            pdf_info_list, md_make_mode, drop_mode, img_dir_or_bucket_prefix
+        )
+        writer.write_string(file_path, md_content)
+
+    def dump_content_list(
+        self, writer: DataWriter, file_path: str, image_dir_or_bucket_prefix: str
+    ):
+        """Dump Content List.
+
+        Args:
+            writer (DataWriter): File writer handle
+            file_path (str): The file location of content list
+            image_dir_or_bucket_prefix (str): The s3 bucket prefix or local file directory which used to store the figure
+        """
+        pdf_info_list = self._pipe_res['pdf_info']
+        content_list = union_make(
+            pdf_info_list,
+            MakeMode.STANDARD_FORMAT,
+            DropMode.NONE,
+            image_dir_or_bucket_prefix,
+        )
+        writer.write_string(
+            file_path, json.dumps(content_list, ensure_ascii=False, indent=4)
+        )
+
+    def dump_middle_json(self, writer: DataWriter, file_path: str):
+        """Dump the result of pipeline.
+
+        Args:
+            writer (DataWriter): File writer handler
+            file_path (str): The file location of middle json
+        """
+        writer.write_string(
+            file_path, json.dumps(self._pipe_res, ensure_ascii=False, indent=4)
+        )
+
+    def draw_layout(self, file_path: str) -> None:
+        """Draw the layout.
+
+        Args:
+            file_path (str): The file location of layout result file
+        """
+        dir_name = os.path.dirname(file_path)
+        base_name = os.path.basename(file_path)
+        if not os.path.exists(dir_name):
+            os.makedirs(dir_name, exist_ok=True)
+        pdf_info = self._pipe_res['pdf_info']
+        draw_layout_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
+
+    def draw_span(self, file_path: str):
+        """Draw the Span.
+
+        Args:
+            file_path (str): The file location of span result file
+        """
+        dir_name = os.path.dirname(file_path)
+        base_name = os.path.basename(file_path)
+        if not os.path.exists(dir_name):
+            os.makedirs(dir_name, exist_ok=True)
+        pdf_info = self._pipe_res['pdf_info']
+        draw_span_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
+
+    def draw_line_sort(self, file_path: str):
+        """Draw line sort.
+
+        Args:
+            file_path (str): The file location of line sort result file
+        """
+        dir_name = os.path.dirname(file_path)
+        base_name = os.path.basename(file_path)
+        if not os.path.exists(dir_name):
+            os.makedirs(dir_name, exist_ok=True)
+        pdf_info = self._pipe_res['pdf_info']
+        draw_line_sort_bbox(pdf_info, self._dataset.data_bits(), dir_name, base_name)
+
+    def get_compress_pdf_mid_data(self):
+        """Compress the pipeline result.
+
+        Returns:
+            str: compress the pipeline result and return
+        """
+        return JsonCompressor.compress_json(self.pdf_mid_data)
+
+    def apply(self, proc: Callable, *args, **kwargs):
+        """Apply callable method which.
+
+        Args:
+            proc (Callable): invoke proc as follows:
+                proc(pipeline_result, *args, **kwargs)
+
+        Returns:
+            Any: return the result generated by proc
+        """
+        return proc(copy.deepcopy(self._pipe_res), *args, **kwargs)

+ 108 - 59
magic_pdf/tools/common.py

@@ -1,5 +1,3 @@
-import copy
-import json as json_parse
 import os
 import os
 
 
 import click
 import click
@@ -7,13 +5,12 @@ import fitz
 from loguru import logger
 from loguru import logger
 
 
 import magic_pdf.model as model_config
 import magic_pdf.model as model_config
+from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.config.make_content_config import DropMode, MakeMode
 from magic_pdf.data.data_reader_writer import FileBasedDataWriter
 from magic_pdf.data.data_reader_writer import FileBasedDataWriter
-from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
-                                      draw_model_bbox, draw_span_bbox)
-from magic_pdf.pipe.OCRPipe import OCRPipe
-from magic_pdf.pipe.TXTPipe import TXTPipe
-from magic_pdf.pipe.UNIPipe import UNIPipe
+from magic_pdf.data.dataset import PymuDocDataset
+from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
+from magic_pdf.model.operators import InferenceResult
 
 
 # from io import BytesIO
 # from io import BytesIO
 # from pypdf import PdfReader, PdfWriter
 # from pypdf import PdfReader, PdfWriter
@@ -56,7 +53,11 @@ def prepare_env(output_dir, pdf_file_name, method):
 def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_id=None):
 def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_id=None):
     document = fitz.open('pdf', pdf_bytes)
     document = fitz.open('pdf', pdf_bytes)
     output_document = fitz.open()
     output_document = fitz.open()
-    end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(document) - 1
+    end_page_id = (
+        end_page_id
+        if end_page_id is not None and end_page_id >= 0
+        else len(document) - 1
+    )
     if end_page_id > len(document) - 1:
     if end_page_id > len(document) - 1:
         logger.warning('end_page_id is out of range, use pdf_docs length')
         logger.warning('end_page_id is out of range, use pdf_docs length')
         end_page_id = len(document) - 1
         end_page_id = len(document) - 1
@@ -94,78 +95,126 @@ def do_parse(
         f_draw_model_bbox = True
         f_draw_model_bbox = True
         f_draw_line_sort_bbox = True
         f_draw_line_sort_bbox = True
 
 
-    if lang == "":
+    if lang == '':
         lang = None
         lang = None
 
 
-    pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id, end_page_id)
+    pdf_bytes = convert_pdf_bytes_to_bytes_by_pymupdf(
+        pdf_bytes, start_page_id, end_page_id
+    )
 
 
-    orig_model_list = copy.deepcopy(model_list)
-    local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name,
-                                                parse_method)
+    local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
 
 
-    image_writer, md_writer = FileBasedDataWriter(
-        local_image_dir), FileBasedDataWriter(local_md_dir)
+    image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
+        local_md_dir
+    )
     image_dir = str(os.path.basename(local_image_dir))
     image_dir = str(os.path.basename(local_image_dir))
 
 
-    if parse_method == 'auto':
-        jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
-        pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
-                       # start_page_id=start_page_id, end_page_id=end_page_id,
-                       lang=lang,
-                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
-    elif parse_method == 'txt':
-        pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       # start_page_id=start_page_id, end_page_id=end_page_id,
-                       lang=lang,
-                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
-    elif parse_method == 'ocr':
-        pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
-                       # start_page_id=start_page_id, end_page_id=end_page_id,
-                       lang=lang,
-                       layout_model=layout_model, formula_enable=formula_enable, table_enable=table_enable)
-    else:
-        logger.error('unknown parse method')
-        exit(1)
-
-    pipe.pipe_classify()
+    ds = PymuDocDataset(pdf_bytes)
 
 
     if len(model_list) == 0:
     if len(model_list) == 0:
         if model_config.__use_inside_model__:
         if model_config.__use_inside_model__:
-            pipe.pipe_analyze()
-            orig_model_list = copy.deepcopy(pipe.model_list)
+            if parse_method == 'auto':
+                if ds.classify() == SupportedPdfParseMethod.TXT:
+                    infer_result = ds.apply(
+                        doc_analyze,
+                        ocr=False,
+                        lang=lang,
+                        layout_model=layout_model,
+                        formula_enable=formula_enable,
+                        table_enable=table_enable,
+                    )
+                    pipe_result = infer_result.pipe_txt_mode(
+                        image_writer, debug_mode=True, lang=lang
+                    )
+                else:
+                    infer_result = ds.apply(
+                        doc_analyze,
+                        ocr=True,
+                        lang=lang,
+                        layout_model=layout_model,
+                        formula_enable=formula_enable,
+                        table_enable=table_enable,
+                    )
+                    pipe_result = infer_result.pipe_ocr_mode(
+                        image_writer, debug_mode=True, lang=lang
+                    )
+
+            elif parse_method == 'txt':
+                infer_result = ds.apply(
+                    doc_analyze,
+                    ocr=False,
+                    lang=lang,
+                    layout_model=layout_model,
+                    formula_enable=formula_enable,
+                    table_enable=table_enable,
+                )
+                pipe_result = infer_result.pipe_txt_mode(
+                    image_writer, debug_mode=True, lang=lang
+                )
+            elif parse_method == 'ocr':
+                infer_result = ds.apply(
+                    doc_analyze,
+                    ocr=True,
+                    lang=lang,
+                    layout_model=layout_model,
+                    formula_enable=formula_enable,
+                    table_enable=table_enable,
+                )
+                pipe_result = infer_result.pipe_ocr_mode(
+                    image_writer, debug_mode=True, lang=lang
+                )
+            else:
+                logger.error('unknown parse method')
+                exit(1)
         else:
         else:
             logger.error('need model list input')
             logger.error('need model list input')
             exit(2)
             exit(2)
+    else:
+        infer_result = InferenceResult(model_list, ds)
+        if parse_method == 'ocr':
+            pipe_result = infer_result.pipe_ocr_mode(
+                image_writer, debug_mode=True, lang=lang
+            )
+        elif parse_method == 'txt':
+            pipe_result = infer_result.pipe_txt_mode(
+                image_writer, debug_mode=True, lang=lang
+            )
+        else:
+            pipe_result = infer_result.pipe_auto_mode(
+                image_writer, debug_mode=True, lang=lang
+            )
+
+    if f_draw_model_bbox:
+        infer_result.draw_model(
+            os.path.join(local_md_dir, f'{pdf_file_name}_model.pdf')
+        )
 
 
-    pipe.pipe_parse()
-    pdf_info = pipe.pdf_mid_data['pdf_info']
     if f_draw_layout_bbox:
     if f_draw_layout_bbox:
-        draw_layout_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
+        pipe_result.draw_layout(
+            os.path.join(local_md_dir, f'{pdf_file_name}_layout.pdf')
+        )
     if f_draw_span_bbox:
     if f_draw_span_bbox:
-        draw_span_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
-    if f_draw_model_bbox:
-        draw_model_bbox(copy.deepcopy(orig_model_list), pdf_bytes, local_md_dir, pdf_file_name)
+        pipe_result.draw_span(os.path.join(local_md_dir, f'{pdf_file_name}_spans.pdf'))
+
     if f_draw_line_sort_bbox:
     if f_draw_line_sort_bbox:
-        draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
+        pipe_result.draw_line_sort(
+            os.path.join(local_md_dir, f'{pdf_file_name}_line_sort.pdf')
+        )
 
 
-    md_content = pipe.pipe_mk_markdown(image_dir, drop_mode=DropMode.NONE, md_make_mode=f_make_md_mode)
     if f_dump_md:
     if f_dump_md:
-        md_writer.write_string(
+        pipe_result.dump_md(
+            md_writer,
             f'{pdf_file_name}.md',
             f'{pdf_file_name}.md',
-            md_content
+            image_dir,
+            drop_mode=DropMode.NONE,
+            md_make_mode=f_make_md_mode,
         )
         )
 
 
     if f_dump_middle_json:
     if f_dump_middle_json:
-        md_writer.write_string(
-            f'{pdf_file_name}_middle.json',
-            json_parse.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
-        )
+        pipe_result.dump_middle_json(md_writer, f'{pdf_file_name}_middle.json')
 
 
     if f_dump_model_json:
     if f_dump_model_json:
-        md_writer.write_string(
-            f'{pdf_file_name}_model.json',
-            json_parse.dumps(orig_model_list, ensure_ascii=False, indent=4)
-        )
+        infer_result.dump_model(md_writer, f'{pdf_file_name}_model.json')
 
 
     if f_dump_orig_pdf:
     if f_dump_orig_pdf:
         md_writer.write(
         md_writer.write(
@@ -173,11 +222,11 @@ def do_parse(
             pdf_bytes,
             pdf_bytes,
         )
         )
 
 
-    content_list = pipe.pipe_mk_uni_format(image_dir, drop_mode=DropMode.NONE)
     if f_dump_content_list:
     if f_dump_content_list:
-        md_writer.write_string(
+        pipe_result.dump_content_list(
+            md_writer,
             f'{pdf_file_name}_content_list.json',
             f'{pdf_file_name}_content_list.json',
-            json_parse.dumps(content_list, ensure_ascii=False, indent=4)
+            image_dir
         )
         )
 
 
     logger.info(f'local output dir is {local_md_dir}')
     logger.info(f'local output dir is {local_md_dir}')

+ 47 - 24
magic_pdf/user_api.py

@@ -10,22 +10,29 @@
 from loguru import logger
 from loguru import logger
 
 
 from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.version import __version__
 from magic_pdf.libs.version import __version__
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
 from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
 from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
 from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
-
-PARSE_TYPE_TXT = 'txt'
-PARSE_TYPE_OCR = 'ocr'
-
-
-def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False,
-                  start_page_id=0, end_page_id=None, lang=None,
-                  *args, **kwargs):
+from magic_pdf.config.constants import PARSE_TYPE_TXT, PARSE_TYPE_OCR
+
+
+def parse_txt_pdf(
+    dataset: Dataset,
+    model_list: list,
+    imageWriter: DataWriter,
+    is_debug=False,
+    start_page_id=0,
+    end_page_id=None,
+    lang=None,
+    *args,
+    **kwargs
+):
     """解析文本类pdf."""
     """解析文本类pdf."""
     pdf_info_dict = parse_pdf_by_txt(
     pdf_info_dict = parse_pdf_by_txt(
-        pdf_bytes,
-        pdf_models,
+        dataset,
+        model_list,
         imageWriter,
         imageWriter,
         start_page_id=start_page_id,
         start_page_id=start_page_id,
         end_page_id=end_page_id,
         end_page_id=end_page_id,
@@ -43,13 +50,21 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i
     return pdf_info_dict
     return pdf_info_dict
 
 
 
 
-def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False,
-                  start_page_id=0, end_page_id=None, lang=None,
-                  *args, **kwargs):
+def parse_ocr_pdf(
+    dataset: Dataset,
+    model_list: list,
+    imageWriter: DataWriter,
+    is_debug=False,
+    start_page_id=0,
+    end_page_id=None,
+    lang=None,
+    *args,
+    **kwargs
+):
     """解析ocr类pdf."""
     """解析ocr类pdf."""
     pdf_info_dict = parse_pdf_by_ocr(
     pdf_info_dict = parse_pdf_by_ocr(
-        pdf_bytes,
-        pdf_models,
+        dataset,
+        model_list,
         imageWriter,
         imageWriter,
         start_page_id=start_page_id,
         start_page_id=start_page_id,
         end_page_id=end_page_id,
         end_page_id=end_page_id,
@@ -67,17 +82,24 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, i
     return pdf_info_dict
     return pdf_info_dict
 
 
 
 
-def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False,
-                    input_model_is_empty: bool = False,
-                    start_page_id=0, end_page_id=None, lang=None,
-                    *args, **kwargs):
+def parse_union_pdf(
+    dataset: Dataset,
+    model_list: list,
+    imageWriter: DataWriter,
+    is_debug=False,
+    start_page_id=0,
+    end_page_id=None,
+    lang=None,
+    *args,
+    **kwargs
+):
     """ocr和文本混合的pdf,全部解析出来."""
     """ocr和文本混合的pdf,全部解析出来."""
 
 
     def parse_pdf(method):
     def parse_pdf(method):
         try:
         try:
             return method(
             return method(
-                pdf_bytes,
-                pdf_models,
+                dataset,
+                model_list,
                 imageWriter,
                 imageWriter,
                 start_page_id=start_page_id,
                 start_page_id=start_page_id,
                 end_page_id=end_page_id,
                 end_page_id=end_page_id,
@@ -91,12 +113,12 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter,
     pdf_info_dict = parse_pdf(parse_pdf_by_txt)
     pdf_info_dict = parse_pdf(parse_pdf_by_txt)
     if pdf_info_dict is None or pdf_info_dict.get('_need_drop', False):
     if pdf_info_dict is None or pdf_info_dict.get('_need_drop', False):
         logger.warning('parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr')
         logger.warning('parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr')
-        if input_model_is_empty:
+        if len(model_list) == 0:
             layout_model = kwargs.get('layout_model', None)
             layout_model = kwargs.get('layout_model', None)
             formula_enable = kwargs.get('formula_enable', None)
             formula_enable = kwargs.get('formula_enable', None)
             table_enable = kwargs.get('table_enable', None)
             table_enable = kwargs.get('table_enable', None)
-            pdf_models = doc_analyze(
-                pdf_bytes,
+            infer_res = doc_analyze(
+                dataset,
                 ocr=True,
                 ocr=True,
                 start_page_id=start_page_id,
                 start_page_id=start_page_id,
                 end_page_id=end_page_id,
                 end_page_id=end_page_id,
@@ -105,6 +127,7 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter,
                 formula_enable=formula_enable,
                 formula_enable=formula_enable,
                 table_enable=table_enable,
                 table_enable=table_enable,
             )
             )
+            model_list = infer_res.get_infer_res()
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         if pdf_info_dict is None:
         if pdf_info_dict is None:
             raise Exception('Both parse_pdf_by_txt and parse_pdf_by_ocr failed.')
             raise Exception('Both parse_pdf_by_txt and parse_pdf_by_ocr failed.')

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 3 - 0
next_docs/en/_static/image/pipeline.drawio.svg


+ 2 - 0
next_docs/en/api.rst

@@ -7,3 +7,5 @@
    api/read_api
    api/read_api
    api/schemas
    api/schemas
    api/io
    api/io
+   api/pipe_operators
+   api/model_operators

+ 8 - 0
next_docs/en/api/model_operators.rst

@@ -0,0 +1,8 @@
+
+Model Api
+==========
+
+.. autoclass:: magic_pdf.model.InferenceResultBase
+   :members:
+   :inherited-members:
+   :show-inheritance:

+ 9 - 0
next_docs/en/api/pipe_operators.rst

@@ -0,0 +1,9 @@
+
+
+Pipeline Api
+=============
+
+.. autoclass:: magic_pdf.pipe.operators.PipeResult
+   :members:
+   :inherited-members:
+   :show-inheritance:

+ 1 - 1
next_docs/en/conf.py

@@ -114,7 +114,7 @@ autodoc_mock_imports = [
     'sentencepiece',
     'sentencepiece',
     'vllm.cuda_utils',
     'vllm.cuda_utils',
     'vllm._C',
     'vllm._C',
-    'numpy',
+    # 'numpy',
     'tqdm',
     'tqdm',
 ]
 ]
 
 

+ 52 - 38
next_docs/en/user_guide/quick_start/to_markdown.rst

@@ -12,17 +12,17 @@ Local File Example
     import os
     import os
 
 
     from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
     from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
-    from magic_pdf.config.make_content_config import DropMode, MakeMode
-    from magic_pdf.pipe.OCRPipe import OCRPipe
+    from magic_pdf.data.dataset import PymuDocDataset
+    from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 
 
-
-    ## args
-    model_list = []
+    # args
     pdf_file_name = "abc.pdf"  # replace with the real pdf path
     pdf_file_name = "abc.pdf"  # replace with the real pdf path
+    name_without_suff = pdf_file_name.split(".")[0]
 
 
-
-    ## prepare env
+    # prepare env
     local_image_dir, local_md_dir = "output/images", "output"
     local_image_dir, local_md_dir = "output/images", "output"
+    image_dir = str(os.path.basename(local_image_dir))
+
     os.makedirs(local_image_dir, exist_ok=True)
     os.makedirs(local_image_dir, exist_ok=True)
 
 
     image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
     image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
@@ -30,27 +30,31 @@ Local File Example
     )
     )
     image_dir = str(os.path.basename(local_image_dir))
     image_dir = str(os.path.basename(local_image_dir))
 
 
+    # read bytes
     reader1 = FileBasedDataReader("")
     reader1 = FileBasedDataReader("")
-    pdf_bytes = reader1.read(pdf_file_name)   # read the pdf content
+    pdf_bytes = reader1.read(pdf_file_name)  # read the pdf content
 
 
+    # proc
+    ## Create Dataset Instance
+    ds = PymuDocDataset(pdf_bytes)
 
 
-    pipe = OCRPipe(pdf_bytes, model_list, image_writer)
+    ## inference 
+    infer_result = ds.apply(doc_analyze, ocr=True)
 
 
-    pipe.pipe_classify()
-    pipe.pipe_analyze()
-    pipe.pipe_parse()
+    ### draw model result on each page
+    infer_result.draw_model(os.path.join(local_md_dir, f"{name_without_suff}_model.pdf"))
 
 
-    pdf_info = pipe.pdf_mid_data["pdf_info"]
+    ## pipeline
+    pipe_result = infer_result.pipe_ocr_mode(image_writer)
 
 
+    ### draw layout result on each page
+    pipe_result.draw_layout(os.path.join(local_md_dir, f"{name_without_suff}_layout.pdf"))
 
 
-    md_content = pipe.pipe_mk_markdown(
-        image_dir, drop_mode=DropMode.NONE, md_make_mode=MakeMode.MM_MD
-    )
+    ### draw spans result on each page
+    pipe_result.draw_span(os.path.join(local_md_dir, f"{name_without_suff}_spans.pdf"))
 
 
-    if isinstance(md_content, list):
-        md_writer.write_string(f"{pdf_file_name}.md", "\n".join(md_content))
-    else:
-        md_writer.write_string(f"{pdf_file_name}.md", md_content)
+    ### dump markdown
+    pipe_result.dump_md(md_writer, f"{name_without_suff}.md", image_dir)
 
 
 
 
 S3 File Example
 S3 File Example
@@ -61,8 +65,8 @@ S3 File Example
     import os
     import os
 
 
     from magic_pdf.data.data_reader_writer import S3DataReader, S3DataWriter
     from magic_pdf.data.data_reader_writer import S3DataReader, S3DataWriter
-    from magic_pdf.config.make_content_config import DropMode, MakeMode
-    from magic_pdf.pipe.OCRPipe import OCRPipe
+    from magic_pdf.data.dataset import PymuDocDataset
+    from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 
 
     bucket_name = "{Your S3 Bucket Name}"  # replace with real bucket name
     bucket_name = "{Your S3 Bucket Name}"  # replace with real bucket name
     ak = "{Your S3 access key}"  # replace with real s3 access key
     ak = "{Your S3 access key}"  # replace with real s3 access key
@@ -74,29 +78,39 @@ S3 File Example
     writer = S3DataWriter('unittest/tmp', bucket_name, ak, sk, endpoint_url)
     writer = S3DataWriter('unittest/tmp', bucket_name, ak, sk, endpoint_url)
     image_writer = S3DataWriter('unittest/tmp/images', bucket_name, ak, sk, endpoint_url)
     image_writer = S3DataWriter('unittest/tmp/images', bucket_name, ak, sk, endpoint_url)
 
 
-    ## args
-    model_list = []
-    pdf_file_name = f"s3://{bucket_name}/{fake pdf path}"  # replace with the real s3 path
+    # args
+    pdf_file_name = (
+        "s3://llm-pdf-text-1/unittest/tmp/bug5-11.pdf"  # replace with the real s3 path
+    )
+
+    # prepare env
+    local_dir = "output"
+    name_without_suff = os.path.basename(pdf_file_name).split(".")[0]
 
 
+    # read bytes
     pdf_bytes = reader.read(pdf_file_name)  # read the pdf content
     pdf_bytes = reader.read(pdf_file_name)  # read the pdf content
 
 
+    # proc
+    ## Create Dataset Instance
+    ds = PymuDocDataset(pdf_bytes)
 
 
-    pipe = OCRPipe(pdf_bytes, model_list, image_writer)
+    ## inference 
+    infer_result = ds.apply(doc_analyze, ocr=True)
 
 
-    pipe.pipe_classify()
-    pipe.pipe_analyze()
-    pipe.pipe_parse()
+    ### draw model result on each page
+    infer_result.draw_model(os.path.join(local_dir, f'{name_without_suff}_model.pdf'))  # dump to local
 
 
-    pdf_info = pipe.pdf_mid_data["pdf_info"]
+    ## pipeline
+    pipe_result = infer_result.pipe_ocr_mode(image_writer)
 
 
-    md_content = pipe.pipe_mk_markdown(
-        "unittest/tmp/images", drop_mode=DropMode.NONE, md_make_mode=MakeMode.MM_MD
-    )
+    ### draw layout result on each page
+    pipe_result.draw_layout(os.path.join(local_dir, f'{name_without_suff}_layout.pdf'))  # dump to local
+
+    ### draw spans result on each page
+    pipe_result.draw_span(os.path.join(local_dir, f'{name_without_suff}_spans.pdf'))   # dump to local 
 
 
-    if isinstance(md_content, list):
-        writer.write_string(f"{pdf_file_name}.md", "\n".join(md_content))
-    else:
-        writer.write_string(f"{pdf_file_name}.md", md_content)
+    ### dump markdown
+    pipe_result.dump_md(writer, f'{name_without_suff}.md', "unittest/tmp/images")    # dump to remote s3
 
 
 
 
-Check :doc:`../data/data_reader_writer` for more [reader | writer] examples
+Check :doc:`../data/data_reader_writer` for more [reader | writer] examples and check :doc:`../../api/pipe_operators` or :doc:`../../api/model_operators` for api details

+ 3 - 1
next_docs/en/user_guide/tutorial.rst

@@ -7,4 +7,6 @@ From the beginning to the end, Show how to using mineru via a minimal project
 .. toctree::
 .. toctree::
     :maxdepth: 1
     :maxdepth: 1
 
 
-    tutorial/output_file_description
+    tutorial/output_file_description
+    tutorial/pipeline
+

+ 185 - 0
next_docs/en/user_guide/tutorial/pipeline.rst

@@ -0,0 +1,185 @@
+
+
+Pipeline
+==========
+
+
+Minimal Example 
+^^^^^^^^^^^^^^^^^
+
+.. code:: python
+
+    import os
+
+    from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
+    from magic_pdf.data.dataset import PymuDocDataset
+    from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
+
+    # args
+    pdf_file_name = "abc.pdf"  # replace with the real pdf path
+    name_without_suff = pdf_file_name.split(".")[0]
+
+    # prepare env
+    local_image_dir, local_md_dir = "output/images", "output"
+    image_dir = str(os.path.basename(local_image_dir))
+
+    os.makedirs(local_image_dir, exist_ok=True)
+
+    image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
+        local_md_dir
+    )
+    image_dir = str(os.path.basename(local_image_dir))
+
+    # read bytes
+    reader1 = FileBasedDataReader("")
+    pdf_bytes = reader1.read(pdf_file_name)  # read the pdf content
+
+    # proc
+    ## Create Dataset Instance
+    ds = PymuDocDataset(pdf_bytes)
+
+    ds.apply(doc_analyze, ocr=True).pipe_ocr_mode(image_writer).dump_md(md_writer, f"{name_without_suff}.md", image_dir)
+
+Running the above code will result in the following
+
+
+.. code:: bash 
+
+    output/
+    ├── abc.md
+    └── images
+
+
+Excluding the setup of the environment, such as creating directories and importing dependencies, the actual code snippet for converting pdf to markdown is as follows
+
+
+.. code:: python 
+
+    # read bytes
+    reader1 = FileBasedDataReader("")
+    pdf_bytes = reader1.read(pdf_file_name)  # read the pdf content
+
+    # proc
+    ## Create Dataset Instance
+    ds = PymuDocDataset(pdf_bytes)
+
+    ds.apply(doc_analyze, ocr=True).pipe_ocr_mode(image_writer).dump_md(md_writer, f"{name_without_suff}.md", image_dir)
+
+``ds.apply(doc_analyze, ocr=True)`` generates an ``InferenceResult`` object. The ``InferenceResult`` object, when executing the ``pipe_ocr_mode`` method, produces a ``PipeResult`` object.
+The ``PipeResult`` object, upon executing ``dump_md``, generates a ``markdown`` file at the specified location.
+
+
+The pipeline execution process is illustrated in the following diagram
+
+
+.. image:: ../../_static/image/pipeline.drawio.svg 
+
+.. raw:: html
+
+    <br> </br>
+
+Currently, the process is divided into three stages: data, inference, and processing, which correspond to the ``Dataset``, ``InferenceResult``, and ``PipeResult`` entities in the diagram.
+These stages are linked together through methods like ``apply``, ``doc_analyze``, or ``pipe_ocr_mode``
+
+
+.. admonition:: Tip
+    :class: tip
+
+    For more examples on how to use ``Dataset``, ``InferenceResult``, and ``PipeResult``, please refer to :doc:`../quick_start/to_markdown`
+
+    For more detailed information about ``Dataset``, ``InferenceResult``, and ``PipeResult``, please refer to :doc:`../../api/dataset`, :doc:`../../api/model_operators`, :doc:`../../api/pipe_operators`
+
+
+Pipeline Composition
+^^^^^^^^^^^^^^^^^^^^^
+
+.. code:: python 
+
+    class Dataset(ABC):
+        @abstractmethod
+        def apply(self, proc: Callable, *args, **kwargs):
+            """Apply callable method which.
+
+            Args:
+                proc (Callable): invoke proc as follows:
+                    proc(self, *args, **kwargs)
+
+            Returns:
+                Any: return the result generated by proc
+            """
+            pass
+
+    class InferenceResult(InferenceResultBase):
+
+        def apply(self, proc: Callable, *args, **kwargs):
+            """Apply callable method which.
+
+            Args:
+                proc (Callable): invoke proc as follows:
+                    proc(inference_result, *args, **kwargs)
+
+            Returns:
+                Any: return the result generated by proc
+            """
+            return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
+
+        def pipe_ocr_mode(
+            self,
+            imageWriter: DataWriter,
+            start_page_id=0,
+            end_page_id=None,
+            debug_mode=False,
+            lang=None,
+            ) -> PipeResult:
+            pass
+
+    class PipeResult:
+        def apply(self, proc: Callable, *args, **kwargs):
+            """Apply callable method which.
+
+            Args:
+                proc (Callable): invoke proc as follows:
+                    proc(pipeline_result, *args, **kwargs)
+
+            Returns:
+                Any: return the result generated by proc
+            """
+            return proc(copy.deepcopy(self._pipe_res), *args, **kwargs)
+
+
+The ``Dataset``, ``InferenceResult``, and ``PipeResult`` classes all have an ``apply`` method, which can be used to chain different stages of the computation. 
+As shown below, ``MinerU`` provides a set of methods to compose these classes.
+
+
+.. code:: python 
+
+    # proc
+    ## Create Dataset Instance
+    ds = PymuDocDataset(pdf_bytes)
+
+    ds.apply(doc_analyze, ocr=True).pipe_ocr_mode(image_writer).dump_md(md_writer, f"{name_without_suff}.md", image_dir)
+
+
+Users can implement their own functions for chaining as needed. For example, a user could use the ``apply`` method to create a function that counts the number of pages in a ``pdf`` file.
+
+
+.. code:: python
+
+    from magic_pdf.data.data_reader_writer import  FileBasedDataReader
+    from magic_pdf.data.dataset import PymuDocDataset
+
+    # args
+    pdf_file_name = "abc.pdf"  # replace with the real pdf path
+
+    # read bytes
+    reader1 = FileBasedDataReader("")
+    pdf_bytes = reader1.read(pdf_file_name)  # read the pdf content
+
+    # proc
+    ## Create Dataset Instance
+    ds = PymuDocDataset(pdf_bytes)
+
+    def count_page(ds)-> int:
+        return len(ds)
+
+    print("page number: ", ds.apply(count_page)) # will output the page count of `abc.pdf`

+ 5 - 1
next_docs/requirements.txt

@@ -1,3 +1,7 @@
+numpy==1.26.4
+click==8.1.7
+fast-langdetect==0.2.2
+Brotli==1.1.0
 boto3>=1.28.43
 boto3>=1.28.43
 loguru>=0.6.0
 loguru>=0.6.0
 myst-parser
 myst-parser
@@ -9,4 +13,4 @@ sphinx-argparse>=0.5.2
 sphinx-book-theme>=1.1.3
 sphinx-book-theme>=1.1.3
 sphinx-copybutton>=0.5.2
 sphinx-copybutton>=0.5.2
 sphinx_rtd_theme>=3.0.1
 sphinx_rtd_theme>=3.0.1
-autodoc_pydantic>=2.2.0
+autodoc_pydantic>=2.2.0

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 3 - 0
next_docs/zh_cn/_static/image/pipeline.drawio.svg


+ 53 - 42
next_docs/zh_cn/user_guide/quick_start/to_markdown.rst

@@ -1,28 +1,26 @@
 
 
-
 转换为 Markdown 文件
 转换为 Markdown 文件
 ========================
 ========================
 
 
-
 本地文件示例
 本地文件示例
-^^^^^^^^^^^
+^^^^^^^^^^^^^^^^^^
 
 
 .. code:: python
 .. code:: python
 
 
     import os
     import os
 
 
     from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
     from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
-    from magic_pdf.config.make_content_config import DropMode, MakeMode
-    from magic_pdf.pipe.OCRPipe import OCRPipe
-
+    from magic_pdf.data.dataset import PymuDocDataset
+    from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 
 
-    ## args
-    model_list = []
+    # args
     pdf_file_name = "abc.pdf"  # replace with the real pdf path
     pdf_file_name = "abc.pdf"  # replace with the real pdf path
+    name_without_suff = pdf_file_name.split(".")[0]
 
 
-
-    ## prepare env
+    # prepare env
     local_image_dir, local_md_dir = "output/images", "output"
     local_image_dir, local_md_dir = "output/images", "output"
+    image_dir = str(os.path.basename(local_image_dir))
+
     os.makedirs(local_image_dir, exist_ok=True)
     os.makedirs(local_image_dir, exist_ok=True)
 
 
     image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
     image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
@@ -30,39 +28,43 @@
     )
     )
     image_dir = str(os.path.basename(local_image_dir))
     image_dir = str(os.path.basename(local_image_dir))
 
 
+    # read bytes
     reader1 = FileBasedDataReader("")
     reader1 = FileBasedDataReader("")
-    pdf_bytes = reader1.read(pdf_file_name)   # read the pdf content
+    pdf_bytes = reader1.read(pdf_file_name)  # read the pdf content
 
 
+    # proc
+    ## Create Dataset Instance
+    ds = PymuDocDataset(pdf_bytes)
 
 
-    pipe = OCRPipe(pdf_bytes, model_list, image_writer)
+    ## inference 
+    infer_result = ds.apply(doc_analyze, ocr=True)
 
 
-    pipe.pipe_classify()
-    pipe.pipe_analyze()
-    pipe.pipe_parse()
+    ### draw model result on each page
+    infer_result.draw_model(os.path.join(local_md_dir, f"{name_without_suff}_model.pdf"))
 
 
-    pdf_info = pipe.pdf_mid_data["pdf_info"]
+    ## pipeline
+    pipe_result = infer_result.pipe_ocr_mode(image_writer)
 
 
+    ### draw layout result on each page
+    pipe_result.draw_layout(os.path.join(local_md_dir, f"{name_without_suff}_layout.pdf"))
 
 
-    md_content = pipe.pipe_mk_markdown(
-        image_dir, drop_mode=DropMode.NONE, md_make_mode=MakeMode.MM_MD
-    )
+    ### draw spans result on each page
+    pipe_result.draw_span(os.path.join(local_md_dir, f"{name_without_suff}_spans.pdf"))
 
 
-    if isinstance(md_content, list):
-        md_writer.write_string(f"{pdf_file_name}.md", "\n".join(md_content))
-    else:
-        md_writer.write_string(f"{pdf_file_name}.md", md_content)
+    ### dump markdown
+    pipe_result.dump_md(md_writer, f"{name_without_suff}.md", image_dir)
 
 
 
 
-对象存储使用示例
-^^^^^^^^^^^^^^^
+对象存储文件示例
+^^^^^^^^^^^^^^^^
 
 
 .. code:: python
 .. code:: python
 
 
     import os
     import os
 
 
     from magic_pdf.data.data_reader_writer import S3DataReader, S3DataWriter
     from magic_pdf.data.data_reader_writer import S3DataReader, S3DataWriter
-    from magic_pdf.config.make_content_config import DropMode, MakeMode
-    from magic_pdf.pipe.OCRPipe import OCRPipe
+    from magic_pdf.data.dataset import PymuDocDataset
+    from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 
 
     bucket_name = "{Your S3 Bucket Name}"  # replace with real bucket name
     bucket_name = "{Your S3 Bucket Name}"  # replace with real bucket name
     ak = "{Your S3 access key}"  # replace with real s3 access key
     ak = "{Your S3 access key}"  # replace with real s3 access key
@@ -74,30 +76,39 @@
     writer = S3DataWriter('unittest/tmp', bucket_name, ak, sk, endpoint_url)
     writer = S3DataWriter('unittest/tmp', bucket_name, ak, sk, endpoint_url)
     image_writer = S3DataWriter('unittest/tmp/images', bucket_name, ak, sk, endpoint_url)
     image_writer = S3DataWriter('unittest/tmp/images', bucket_name, ak, sk, endpoint_url)
 
 
-    ## args
-    model_list = []
-    pdf_file_name = f"s3://{bucket_name}/{fake pdf path}"  # replace with the real s3 path
+    # args
+    pdf_file_name = (
+        "s3://llm-pdf-text-1/unittest/tmp/bug5-11.pdf"  # replace with the real s3 path
+    )
 
 
+    # prepare env
+    local_dir = "output"
+    name_without_suff = os.path.basename(pdf_file_name).split(".")[0]
+
+    # read bytes
     pdf_bytes = reader.read(pdf_file_name)  # read the pdf content
     pdf_bytes = reader.read(pdf_file_name)  # read the pdf content
 
 
+    # proc
+    ## Create Dataset Instance
+    ds = PymuDocDataset(pdf_bytes)
 
 
-    pipe = OCRPipe(pdf_bytes, model_list, image_writer)
+    ## inference 
+    infer_result = ds.apply(doc_analyze, ocr=True)
 
 
-    pipe.pipe_classify()
-    pipe.pipe_analyze()
-    pipe.pipe_parse()
+    ### draw model result on each page
+    infer_result.draw_model(os.path.join(local_dir, f'{name_without_suff}_model.pdf'))  # dump to local
 
 
-    pdf_info = pipe.pdf_mid_data["pdf_info"]
+    ## pipeline
+    pipe_result = infer_result.pipe_ocr_mode(image_writer)
 
 
-    md_content = pipe.pipe_mk_markdown(
-        "unittest/tmp/images", drop_mode=DropMode.NONE, md_make_mode=MakeMode.MM_MD
-    )
+    ### draw layout result on each page
+    pipe_result.draw_layout(os.path.join(local_dir, f'{name_without_suff}_layout.pdf'))  # dump to local
 
 
-    if isinstance(md_content, list):
-        writer.write_string(f"{pdf_file_name}.md", "\n".join(md_content))
-    else:
-        writer.write_string(f"{pdf_file_name}.md", md_content)
+    ### draw spans result on each page
+    pipe_result.draw_span(os.path.join(local_dir, f'{name_without_suff}_spans.pdf'))  # dump to local 
 
 
+    ### dump markdown
+    pipe_result.dump_md(writer, f'{name_without_suff}.md', "unittest/tmp/images")  # dump to remote s3
 
 
 
 
 前去 :doc:`../data/data_reader_writer` 获取更多有关 **读写** 示例
 前去 :doc:`../data/data_reader_writer` 获取更多有关 **读写** 示例

+ 2 - 0
next_docs/zh_cn/user_guide/tutorial.rst

@@ -9,3 +9,5 @@
     :caption: 教程
     :caption: 教程
 
 
     tutorial/output_file_description
     tutorial/output_file_description
+    tutorial/pipeline
+

+ 59 - 66
next_docs/zh_cn/user_guide/tutorial/output_file_description.rst

@@ -137,49 +137,45 @@ poly 坐标的格式 [x0, y0, x1, y1, x2, y2, x3, y3],
 some_pdf_middle.json
 some_pdf_middle.json
 ~~~~~~~~~~~~~~~~~~~~
 ~~~~~~~~~~~~~~~~~~~~
 
 
-+-----------+----------------------------------------------------------+
-| 字段名    | 解释                                                     |
-+===========+==========================================================+
-| pdf_info  | list,每个                                               |
-|           | 元素都是一个dict,这个dict是每一页pdf的解析结果,详见下表 |
-+-----------+----------------------------------------------------------+
-|              | ocr \| txt,用来标识本次解析的中间态使用的模式           |
-| \_parse_type |                                                          |
-+-----------+----------------------------------------------------------+
-|                | string, 表示本次解析使用的 magic-pdf 的版本号            |
-| \_version_name |                                                          |
-+-----------+----------------------------------------------------------+
++--------------------+----------------------------------------------------------+
+| 字段名              | 解释                                                    |
++====================+==========================================================+
+| pdf_info           | list,每个元素都是一个                                   |
+|                    | dict,这个dict是每一页pdf的解析结果,详见下表            |
++--------------------+----------------------------------------------------------+
+| \_parse_type       | ocr \| txt,用来标识本次解析的中间态使用的模式           |
++--------------------+----------------------------------------------------------+
+| \_version_name     | string,表示本次解析使用的 magic-pdf 的版本号            |
++-------------------------------------------------------------------------------+
 
 
 **pdf_info** 字段结构说明
 **pdf_info** 字段结构说明
 
 
-+--------------+-------------------------------------------------------+
-| 字段名       | 解释                                                  |
-+==============+=======================================================+
-|                 | pdf预处理后,未分段的中间结果                         |
-| preeproc_blocks |                                                       |
-+--------------+-------------------------------------------------------+
-|               | 布局分割的结果,                                      |
-| layout_bboxes | 含有布局的方向(垂直、水平),和bbox,按阅读顺序排序  |
-+--------------+-------------------------------------------------------+
-| page_idx     | 页码,从0开始                                         |
-+--------------+-------------------------------------------------------+
-| page_size    | 页面的宽度和高度                                      |
-+--------------+-------------------------------------------------------+
-| \            | 布局树状结构                                          |
-| _layout_tree |                                                       |
-+--------------+-------------------------------------------------------+
-| images       | list,每个元素是一个dict,每个dict表示一个img_block   |
-+--------------+-------------------------------------------------------+
-| tables       | list,每个元素是一个dict,每个dict表示一个table_block |
-+--------------+-------------------------------------------------------+
-|                     | list,每个元素                                        |
-| interline_equations | 是一个dict,每个dict表示一个interline_equation_block  |
-+--------------+-------------------------------------------------------+
-|                  | List, 模型返回的需要drop的block信息                   |
-| discarded_blocks |                                                       |
-+--------------+-------------------------------------------------------+
-| para_blocks  | 将preproc_blocks进行分段之后的结果                    |
-+--------------+-------------------------------------------------------+
++---------------------+-------------------------------------------------------+
+| 字段名               | 解释                                                 |
++=====================+=======================================================+
+| preproc_blocks      | pdf预处理后,未分段的中间结果                         |
++---------------------+-------------------------------------------------------+
+|                     | 布局分割的结果,                                      |
+| layout_bboxes       | 含有布局的方向(垂直、水平),和bbox,按阅读顺序排序  |
++---------------------+-------------------------------------------------------+
+| page_idx            | 页码,从0开始                                         |
++---------------------+-------------------------------------------------------+
+| page_size           | 页面的宽度和高度                                      |
++---------------------+-------------------------------------------------------+
+| \_layout_tree       | 布局树状结构                                          |
++---------------------+-------------------------------------------------------+
+| images              | list,每个元素是一个dict,每个dict表示一个img_block   |
++---------------------+-------------------------------------------------------+
+| tables              | list,每个元素是一个dict,每个dict表示一个table_block |
++---------------------+-------------------------------------------------------+
+|                     | list,每个元素是一个                                  |
+| interline_equations | dict,每个dict表示一个interline_equation_block        |
++---------------------+-------------------------------------------------------+
+|                     | List, 模型返回的需要drop的block信息                   |
+| discarded_blocks    |                                                       |
++---------------------+-------------------------------------------------------+
+| para_blocks         | 将preproc_blocks进行分段之后的结果                    |
++---------------------+-------------------------------------------------------+
 
 
 上表中 ``para_blocks``
 上表中 ``para_blocks``
 是个dict的数组,每个dict是一个block结构,block最多支持一次嵌套
 是个dict的数组,每个dict是一个block结构,block最多支持一次嵌套
@@ -200,20 +196,18 @@ blocks list,里面的每个元素都是一个dict格式的二级block
 
 
 二级block中的字段包括
 二级block中的字段包括
 
 
-+-----+----------------------------------------------------------------+
-| 字  | 解释                                                           |
-| 段  |                                                                |
-| 名  |                                                                |
-+=====+================================================================+
-|      | block类型                                                      |
-| type |                                                                |
-+-----+----------------------------------------------------------------+
-|      | block矩形框坐标                                                |
-| bbox |                                                                |
-+-----+----------------------------------------------------------------+
-|       | list,每个元素都是一个dict表示的line,用来描述一行信息的构成   |
-| lines |                                                                |
-+-----+----------------------------------------------------------------+
++----------+----------------------------------------------------------------+
+| 字       | 解释                                                           |
+| 段       |                                                                |
+| 名       |                                                                |
++==========+================================================================+
+|          | block类型                                                      |
+| type     |                                                                |
++----------+----------------------------------------------------------------+
+| bbox     | block矩形框坐标                                                |
++----------+----------------------------------------------------------------+
+| lines    | list,每个元素都是一个dict表示的line,用来描述一行信息的构成   |
++----------+----------------------------------------------------------------+
 
 
 二级block的类型详解
 二级block的类型详解
 
 
@@ -237,22 +231,21 @@ interline_equation 行间公式块
 
 
 line 的 字段格式如下
 line 的 字段格式如下
 
 
-+----+-----------------------------------------------------------------+
-| 字 | 解释                                                            |
-| 段 |                                                                 |
-| 名 |                                                                 |
-+====+=================================================================+
-| bbox  | line的矩形框坐标                                                |
-|       |                                                                 |
-+----+-----------------------------------------------------------------+
-| spans  | list,                                                       |
-|        | 每个元素都是一个dict表示的span,用来描述一个最小组成单元的构成  |
-+----+-----------------------------------------------------------------+
++-----------+-----------------------------------------------------------------+
+| 字        | 解释                                                            |
+| 段        |                                                                 |
+| 名        |                                                                 |
++===========+=================================================================+
+| bbox      | line的矩形框坐标                                                |
++-----------+-----------------------------------------------------------------+
+| spans     | list,                                                          |
+|           | 每个元素都是一个dict表示的span,用来描述一个最小组成单元的构成  |
++-----------+-----------------------------------------------------------------+
 
 
 **span**
 **span**
 
 
 +------------+---------------------------------------------------------+
 +------------+---------------------------------------------------------+
-| 字段名     | 解释                                                    |
+| 字段名      | 解释                                                   |
 +============+=========================================================+
 +============+=========================================================+
 | bbox       | span的矩形框坐标                                        |
 | bbox       | span的矩形框坐标                                        |
 +------------+---------------------------------------------------------+
 +------------+---------------------------------------------------------+

+ 179 - 0
next_docs/zh_cn/user_guide/tutorial/pipeline.rst

@@ -0,0 +1,179 @@
+
+流水线管道
+===========
+
+
+极简示例
+^^^^^^^^
+
+.. code:: python
+
+    import os
+
+    from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
+    from magic_pdf.data.dataset import PymuDocDataset
+    from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
+
+    # args
+    pdf_file_name = "abc.pdf"  # replace with the real pdf path
+    name_without_suff = pdf_file_name.split(".")[0]
+
+    # prepare env
+    local_image_dir, local_md_dir = "output/images", "output"
+    image_dir = str(os.path.basename(local_image_dir))
+
+    os.makedirs(local_image_dir, exist_ok=True)
+
+    image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(
+        local_md_dir
+    )
+    image_dir = str(os.path.basename(local_image_dir))
+
+    # read bytes
+    reader1 = FileBasedDataReader("")
+    pdf_bytes = reader1.read(pdf_file_name)  # read the pdf content
+
+    # proc
+    ## Create Dataset Instance
+    ds = PymuDocDataset(pdf_bytes)
+
+    ds.apply(doc_analyze, ocr=True).pipe_ocr_mode(image_writer).dump_md(md_writer, f"{name_without_suff}.md", image_dir)
+
+
+运行以上的代码,会得到如下的结果
+
+.. code:: bash 
+
+    output/
+    ├── abc.md
+    └── images
+
+
+除去初始化环境,如建立目录、导入依赖库等逻辑。真正将 ``pdf`` 转换为 ``markdown`` 的代码片段如下
+
+.. code::
+
+    # read bytes
+    reader1 = FileBasedDataReader("")
+    pdf_bytes = reader1.read(pdf_file_name)  # read the pdf content
+
+    # proc
+    ## Create Dataset Instance
+    ds = PymuDocDataset(pdf_bytes)
+
+    ds.apply(doc_analyze, ocr=True).pipe_ocr_mode(image_writer).dump_md(md_writer, f"{name_without_suff}.md", image_dir)
+
+
+``ds.apply(doc_analyze, ocr=True)`` 会生成 ``InferenceResult`` 对象。 ``InferenceResult`` 对象执行 ``pipe_ocr_mode`` 方法会生成 ``PipeResult`` 对象。
+``PipeResult`` 对象执行 ``dump_md`` 会在指定位置生成 ``markdown`` 文件。
+
+
+pipeline 的执行过程如下图所示
+
+.. image:: ../../_static/image/pipeline.drawio.svg 
+
+.. raw:: html 
+
+    <br> </br>
+
+目前划分出数据、推理、程序处理三个阶段,分别对应着图上的 ``Dataset``, ``InferenceResult``, ``PipeResult`` 这三个实体。通过 ``apply`` , ``doc_analyze`` 或 ``pipe_ocr_mode`` 等方法链接在一起。
+
+
+.. admonition:: Tip
+    :class: tip
+
+    要想获得更多有关 Dataset、InferenceResult、PipeResult 的使用示例子,请前往 :doc:`../quick_start/to_markdown`
+
+    要想获得更多有关 Dataset、InferenceResult、PipeResult 的细节信息请前往英文版 MinerU 文档进行查看!
+
+
+
+管道组合
+^^^^^^^^^
+
+.. code:: python
+
+    class Dataset(ABC):
+        @abstractmethod
+        def apply(self, proc: Callable, *args, **kwargs):
+            """Apply callable method which.
+
+            Args:
+                proc (Callable): invoke proc as follows:
+                    proc(self, *args, **kwargs)
+
+            Returns:
+                Any: return the result generated by proc
+            """
+            pass
+
+    class InferenceResult(InferenceResultBase):
+
+        def apply(self, proc: Callable, *args, **kwargs):
+            """Apply callable method which.
+
+            Args:
+                proc (Callable): invoke proc as follows:
+                    proc(inference_result, *args, **kwargs)
+
+            Returns:
+                Any: return the result generated by proc
+            """
+            return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
+
+        def pipe_ocr_mode(
+            self,
+            imageWriter: DataWriter,
+            start_page_id=0,
+            end_page_id=None,
+            debug_mode=False,
+            lang=None,
+            ) -> PipeResult:
+            pass
+
+    class PipeResult:
+        def apply(self, proc: Callable, *args, **kwargs):
+            """Apply callable method which.
+
+            Args:
+                proc (Callable): invoke proc as follows:
+                    proc(pipeline_result, *args, **kwargs)
+
+            Returns:
+                Any: return the result generated by proc
+            """
+            return proc(copy.deepcopy(self._pipe_res), *args, **kwargs)
+
+``Dataset`` 、 ``InferenceResult`` 和 ``PipeResult`` 类均有 ``apply`` method。可用于组合不同阶段的运算过程。
+如下所示,``MinerU`` 提供一套组合这些类的计算过程。
+
+.. code:: python 
+
+    # proc
+    ## Create Dataset Instance
+    ds = PymuDocDataset(pdf_bytes)
+
+    ds.apply(doc_analyze, ocr=True).pipe_ocr_mode(image_writer).dump_md(md_writer, f"{name_without_suff}.md", image_dir)
+
+用户可以根据的需求,自行实现一些组合用的函数。比如用户通过 ``apply`` 方法实现一个统计 ``pdf`` 文件页数的功能。
+
+.. code:: python 
+
+    from magic_pdf.data.data_reader_writer import  FileBasedDataReader
+    from magic_pdf.data.dataset import PymuDocDataset
+
+    # args
+    pdf_file_name = "abc.pdf"  # replace with the real pdf path
+
+    # read bytes
+    reader1 = FileBasedDataReader("")
+    pdf_bytes = reader1.read(pdf_file_name)  # read the pdf content
+
+    # proc
+    ## Create Dataset Instance
+    ds = PymuDocDataset(pdf_bytes)
+
+    def count_page(ds)-> int:
+        return len(ds)
+
+    print("page number: ", ds.apply(count_page)) # will output the page count of `abc.pdf`

+ 2 - 2
requirements-docker.txt

@@ -7,9 +7,9 @@ numpy>=1.21.6,<2.0.0
 fast-langdetect==0.2.0
 fast-langdetect==0.2.0
 scikit-learn>=1.0.2
 scikit-learn>=1.0.2
 pdfminer.six==20231228
 pdfminer.six==20231228
-unimernet==0.2.1
+unimernet==0.2.2
 matplotlib
 matplotlib
-ultralytics
+ultralytics>=8.3.48
 paddleocr==2.7.3
 paddleocr==2.7.3
 paddlepaddle==3.0.0b1
 paddlepaddle==3.0.0b1
 struct-eqtable==0.3.2
 struct-eqtable==0.3.2

+ 2 - 2
requirements.txt

@@ -7,7 +7,7 @@ numpy>=1.21.6,<2.0.0
 pydantic>=2.7.2,<2.8.0
 pydantic>=2.7.2,<2.8.0
 PyMuPDF>=1.24.9
 PyMuPDF>=1.24.9
 scikit-learn>=1.0.2
 scikit-learn>=1.0.2
-torch>=2.2.2,<=2.3.1
+torch>=2.2.2
 transformers
 transformers
-# pdfminer.six==20231228
+pdfminer.six==20231228
 # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.
 # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.

+ 4 - 2
setup.py

@@ -36,10 +36,12 @@ if __name__ == '__main__':
                      "paddlepaddle==3.0.0b1;platform_system=='Linux'",
                      "paddlepaddle==3.0.0b1;platform_system=='Linux'",
                      "paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'",
                      "paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'",
                      ],
                      ],
-            "full": ["unimernet==0.2.1",  # unimernet升级0.2.1
+            "full": ["unimernet==0.2.2",  # unimernet升级0.2.2,移除torchtext的依赖
+                     "torch>=2.2.2,<=2.3.1",  # torch2.4.0及之后版本未测试,先卡住版本上限
+                     "torchvision>=0.17.2,<=0.18.1",  # torchvision 受torch版本约束
                      "matplotlib<=3.9.0;platform_system=='Windows'",  # 3.9.1及之后不提供windows的预编译包,避免一些没有编译环境的windows设备安装失败
                      "matplotlib<=3.9.0;platform_system=='Windows'",  # 3.9.1及之后不提供windows的预编译包,避免一些没有编译环境的windows设备安装失败
                      "matplotlib;platform_system=='Linux' or platform_system=='Darwin'",  # linux 和 macos 不应限制matplotlib的最高版本,以避免无法更新导致的一些bug
                      "matplotlib;platform_system=='Linux' or platform_system=='Darwin'",  # linux 和 macos 不应限制matplotlib的最高版本,以避免无法更新导致的一些bug
-                     "ultralytics",  # yolov8,公式检测
+                     "ultralytics>=8.3.48",  # yolov8,公式检测
                      "paddleocr==2.7.3",  # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3
                      "paddleocr==2.7.3",  # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3
                      "paddlepaddle==3.0.0b1;platform_system=='Linux'",  # 解决linux的段异常问题
                      "paddlepaddle==3.0.0b1;platform_system=='Linux'",  # 解决linux的段异常问题
                      "paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'",  # windows版本3.0.0b1效率下降,需锁定2.6.1
                      "paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'",  # windows版本3.0.0b1效率下降,需锁定2.6.1

+ 16 - 11
tests/test_cli/test_cli_sdk.py

@@ -7,8 +7,11 @@ from lib import common
 import time
 import time
 import magic_pdf.model as model_config
 import magic_pdf.model as model_config
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.pipe.UNIPipe import UNIPipe
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
-from magic_pdf.rw.S3ReaderWriter import S3ReaderWriter
+import os
+from magic_pdf.data.data_reader_writer import FileBasedDataWriter
+from magic_pdf.data.data_reader_writer import S3DataReader, S3DataWriter
+from magic_pdf.config.make_content_config import DropMode, MakeMode
+from magic_pdf.pipe.OCRPipe import OCRPipe
 model_config.__use_inside_model__ = True
 model_config.__use_inside_model__ = True
 pdf_res_path = conf.conf['pdf_res_path']
 pdf_res_path = conf.conf['pdf_res_path']
 code_path = conf.conf['code_path']
 code_path = conf.conf['code_path']
@@ -41,7 +44,7 @@ class TestCli:
             pdf_bytes = open(pdf_path, 'rb').read()
             pdf_bytes = open(pdf_path, 'rb').read()
             local_image_dir = os.path.join(pdf_dev_path, 'pdf', 'images')
             local_image_dir = os.path.join(pdf_dev_path, 'pdf', 'images')
             image_dir = str(os.path.basename(local_image_dir))
             image_dir = str(os.path.basename(local_image_dir))
-            image_writer = DiskReaderWriter(local_image_dir)
+            image_writer = FileBasedDataWriter(local_image_dir)
             model_json = list()
             model_json = list()
             jso_useful_key = {'_pdf_type': '', 'model_list': model_json}
             jso_useful_key = {'_pdf_type': '', 'model_list': model_json}
             pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
             pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
@@ -77,7 +80,7 @@ class TestCli:
             pdf_bytes = open(pdf_path, 'rb').read()
             pdf_bytes = open(pdf_path, 'rb').read()
             local_image_dir = os.path.join(pdf_dev_path, 'pdf', 'images')
             local_image_dir = os.path.join(pdf_dev_path, 'pdf', 'images')
             image_dir = str(os.path.basename(local_image_dir))
             image_dir = str(os.path.basename(local_image_dir))
-            image_writer = DiskReaderWriter(local_image_dir)
+            image_writer = FileBasedDataWriter(local_image_dir)
             model_json = list()
             model_json = list()
             jso_useful_key = {'_pdf_type': 'ocr', 'model_list': model_json}
             jso_useful_key = {'_pdf_type': 'ocr', 'model_list': model_json}
             pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
             pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
@@ -112,7 +115,7 @@ class TestCli:
             pdf_bytes = open(pdf_path, 'rb').read()
             pdf_bytes = open(pdf_path, 'rb').read()
             local_image_dir = os.path.join(pdf_dev_path, 'pdf', 'images')
             local_image_dir = os.path.join(pdf_dev_path, 'pdf', 'images')
             image_dir = str(os.path.basename(local_image_dir))
             image_dir = str(os.path.basename(local_image_dir))
-            image_writer = DiskReaderWriter(local_image_dir)
+            image_writer = FileBasedDataWriter(local_image_dir)
             model_json = list()
             model_json = list()
             jso_useful_key = {'_pdf_type': 'txt', 'model_list': model_json}
             jso_useful_key = {'_pdf_type': 'txt', 'model_list': model_json}
             pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
             pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer)
@@ -284,12 +287,13 @@ class TestCli:
         pdf_endpoint = os.environ.get('pdf_endpoint', "")
         pdf_endpoint = os.environ.get('pdf_endpoint', "")
         s3_pdf_path = conf.conf["s3_pdf_path"]
         s3_pdf_path = conf.conf["s3_pdf_path"]
         image_dir = "s3://" + pdf_bucket + "/mineru/test/output"
         image_dir = "s3://" + pdf_bucket + "/mineru/test/output"
-        print (image_dir)
-        s3pdf_cli = S3ReaderWriter(pdf_ak, pdf_sk, pdf_endpoint)
-        s3image_cli = S3ReaderWriter(pdf_ak, pdf_sk, pdf_endpoint, parent_path=image_dir)
-        pdf_bytes = s3pdf_cli.read(s3_pdf_path, mode=s3pdf_cli.MODE_BIN)
-        jso_useful_key = {"_pdf_type": "", "model_list": []}
-        pipe = UNIPipe(pdf_bytes, jso_useful_key, s3image_cli)
+        prefix = "mineru/test/output"
+        reader = S3DataReader(prefix, pdf_bucket, pdf_ak, pdf_sk, pdf_endpoint)
+        # = S3DataWriter(prefix, pdf_bucket, pdf_ak, pdf_sk, pdf_endpoint)
+        image_writer = S3DataWriter(prefix, pdf_bucket, pdf_ak, pdf_sk, pdf_endpoint)
+        pdf_bytes = reader.read(s3_pdf_path)
+        model_list = []
+        pipe = OCRPipe(pdf_bytes, model_list, image_writer)
         pipe.pipe_classify()
         pipe.pipe_classify()
         pipe.pipe_analyze()
         pipe.pipe_analyze()
         pipe.pipe_parse()
         pipe.pipe_parse()
@@ -427,3 +431,4 @@ class TestCli:
  
  
 if __name__ == '__main__':
 if __name__ == '__main__':
     pytest.main()
     pytest.main()
+

برخی فایل ها در این مقایسه diff نمایش داده نمی شوند زیرا تعداد فایل ها بسیار زیاد است