Browse Source

fix: using new data api replace old rw api

icecraft 1 year ago
parent
commit
6a481320ea

+ 3 - 4
magic_pdf/integrations/rag/utils.py

@@ -5,14 +5,13 @@ from pathlib import Path
 from loguru import logger
 from loguru import logger
 
 
 import magic_pdf.model as model_config
 import magic_pdf.model as model_config
+from magic_pdf.data.data_reader_writer import FileBasedDataReader
 from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
 from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
 from magic_pdf.integrations.rag.type import (CategoryType, ContentObject,
 from magic_pdf.integrations.rag.type import (CategoryType, ContentObject,
                                              ElementRelation, ElementRelType,
                                              ElementRelation, ElementRelType,
                                              LayoutElements,
                                              LayoutElements,
                                              LayoutElementsExtra, PageInfo)
                                              LayoutElementsExtra, PageInfo)
 from magic_pdf.libs.ocr_content_type import BlockType, ContentType
 from magic_pdf.libs.ocr_content_type import BlockType, ContentType
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 from magic_pdf.tools.common import do_parse, prepare_env
 from magic_pdf.tools.common import do_parse, prepare_env
 
 
 
 
@@ -224,8 +223,8 @@ def inference(path, output_dir, method):
                                                 str(Path(path).stem), method)
                                                 str(Path(path).stem), method)
 
 
     def read_fn(path):
     def read_fn(path):
-        disk_rw = DiskReaderWriter(os.path.dirname(path))
-        return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
+        disk_rw = FileBasedDataReader(os.path.dirname(path))
+        return disk_rw.read(os.path.basename(path))
 
 
     def parse_doc(doc_path: str):
     def parse_doc(doc_path: str):
         try:
         try:

+ 9 - 11
magic_pdf/libs/pdf_image_tools.py

@@ -1,23 +1,21 @@
 
 
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.libs.commons import fitz
-from magic_pdf.libs.commons import join_path
+from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.libs.commons import fitz, join_path
 from magic_pdf.libs.hash_utils import compute_sha256
 from magic_pdf.libs.hash_utils import compute_sha256
 
 
 
 
-def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: AbsReaderWriter):
-    """
-    从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径
-    save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。
-    """
+def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: DataWriter):
+    """从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
+    图片存放在save_path下,文件名是:
+    {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
     # 拼接文件名
     # 拼接文件名
-    filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}"
+    filename = f'{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}'
 
 
     # 老版本返回不带bucket的路径
     # 老版本返回不带bucket的路径
     img_path = join_path(return_path, filename) if return_path is not None else None
     img_path = join_path(return_path, filename) if return_path is not None else None
 
 
     # 新版本生成平铺路径
     # 新版本生成平铺路径
-    img_hash256_path = f"{compute_sha256(img_path)}.jpg"
+    img_hash256_path = f'{compute_sha256(img_path)}.jpg'
 
 
     # 将坐标转换为fitz.Rect对象
     # 将坐标转换为fitz.Rect对象
     rect = fitz.Rect(*bbox)
     rect = fitz.Rect(*bbox)
@@ -28,6 +26,6 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri
 
 
     byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
     byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
 
 
-    imageWriter.write(byte_data, img_hash256_path, AbsReaderWriter.MODE_BIN)
+    imageWriter.write(img_hash256_path, byte_data)
 
 
     return img_hash256_path
     return img_hash256_path

+ 11 - 11
magic_pdf/model/magic_model.py

@@ -1,6 +1,8 @@
 import enum
 import enum
 import json
 import json
 
 
+from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
+                                               FileBasedDataWriter)
 from magic_pdf.data.dataset import Dataset
 from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
 from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
                                     bbox_relative_pos, box_area, calculate_iou,
                                     bbox_relative_pos, box_area, calculate_iou,
@@ -12,8 +14,6 @@ from magic_pdf.libs.local_math import float_gt
 from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
 from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
 from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
 from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
 from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
 from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 
 
 CAPATION_OVERLAP_AREA_RATIO = 0.6
 CAPATION_OVERLAP_AREA_RATIO = 0.6
 MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
 MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
@@ -1050,27 +1050,27 @@ class MagicModel:
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
-    drw = DiskReaderWriter(r'D:/project/20231108code-clean')
+    drw = FileBasedDataReader(r'D:/project/20231108code-clean')
     if 0:
     if 0:
         pdf_file_path = r'linshixuqiu\19983-00.pdf'
         pdf_file_path = r'linshixuqiu\19983-00.pdf'
         model_file_path = r'linshixuqiu\19983-00_new.json'
         model_file_path = r'linshixuqiu\19983-00_new.json'
-        pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
-        model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
+        pdf_bytes = drw.read(pdf_file_path)
+        model_json_txt = drw.read(model_file_path).decode()
         model_list = json.loads(model_json_txt)
         model_list = json.loads(model_json_txt)
         write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
         write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
         img_bucket_path = 'imgs'
         img_bucket_path = 'imgs'
-        img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
+        img_writer = FileBasedDataWriter(join_path(write_path, img_bucket_path))
         pdf_docs = fitz.open('pdf', pdf_bytes)
         pdf_docs = fitz.open('pdf', pdf_bytes)
         magic_model = MagicModel(model_list, pdf_docs)
         magic_model = MagicModel(model_list, pdf_docs)
 
 
     if 1:
     if 1:
+        from magic_pdf.data.dataset import PymuDocDataset
+
         model_list = json.loads(
         model_list = json.loads(
             drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json')
             drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json')
         )
         )
-        pdf_bytes = drw.read(
-            '/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf', AbsReaderWriter.MODE_BIN
-        )
-        pdf_docs = fitz.open('pdf', pdf_bytes)
-        magic_model = MagicModel(model_list, pdf_docs)
+        pdf_bytes = drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf')
+
+        magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
         for i in range(7):
         for i in range(7):
             print(magic_model.get_imgs(i))
             print(magic_model.get_imgs(i))

+ 27 - 43
magic_pdf/pipe/AbsPipe.py

@@ -1,22 +1,20 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
 
 
+from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.dict2md.ocr_mkcontent import union_make
 from magic_pdf.dict2md.ocr_mkcontent import union_make
 from magic_pdf.filter.pdf_classify_by_type import classify
 from magic_pdf.filter.pdf_classify_by_type import classify
 from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
 from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
-from magic_pdf.libs.MakeContentConfig import MakeMode, DropMode
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
 from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.json_compressor import JsonCompressor
 from magic_pdf.libs.json_compressor import JsonCompressor
+from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 
 
 
 
 class AbsPipe(ABC):
 class AbsPipe(ABC):
-    """
-    txt和ocr处理的抽象类
-    """
-    PIP_OCR = "ocr"
-    PIP_TXT = "txt"
+    """txt和ocr处理的抽象类."""
+    PIP_OCR = 'ocr'
+    PIP_TXT = 'txt'
 
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
                  start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
                  start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
         self.pdf_bytes = pdf_bytes
         self.pdf_bytes = pdf_bytes
         self.model_list = model_list
         self.model_list = model_list
@@ -29,29 +27,23 @@ class AbsPipe(ABC):
         self.layout_model = layout_model
         self.layout_model = layout_model
         self.formula_enable = formula_enable
         self.formula_enable = formula_enable
         self.table_enable = table_enable
         self.table_enable = table_enable
-    
+
     def get_compress_pdf_mid_data(self):
     def get_compress_pdf_mid_data(self):
         return JsonCompressor.compress_json(self.pdf_mid_data)
         return JsonCompressor.compress_json(self.pdf_mid_data)
 
 
     @abstractmethod
     @abstractmethod
     def pipe_classify(self):
     def pipe_classify(self):
-        """
-        有状态的分类
-        """
+        """有状态的分类."""
         raise NotImplementedError
         raise NotImplementedError
 
 
     @abstractmethod
     @abstractmethod
     def pipe_analyze(self):
     def pipe_analyze(self):
-        """
-        有状态的跑模型分析
-        """
+        """有状态的跑模型分析."""
         raise NotImplementedError
         raise NotImplementedError
 
 
     @abstractmethod
     @abstractmethod
     def pipe_parse(self):
     def pipe_parse(self):
-        """
-        有状态的解析
-        """
+        """有状态的解析."""
         raise NotImplementedError
         raise NotImplementedError
 
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
@@ -64,27 +56,25 @@ class AbsPipe(ABC):
 
 
     @staticmethod
     @staticmethod
     def classify(pdf_bytes: bytes) -> str:
     def classify(pdf_bytes: bytes) -> str:
-        """
-        根据pdf的元数据,判断是文本pdf,还是ocr pdf
-        """
+        """根据pdf的元数据,判断是文本pdf,还是ocr pdf."""
         pdf_meta = pdf_meta_scan(pdf_bytes)
         pdf_meta = pdf_meta_scan(pdf_bytes)
-        if pdf_meta.get("_need_drop", False):  # 如果返回了需要丢弃的标志,则抛出异常
+        if pdf_meta.get('_need_drop', False):  # 如果返回了需要丢弃的标志,则抛出异常
             raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
             raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
         else:
         else:
-            is_encrypted = pdf_meta["is_encrypted"]
-            is_needs_password = pdf_meta["is_needs_password"]
+            is_encrypted = pdf_meta['is_encrypted']
+            is_needs_password = pdf_meta['is_needs_password']
             if is_encrypted or is_needs_password:  # 加密的,需要密码的,没有页面的,都不处理
             if is_encrypted or is_needs_password:  # 加密的,需要密码的,没有页面的,都不处理
-                raise Exception(f"pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}")
+                raise Exception(f'pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}')
             else:
             else:
                 is_text_pdf, results = classify(
                 is_text_pdf, results = classify(
-                    pdf_meta["total_page"],
-                    pdf_meta["page_width_pts"],
-                    pdf_meta["page_height_pts"],
-                    pdf_meta["image_info_per_page"],
-                    pdf_meta["text_len_per_page"],
-                    pdf_meta["imgs_per_page"],
-                    pdf_meta["text_layout_per_page"],
-                    pdf_meta["invalid_chars"],
+                    pdf_meta['total_page'],
+                    pdf_meta['page_width_pts'],
+                    pdf_meta['page_height_pts'],
+                    pdf_meta['image_info_per_page'],
+                    pdf_meta['text_len_per_page'],
+                    pdf_meta['imgs_per_page'],
+                    pdf_meta['text_layout_per_page'],
+                    pdf_meta['invalid_chars'],
                 )
                 )
                 if is_text_pdf:
                 if is_text_pdf:
                     return AbsPipe.PIP_TXT
                     return AbsPipe.PIP_TXT
@@ -93,22 +83,16 @@ class AbsPipe(ABC):
 
 
     @staticmethod
     @staticmethod
     def mk_uni_format(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF) -> list:
     def mk_uni_format(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF) -> list:
-        """
-        根据pdf类型,生成统一格式content_list
-        """
+        """根据pdf类型,生成统一格式content_list."""
         pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
         pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
-        pdf_info_list = pdf_mid_data["pdf_info"]
+        pdf_info_list = pdf_mid_data['pdf_info']
         content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path)
         content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path)
         return content_list
         return content_list
 
 
     @staticmethod
     @staticmethod
     def mk_markdown(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD) -> list:
     def mk_markdown(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD) -> list:
-        """
-        根据pdf类型,markdown
-        """
+        """根据pdf类型,markdown."""
         pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
         pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
-        pdf_info_list = pdf_mid_data["pdf_info"]
+        pdf_info_list = pdf_mid_data['pdf_info']
         md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path)
         md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path)
         return md_content
         return md_content
-
-

+ 4 - 4
magic_pdf/pipe/OCRPipe.py

@@ -1,15 +1,15 @@
 from loguru import logger
 from loguru import logger
 
 
+from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.user_api import parse_ocr_pdf
 from magic_pdf.user_api import parse_ocr_pdf
 
 
 
 
 class OCRPipe(AbsPipe):
 class OCRPipe(AbsPipe):
 
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
                  start_page_id=0, end_page_id=None, lang=None,
                  start_page_id=0, end_page_id=None, lang=None,
                  layout_model=None, formula_enable=None, table_enable=None):
                  layout_model=None, formula_enable=None, table_enable=None):
         super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
         super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
@@ -32,10 +32,10 @@ class OCRPipe(AbsPipe):
 
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
-        logger.info("ocr_pipe mk content list finished")
+        logger.info('ocr_pipe mk content list finished')
         return result
         return result
 
 
     def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
     def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
-        logger.info(f"ocr_pipe mk {md_make_mode} finished")
+        logger.info(f'ocr_pipe mk {md_make_mode} finished')
         return result
         return result

+ 4 - 5
magic_pdf/pipe/TXTPipe.py

@@ -1,16 +1,15 @@
 from loguru import logger
 from loguru import logger
 
 
+from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.libs.json_compressor import JsonCompressor
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.user_api import parse_txt_pdf
 from magic_pdf.user_api import parse_txt_pdf
 
 
 
 
 class TXTPipe(AbsPipe):
 class TXTPipe(AbsPipe):
 
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
                  start_page_id=0, end_page_id=None, lang=None,
                  start_page_id=0, end_page_id=None, lang=None,
                  layout_model=None, formula_enable=None, table_enable=None):
                  layout_model=None, formula_enable=None, table_enable=None):
         super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
         super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
@@ -33,10 +32,10 @@ class TXTPipe(AbsPipe):
 
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
-        logger.info("txt_pipe mk content list finished")
+        logger.info('txt_pipe mk content list finished')
         return result
         return result
 
 
     def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
     def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
-        logger.info(f"txt_pipe mk {md_make_mode} finished")
+        logger.info(f'txt_pipe mk {md_make_mode} finished')
         return result
         return result

+ 23 - 24
magic_pdf/pipe/UNIPipe.py

@@ -2,22 +2,21 @@ import json
 
 
 from loguru import logger
 from loguru import logger
 
 
+from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.libs.commons import join_path
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
-from magic_pdf.libs.commons import join_path
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.pipe.AbsPipe import AbsPipe
-from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
+from magic_pdf.user_api import parse_ocr_pdf, parse_union_pdf
 
 
 
 
 class UNIPipe(AbsPipe):
 class UNIPipe(AbsPipe):
 
 
-    def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
+    def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: DataWriter, is_debug: bool = False,
                  start_page_id=0, end_page_id=None, lang=None,
                  start_page_id=0, end_page_id=None, lang=None,
                  layout_model=None, formula_enable=None, table_enable=None):
                  layout_model=None, formula_enable=None, table_enable=None):
-        self.pdf_type = jso_useful_key["_pdf_type"]
-        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id,
+        self.pdf_type = jso_useful_key['_pdf_type']
+        super().__init__(pdf_bytes, jso_useful_key['model_list'], image_writer, is_debug, start_page_id, end_page_id,
                          lang, layout_model, formula_enable, table_enable)
                          lang, layout_model, formula_enable, table_enable)
         if len(self.model_list) == 0:
         if len(self.model_list) == 0:
             self.input_model_is_empty = True
             self.input_model_is_empty = True
@@ -54,27 +53,28 @@ class UNIPipe(AbsPipe):
 
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON):
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
-        logger.info("uni_pipe mk content list finished")
+        logger.info('uni_pipe mk content list finished')
         return result
         return result
 
 
     def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
     def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
-        logger.info(f"uni_pipe mk {md_make_mode} finished")
+        logger.info(f'uni_pipe mk {md_make_mode} finished')
         return result
         return result
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     # 测试
     # 测试
-    drw = DiskReaderWriter(r"D:/project/20231108code-clean")
+    from magic_pdf.data.data_reader_writer import DataReader
+    drw = DataReader(r'D:/project/20231108code-clean')
 
 
-    pdf_file_path = r"linshixuqiu\19983-00.pdf"
-    model_file_path = r"linshixuqiu\19983-00.json"
-    pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
-    model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
+    pdf_file_path = r'linshixuqiu\19983-00.pdf'
+    model_file_path = r'linshixuqiu\19983-00.json'
+    pdf_bytes = drw.read(pdf_file_path)
+    model_json_txt = drw.read(model_file_path).decode()
     model_list = json.loads(model_json_txt)
     model_list = json.loads(model_json_txt)
-    write_path = r"D:\project\20231108code-clean\linshixuqiu\19983-00"
-    img_bucket_path = "imgs"
-    img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
+    write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
+    img_bucket_path = 'imgs'
+    img_writer = DataWriter(join_path(write_path, img_bucket_path))
 
 
     # pdf_type = UNIPipe.classify(pdf_bytes)
     # pdf_type = UNIPipe.classify(pdf_bytes)
     # jso_useful_key = {
     # jso_useful_key = {
@@ -83,8 +83,8 @@ if __name__ == '__main__':
     # }
     # }
 
 
     jso_useful_key = {
     jso_useful_key = {
-        "_pdf_type": "",
-        "model_list": model_list
+        '_pdf_type': '',
+        'model_list': model_list
     }
     }
     pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer)
     pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer)
     pipe.pipe_classify()
     pipe.pipe_classify()
@@ -92,8 +92,7 @@ if __name__ == '__main__':
     md_content = pipe.pipe_mk_markdown(img_bucket_path)
     md_content = pipe.pipe_mk_markdown(img_bucket_path)
     content_list = pipe.pipe_mk_uni_format(img_bucket_path)
     content_list = pipe.pipe_mk_uni_format(img_bucket_path)
 
 
-    md_writer = DiskReaderWriter(write_path)
-    md_writer.write(md_content, "19983-00.md", AbsReaderWriter.MODE_TXT)
-    md_writer.write(json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4), "19983-00.json",
-                    AbsReaderWriter.MODE_TXT)
-    md_writer.write(str(content_list), "19983-00.txt", AbsReaderWriter.MODE_TXT)
+    md_writer = DataWriter(write_path)
+    md_writer.write_string('19983-00.md', md_content)
+    md_writer.write_string('19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4))
+    md_writer.write_string('19983-00.txt', str(content_list))

+ 3 - 4
magic_pdf/tools/cli.py

@@ -5,9 +5,8 @@ import click
 from loguru import logger
 from loguru import logger
 
 
 import magic_pdf.model as model_config
 import magic_pdf.model as model_config
+from magic_pdf.data.data_reader_writer import FileBasedDataReader
 from magic_pdf.libs.version import __version__
 from magic_pdf.libs.version import __version__
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 from magic_pdf.tools.common import do_parse, parse_pdf_methods
 from magic_pdf.tools.common import do_parse, parse_pdf_methods
 
 
 
 
@@ -86,8 +85,8 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
     os.makedirs(output_dir, exist_ok=True)
     os.makedirs(output_dir, exist_ok=True)
 
 
     def read_fn(path):
     def read_fn(path):
-        disk_rw = DiskReaderWriter(os.path.dirname(path))
-        return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
+        disk_rw = FileBasedDataReader(os.path.dirname(path))
+        return disk_rw.read(os.path.basename(path))
 
 
     def parse_doc(doc_path: str):
     def parse_doc(doc_path: str):
         try:
         try:

+ 6 - 9
magic_pdf/tools/cli_dev.py

@@ -5,13 +5,11 @@ from pathlib import Path
 import click
 import click
 
 
 import magic_pdf.model as model_config
 import magic_pdf.model as model_config
+from magic_pdf.data.data_reader_writer import FileBasedDataReader, S3DataReader
 from magic_pdf.libs.config_reader import get_s3_config
 from magic_pdf.libs.config_reader import get_s3_config
 from magic_pdf.libs.path_utils import (parse_s3_range_params, parse_s3path,
 from magic_pdf.libs.path_utils import (parse_s3_range_params, parse_s3path,
                                        remove_non_official_s3_args)
                                        remove_non_official_s3_args)
 from magic_pdf.libs.version import __version__
 from magic_pdf.libs.version import __version__
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
-from magic_pdf.rw.S3ReaderWriter import S3ReaderWriter
 from magic_pdf.tools.common import do_parse, parse_pdf_methods
 from magic_pdf.tools.common import do_parse, parse_pdf_methods
 
 
 
 
@@ -19,15 +17,14 @@ def read_s3_path(s3path):
     bucket, key = parse_s3path(s3path)
     bucket, key = parse_s3path(s3path)
 
 
     s3_ak, s3_sk, s3_endpoint = get_s3_config(bucket)
     s3_ak, s3_sk, s3_endpoint = get_s3_config(bucket)
-    s3_rw = S3ReaderWriter(s3_ak, s3_sk, s3_endpoint, 'auto',
-                           remove_non_official_s3_args(s3path))
+    s3_rw = S3DataReader('', bucket, s3_ak, s3_sk, s3_endpoint, 'auto')
     may_range_params = parse_s3_range_params(s3path)
     may_range_params = parse_s3_range_params(s3path)
     if may_range_params is None or 2 != len(may_range_params):
     if may_range_params is None or 2 != len(may_range_params):
-        byte_start, byte_end = 0, None
+        byte_start, byte_end = 0, -1
     else:
     else:
         byte_start, byte_end = int(may_range_params[0]), int(
         byte_start, byte_end = int(may_range_params[0]), int(
             may_range_params[1])
             may_range_params[1])
-    return s3_rw.read_offset(
+    return s3_rw.read_at(
         remove_non_official_s3_args(s3path),
         remove_non_official_s3_args(s3path),
         byte_start,
         byte_start,
         byte_end,
         byte_end,
@@ -129,8 +126,8 @@ def pdf(pdf, json_data, output_dir, method):
     os.makedirs(output_dir, exist_ok=True)
     os.makedirs(output_dir, exist_ok=True)
 
 
     def read_fn(path):
     def read_fn(path):
-        disk_rw = DiskReaderWriter(os.path.dirname(path))
-        return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
+        disk_rw = FileBasedDataReader(os.path.dirname(path))
+        return disk_rw.read(os.path.basename(path))
 
 
     model_json_list = json_parse.loads(read_fn(json_data).decode('utf-8'))
     model_json_list = json_parse.loads(read_fn(json_data).decode('utf-8'))
 
 

+ 22 - 35
magic_pdf/tools/common.py

@@ -3,18 +3,18 @@ import json as json_parse
 import os
 import os
 
 
 import click
 import click
+import fitz
 from loguru import logger
 from loguru import logger
 
 
 import magic_pdf.model as model_config
 import magic_pdf.model as model_config
+from magic_pdf.data.data_reader_writer import FileBasedDataWriter
 from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
 from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
                                       draw_model_bbox, draw_span_bbox)
                                       draw_model_bbox, draw_span_bbox)
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
 from magic_pdf.pipe.UNIPipe import UNIPipe
 from magic_pdf.pipe.UNIPipe import UNIPipe
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
-import fitz
+
 # from io import BytesIO
 # from io import BytesIO
 # from pypdf import PdfReader, PdfWriter
 # from pypdf import PdfReader, PdfWriter
 
 
@@ -54,11 +54,11 @@ def prepare_env(output_dir, pdf_file_name, method):
 
 
 
 
 def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_id=None):
 def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_id=None):
-    document = fitz.open("pdf", pdf_bytes)
+    document = fitz.open('pdf', pdf_bytes)
     output_document = fitz.open()
     output_document = fitz.open()
     end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(document) - 1
     end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(document) - 1
     if end_page_id > len(document) - 1:
     if end_page_id > len(document) - 1:
-        logger.warning("end_page_id is out of range, use pdf_docs length")
+        logger.warning('end_page_id is out of range, use pdf_docs length')
         end_page_id = len(document) - 1
         end_page_id = len(document) - 1
     output_document.insert_pdf(document, from_page=start_page_id, to_page=end_page_id)
     output_document.insert_pdf(document, from_page=start_page_id, to_page=end_page_id)
     output_bytes = output_document.tobytes()
     output_bytes = output_document.tobytes()
@@ -100,8 +100,8 @@ def do_parse(
     local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name,
     local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name,
                                                 parse_method)
                                                 parse_method)
 
 
-    image_writer, md_writer = DiskReaderWriter(
-        local_image_dir), DiskReaderWriter(local_md_dir)
+    image_writer, md_writer = FileBasedDataWriter(
+        local_image_dir), FileBasedDataWriter(local_md_dir)
     image_dir = str(os.path.basename(local_image_dir))
     image_dir = str(os.path.basename(local_image_dir))
 
 
     if parse_method == 'auto':
     if parse_method == 'auto':
@@ -145,49 +145,36 @@ def do_parse(
     if f_draw_line_sort_bbox:
     if f_draw_line_sort_bbox:
         draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
         draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
 
 
-    md_content = pipe.pipe_mk_markdown(image_dir,
-                                       drop_mode=DropMode.NONE,
-                                       md_make_mode=f_make_md_mode)
+    md_content = pipe.pipe_mk_markdown(image_dir, drop_mode=DropMode.NONE, md_make_mode=f_make_md_mode)
     if f_dump_md:
     if f_dump_md:
-        md_writer.write(
-            content=md_content,
-            path=f'{pdf_file_name}.md',
-            mode=AbsReaderWriter.MODE_TXT,
+        md_writer.write_string(
+            f'{pdf_file_name}.md',
+            md_content
         )
         )
 
 
     if f_dump_middle_json:
     if f_dump_middle_json:
-        md_writer.write(
-            content=json_parse.dumps(pipe.pdf_mid_data,
-                                     ensure_ascii=False,
-                                     indent=4),
-            path=f'{pdf_file_name}_middle.json',
-            mode=AbsReaderWriter.MODE_TXT,
+        md_writer.write_string(
+            f'{pdf_file_name}_middle.json',
+            json_parse.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
         )
         )
 
 
     if f_dump_model_json:
     if f_dump_model_json:
-        md_writer.write(
-            content=json_parse.dumps(orig_model_list,
-                                     ensure_ascii=False,
-                                     indent=4),
-            path=f'{pdf_file_name}_model.json',
-            mode=AbsReaderWriter.MODE_TXT,
+        md_writer.write_string(
+            f'{pdf_file_name}_model.json',
+            json_parse.dumps(orig_model_list, ensure_ascii=False, indent=4)
         )
         )
 
 
     if f_dump_orig_pdf:
     if f_dump_orig_pdf:
         md_writer.write(
         md_writer.write(
-            content=pdf_bytes,
-            path=f'{pdf_file_name}_origin.pdf',
-            mode=AbsReaderWriter.MODE_BIN,
+            f'{pdf_file_name}_origin.pdf',
+            pdf_bytes,
         )
         )
 
 
     content_list = pipe.pipe_mk_uni_format(image_dir, drop_mode=DropMode.NONE)
     content_list = pipe.pipe_mk_uni_format(image_dir, drop_mode=DropMode.NONE)
     if f_dump_content_list:
     if f_dump_content_list:
-        md_writer.write(
-            content=json_parse.dumps(content_list,
-                                     ensure_ascii=False,
-                                     indent=4),
-            path=f'{pdf_file_name}_content_list.json',
-            mode=AbsReaderWriter.MODE_TXT,
+        md_writer.write_string(
+            f'{pdf_file_name}_content_list.json',
+            json_parse.dumps(content_list, ensure_ascii=False, indent=4)
         )
         )
 
 
     logger.info(f'local output dir is {local_md_dir}')
     logger.info(f'local output dir is {local_md_dir}')

+ 26 - 38
magic_pdf/user_api.py

@@ -1,36 +1,28 @@
-"""
-用户输入:
-    model数组,每个元素代表一个页面
-    pdf在s3的路径
-    截图保存的s3位置
+"""用户输入: model数组,每个元素代表一个页面 pdf在s3的路径 截图保存的s3位置.
 
 
 然后:
 然后:
     1)根据s3路径,调用spark集群的api,拿到ak,sk,endpoint,构造出s3PDFReader
     1)根据s3路径,调用spark集群的api,拿到ak,sk,endpoint,构造出s3PDFReader
     2)根据用户输入的s3地址,调用spark集群的api,拿到ak,sk,endpoint,构造出s3ImageWriter
     2)根据用户输入的s3地址,调用spark集群的api,拿到ak,sk,endpoint,构造出s3ImageWriter
 
 
 其余部分至于构造s3cli, 获取ak,sk都在code-clean里写代码完成。不要反向依赖!!!
 其余部分至于构造s3cli, 获取ak,sk都在code-clean里写代码完成。不要反向依赖!!!
-
 """
 """
-import re
 
 
 from loguru import logger
 from loguru import logger
 
 
+from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.libs.version import __version__
 from magic_pdf.libs.version import __version__
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
-from magic_pdf.rw import AbsReaderWriter
 from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
 from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
 from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
 from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
 
 
-PARSE_TYPE_TXT = "txt"
-PARSE_TYPE_OCR = "ocr"
+PARSE_TYPE_TXT = 'txt'
+PARSE_TYPE_OCR = 'ocr'
 
 
 
 
-def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
+def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False,
                   start_page_id=0, end_page_id=None, lang=None,
                   start_page_id=0, end_page_id=None, lang=None,
                   *args, **kwargs):
                   *args, **kwargs):
-    """
-    解析文本类pdf
-    """
+    """解析文本类pdf."""
     pdf_info_dict = parse_pdf_by_txt(
     pdf_info_dict = parse_pdf_by_txt(
         pdf_bytes,
         pdf_bytes,
         pdf_models,
         pdf_models,
@@ -40,22 +32,20 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
         debug_mode=is_debug,
         debug_mode=is_debug,
     )
     )
 
 
-    pdf_info_dict["_parse_type"] = PARSE_TYPE_TXT
+    pdf_info_dict['_parse_type'] = PARSE_TYPE_TXT
 
 
-    pdf_info_dict["_version_name"] = __version__
+    pdf_info_dict['_version_name'] = __version__
 
 
     if lang is not None:
     if lang is not None:
-        pdf_info_dict["_lang"] = lang
+        pdf_info_dict['_lang'] = lang
 
 
     return pdf_info_dict
     return pdf_info_dict
 
 
 
 
-def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
+def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False,
                   start_page_id=0, end_page_id=None, lang=None,
                   start_page_id=0, end_page_id=None, lang=None,
                   *args, **kwargs):
                   *args, **kwargs):
-    """
-    解析ocr类pdf
-    """
+    """解析ocr类pdf."""
     pdf_info_dict = parse_pdf_by_ocr(
     pdf_info_dict = parse_pdf_by_ocr(
         pdf_bytes,
         pdf_bytes,
         pdf_models,
         pdf_models,
@@ -65,23 +55,21 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
         debug_mode=is_debug,
         debug_mode=is_debug,
     )
     )
 
 
-    pdf_info_dict["_parse_type"] = PARSE_TYPE_OCR
+    pdf_info_dict['_parse_type'] = PARSE_TYPE_OCR
 
 
-    pdf_info_dict["_version_name"] = __version__
+    pdf_info_dict['_version_name'] = __version__
 
 
     if lang is not None:
     if lang is not None:
-        pdf_info_dict["_lang"] = lang
+        pdf_info_dict['_lang'] = lang
 
 
     return pdf_info_dict
     return pdf_info_dict
 
 
 
 
-def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
+def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False,
                     input_model_is_empty: bool = False,
                     input_model_is_empty: bool = False,
                     start_page_id=0, end_page_id=None, lang=None,
                     start_page_id=0, end_page_id=None, lang=None,
                     *args, **kwargs):
                     *args, **kwargs):
-    """
-    ocr和文本混合的pdf,全部解析出来
-    """
+    """ocr和文本混合的pdf,全部解析出来."""
 
 
     def parse_pdf(method):
     def parse_pdf(method):
         try:
         try:
@@ -98,12 +86,12 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
             return None
             return None
 
 
     pdf_info_dict = parse_pdf(parse_pdf_by_txt)
     pdf_info_dict = parse_pdf(parse_pdf_by_txt)
-    if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
-        logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
+    if pdf_info_dict is None or pdf_info_dict.get('_need_drop', False):
+        logger.warning('parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr')
         if input_model_is_empty:
         if input_model_is_empty:
-            layout_model = kwargs.get("layout_model", None)
-            formula_enable = kwargs.get("formula_enable", None)
-            table_enable = kwargs.get("table_enable", None)
+            layout_model = kwargs.get('layout_model', None)
+            formula_enable = kwargs.get('formula_enable', None)
+            table_enable = kwargs.get('table_enable', None)
             pdf_models = doc_analyze(
             pdf_models = doc_analyze(
                 pdf_bytes,
                 pdf_bytes,
                 ocr=True,
                 ocr=True,
@@ -116,15 +104,15 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
             )
             )
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         if pdf_info_dict is None:
         if pdf_info_dict is None:
-            raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
+            raise Exception('Both parse_pdf_by_txt and parse_pdf_by_ocr failed.')
         else:
         else:
-            pdf_info_dict["_parse_type"] = PARSE_TYPE_OCR
+            pdf_info_dict['_parse_type'] = PARSE_TYPE_OCR
     else:
     else:
-        pdf_info_dict["_parse_type"] = PARSE_TYPE_TXT
+        pdf_info_dict['_parse_type'] = PARSE_TYPE_TXT
 
 
-    pdf_info_dict["_version_name"] = __version__
+    pdf_info_dict['_version_name'] = __version__
 
 
     if lang is not None:
     if lang is not None:
-        pdf_info_dict["_lang"] = lang
+        pdf_info_dict['_lang'] = lang
 
 
     return pdf_info_dict
     return pdf_info_dict

+ 49 - 52
projects/gradio_app/app.py

@@ -2,39 +2,37 @@
 
 
 import base64
 import base64
 import os
 import os
+import re
 import time
 import time
 import uuid
 import uuid
 import zipfile
 import zipfile
 from pathlib import Path
 from pathlib import Path
-import re
 
 
+import gradio as gr
 import pymupdf
 import pymupdf
+from gradio_pdf import PDF
 from loguru import logger
 from loguru import logger
 
 
+from magic_pdf.data.data_reader_writer import DataReader
 from magic_pdf.libs.hash_utils import compute_sha256
 from magic_pdf.libs.hash_utils import compute_sha256
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 from magic_pdf.tools.common import do_parse, prepare_env
 from magic_pdf.tools.common import do_parse, prepare_env
 
 
-import gradio as gr
-from gradio_pdf import PDF
-
 
 
 def read_fn(path):
 def read_fn(path):
-    disk_rw = DiskReaderWriter(os.path.dirname(path))
-    return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
+    disk_rw = DataReader(os.path.dirname(path))
+    return disk_rw.read(os.path.basename(path))
 
 
 
 
 def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_enable, table_enable, language):
 def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_enable, table_enable, language):
     os.makedirs(output_dir, exist_ok=True)
     os.makedirs(output_dir, exist_ok=True)
 
 
     try:
     try:
-        file_name = f"{str(Path(doc_path).stem)}_{time.time()}"
+        file_name = f'{str(Path(doc_path).stem)}_{time.time()}'
         pdf_data = read_fn(doc_path)
         pdf_data = read_fn(doc_path)
         if is_ocr:
         if is_ocr:
-            parse_method = "ocr"
+            parse_method = 'ocr'
         else:
         else:
-            parse_method = "auto"
+            parse_method = 'auto'
         local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method)
         local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method)
         do_parse(
         do_parse(
             output_dir,
             output_dir,
@@ -55,8 +53,7 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_en
 
 
 
 
 def compress_directory_to_zip(directory_path, output_zip_path):
 def compress_directory_to_zip(directory_path, output_zip_path):
-    """
-    压缩指定目录到一个 ZIP 文件。
+    """压缩指定目录到一个 ZIP 文件。
 
 
     :param directory_path: 要压缩的目录路径
     :param directory_path: 要压缩的目录路径
     :param output_zip_path: 输出的 ZIP 文件路径
     :param output_zip_path: 输出的 ZIP 文件路径
@@ -80,7 +77,7 @@ def compress_directory_to_zip(directory_path, output_zip_path):
 
 
 
 
 def image_to_base64(image_path):
 def image_to_base64(image_path):
-    with open(image_path, "rb") as image_file:
+    with open(image_path, 'rb') as image_file:
         return base64.b64encode(image_file.read()).decode('utf-8')
         return base64.b64encode(image_file.read()).decode('utf-8')
 
 
 
 
@@ -93,7 +90,7 @@ def replace_image_with_base64(markdown_text, image_dir_path):
         relative_path = match.group(1)
         relative_path = match.group(1)
         full_path = os.path.join(image_dir_path, relative_path)
         full_path = os.path.join(image_dir_path, relative_path)
         base64_image = image_to_base64(full_path)
         base64_image = image_to_base64(full_path)
-        return f"![{relative_path}](data:image/jpeg;base64,{base64_image})"
+        return f'![{relative_path}](data:image/jpeg;base64,{base64_image})'
 
 
     # 应用替换
     # 应用替换
     return re.sub(pattern, replace, markdown_text)
     return re.sub(pattern, replace, markdown_text)
@@ -103,34 +100,34 @@ def to_markdown(file_path, end_pages, is_ocr, layout_mode, formula_enable, table
     # 获取识别的md文件以及压缩包文件路径
     # 获取识别的md文件以及压缩包文件路径
     local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr,
     local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr,
                                         layout_mode, formula_enable, table_enable, language)
                                         layout_mode, formula_enable, table_enable, language)
-    archive_zip_path = os.path.join("./output", compute_sha256(local_md_dir) + ".zip")
+    archive_zip_path = os.path.join('./output', compute_sha256(local_md_dir) + '.zip')
     zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
     zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
     if zip_archive_success == 0:
     if zip_archive_success == 0:
-        logger.info("压缩成功")
+        logger.info('压缩成功')
     else:
     else:
-        logger.error("压缩失败")
-    md_path = os.path.join(local_md_dir, file_name + ".md")
+        logger.error('压缩失败')
+    md_path = os.path.join(local_md_dir, file_name + '.md')
     with open(md_path, 'r', encoding='utf-8') as f:
     with open(md_path, 'r', encoding='utf-8') as f:
         txt_content = f.read()
         txt_content = f.read()
     md_content = replace_image_with_base64(txt_content, local_md_dir)
     md_content = replace_image_with_base64(txt_content, local_md_dir)
     # 返回转换后的PDF路径
     # 返回转换后的PDF路径
-    new_pdf_path = os.path.join(local_md_dir, file_name + "_layout.pdf")
+    new_pdf_path = os.path.join(local_md_dir, file_name + '_layout.pdf')
 
 
     return md_content, txt_content, archive_zip_path, new_pdf_path
     return md_content, txt_content, archive_zip_path, new_pdf_path
 
 
 
 
-latex_delimiters = [{"left": "$$", "right": "$$", "display": True},
-                    {"left": '$', "right": '$', "display": False}]
+latex_delimiters = [{'left': '$$', 'right': '$$', 'display': True},
+                    {'left': '$', 'right': '$', 'display': False}]
 
 
 
 
 def init_model():
 def init_model():
     from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
     from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
     try:
     try:
         model_manager = ModelSingleton()
         model_manager = ModelSingleton()
-        txt_model = model_manager.get_model(False, False)
-        logger.info(f"txt_model init final")
-        ocr_model = model_manager.get_model(True, False)
-        logger.info(f"ocr_model init final")
+        txt_model = model_manager.get_model(False, False)  # noqa: F841
+        logger.info('txt_model init final')
+        ocr_model = model_manager.get_model(True, False)  # noqa: F841
+        logger.info('ocr_model init final')
         return 0
         return 0
     except Exception as e:
     except Exception as e:
         logger.exception(e)
         logger.exception(e)
@@ -138,31 +135,31 @@ def init_model():
 
 
 
 
 model_init = init_model()
 model_init = init_model()
-logger.info(f"model_init: {model_init}")
+logger.info(f'model_init: {model_init}')
 
 
 
 
-with open("header.html", "r") as file:
+with open('header.html', 'r') as file:
     header = file.read()
     header = file.read()
 
 
 
 
 latin_lang = [
 latin_lang = [
-        'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
+        'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',  # noqa: E126
         'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
         'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
         'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
         'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
         'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
         'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
 ]
 ]
 arabic_lang = ['ar', 'fa', 'ug', 'ur']
 arabic_lang = ['ar', 'fa', 'ug', 'ur']
 cyrillic_lang = [
 cyrillic_lang = [
-        'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
+        'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',  # noqa: E126
         'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
         'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
 ]
 ]
 devanagari_lang = [
 devanagari_lang = [
-        'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
+        'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',  # noqa: E126
         'sa', 'bgc'
         'sa', 'bgc'
 ]
 ]
 other_lang = ['ch', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
 other_lang = ['ch', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
 
 
-all_lang = [""]
+all_lang = ['']
 all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
 all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
 
 
 
 
@@ -174,7 +171,7 @@ def to_pdf(file_path):
             pdf_bytes = f.convert_to_pdf()
             pdf_bytes = f.convert_to_pdf()
             # 将pdfbytes 写入到uuid.pdf中
             # 将pdfbytes 写入到uuid.pdf中
             # 生成唯一的文件名
             # 生成唯一的文件名
-            unique_filename = f"{uuid.uuid4()}.pdf"
+            unique_filename = f'{uuid.uuid4()}.pdf'
 
 
             # 构建完整的文件路径
             # 构建完整的文件路径
             tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename)
             tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename)
@@ -186,43 +183,43 @@ def to_pdf(file_path):
             return tmp_file_path
             return tmp_file_path
 
 
 
 
-if __name__ == "__main__":
+if __name__ == '__main__':
     with gr.Blocks() as demo:
     with gr.Blocks() as demo:
         gr.HTML(header)
         gr.HTML(header)
         with gr.Row():
         with gr.Row():
             with gr.Column(variant='panel', scale=5):
             with gr.Column(variant='panel', scale=5):
-                file = gr.File(label="Please upload a PDF or image", file_types=[".pdf", ".png", ".jpeg", ".jpg"])
-                max_pages = gr.Slider(1, 10, 5, step=1, label="Max convert pages")
+                file = gr.File(label='Please upload a PDF or image', file_types=['.pdf', '.png', '.jpeg', '.jpg'])
+                max_pages = gr.Slider(1, 10, 5, step=1, label='Max convert pages')
                 with gr.Row():
                 with gr.Row():
-                    layout_mode = gr.Dropdown(["layoutlmv3", "doclayout_yolo"], label="Layout model", value="layoutlmv3")
-                    language = gr.Dropdown(all_lang, label="Language", value="")
+                    layout_mode = gr.Dropdown(['layoutlmv3', 'doclayout_yolo'], label='Layout model', value='layoutlmv3')
+                    language = gr.Dropdown(all_lang, label='Language', value='')
                 with gr.Row():
                 with gr.Row():
-                    formula_enable = gr.Checkbox(label="Enable formula recognition", value=True)
-                    is_ocr = gr.Checkbox(label="Force enable OCR", value=False)
-                    table_enable = gr.Checkbox(label="Enable table recognition(test)", value=False)
+                    formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
+                    is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
+                    table_enable = gr.Checkbox(label='Enable table recognition(test)', value=False)
                 with gr.Row():
                 with gr.Row():
-                    change_bu = gr.Button("Convert")
-                    clear_bu = gr.ClearButton(value="Clear")
-                pdf_show = PDF(label="PDF preview", interactive=True, height=800)
-                with gr.Accordion("Examples:"):
-                    example_root = os.path.join(os.path.dirname(__file__), "examples")
+                    change_bu = gr.Button('Convert')
+                    clear_bu = gr.ClearButton(value='Clear')
+                pdf_show = PDF(label='PDF preview', interactive=True, height=800)
+                with gr.Accordion('Examples:'):
+                    example_root = os.path.join(os.path.dirname(__file__), 'examples')
                     gr.Examples(
                     gr.Examples(
                         examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
                         examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
-                                  _.endswith("pdf")],
+                                  _.endswith('pdf')],
                         inputs=pdf_show
                         inputs=pdf_show
                     )
                     )
 
 
             with gr.Column(variant='panel', scale=5):
             with gr.Column(variant='panel', scale=5):
-                output_file = gr.File(label="convert result", interactive=False)
+                output_file = gr.File(label='convert result', interactive=False)
                 with gr.Tabs():
                 with gr.Tabs():
-                    with gr.Tab("Markdown rendering"):
-                        md = gr.Markdown(label="Markdown rendering", height=900, show_copy_button=True,
+                    with gr.Tab('Markdown rendering'):
+                        md = gr.Markdown(label='Markdown rendering', height=900, show_copy_button=True,
                                          latex_delimiters=latex_delimiters, line_breaks=True)
                                          latex_delimiters=latex_delimiters, line_breaks=True)
-                    with gr.Tab("Markdown text"):
+                    with gr.Tab('Markdown text'):
                         md_text = gr.TextArea(lines=45, show_copy_button=True)
                         md_text = gr.TextArea(lines=45, show_copy_button=True)
         file.upload(fn=to_pdf, inputs=file, outputs=pdf_show)
         file.upload(fn=to_pdf, inputs=file, outputs=pdf_show)
         change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
         change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
                         outputs=[md, md_text, output_file, pdf_show])
                         outputs=[md, md_text, output_file, pdf_show])
         clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language])
         clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language])
 
 
-    demo.launch(server_name="0.0.0.0")
+    demo.launch(server_name='0.0.0.0')

+ 1 - 1
tests/unittest/test_data/assets/jsonl/test_02.jsonl

@@ -1 +1 @@
-{"track_id":"e8824f5a-9fcb-4ee5-b2d4-6bf2c67019dc","path":"tests/test_data/assets/pdfs/test_02.pdf","file_type":"pdf","content_type":"application/pdf","content_length":80078,"title":"German Idealism and the Concept of Punishment || Conclusion","remark":{"file_id":"scihub_78800000/libgen.scimag78872000-78872999.zip_10.1017/cbo9780511770425.012","file_source_type":"paper","original_file_id":"10.1017/cbo9780511770425.012","file_name":"10.1017/cbo9780511770425.012.pdf","author":"Merle, Jean-Christophe"}}
+{"track_id":"e8824f5a-9fcb-4ee5-b2d4-6bf2c67019dc","path":"tests/unittest/test_data/assets/pdfs/test_02.pdf","file_type":"pdf","content_type":"application/pdf","content_length":80078,"title":"German Idealism and the Concept of Punishment || Conclusion","remark":{"file_id":"scihub_78800000/libgen.scimag78872000-78872999.zip_10.1017/cbo9780511770425.012","file_source_type":"paper","original_file_id":"10.1017/cbo9780511770425.012","file_name":"10.1017/cbo9780511770425.012.pdf","author":"Merle, Jean-Christophe"}}

+ 2 - 2
tests/unittest/test_data/test_dataset.py

@@ -3,7 +3,7 @@ from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
 
 
 
 
 def test_pymudataset():
 def test_pymudataset():
-    with open('tests/test_data/assets/pdfs/test_01.pdf', 'rb') as f:
+    with open('tests/unittest/test_data/assets/pdfs/test_01.pdf', 'rb') as f:
         bits = f.read()
         bits = f.read()
     datasets = PymuDocDataset(bits)
     datasets = PymuDocDataset(bits)
     assert len(datasets) > 0
     assert len(datasets) > 0
@@ -11,7 +11,7 @@ def test_pymudataset():
 
 
 
 
 def test_imagedataset():
 def test_imagedataset():
-    with open('tests/test_data/assets/pngs/test_01.png', 'rb') as f:
+    with open('tests/unittest/test_data/assets/pngs/test_01.png', 'rb') as f:
         bits = f.read()
         bits = f.read()
     datasets = ImageDataset(bits)
     datasets = ImageDataset(bits)
     assert len(datasets) == 1
     assert len(datasets) == 1

+ 4 - 4
tests/unittest/test_data/test_read_api.py

@@ -9,7 +9,7 @@ from magic_pdf.data.schemas import S3Config
 
 
 
 
 def test_read_local_pdfs():
 def test_read_local_pdfs():
-    datasets = read_local_pdfs('tests/test_data/assets/pdfs')
+    datasets = read_local_pdfs('tests/unittest/test_data/assets/pdfs')
     assert len(datasets) == 2
     assert len(datasets) == 2
     assert len(datasets[0]) > 0
     assert len(datasets[0]) > 0
     assert len(datasets[1]) > 0
     assert len(datasets[1]) > 0
@@ -19,7 +19,7 @@ def test_read_local_pdfs():
 
 
 
 
 def test_read_local_images():
 def test_read_local_images():
-    datasets = read_local_images('tests/test_data/assets/pngs', suffixes=['png'])
+    datasets = read_local_images('tests/unittest/test_data/assets/pngs', suffixes=['png'])
     assert len(datasets) == 2
     assert len(datasets) == 2
     assert len(datasets[0]) == 1
     assert len(datasets[0]) == 1
     assert len(datasets[1]) == 1
     assert len(datasets[1]) == 1
@@ -69,10 +69,10 @@ def test_read_json():
     assert len(datasets) > 0
     assert len(datasets) > 0
     assert len(datasets[0]) == 10
     assert len(datasets[0]) == 10
 
 
-    datasets = read_jsonl('tests/test_data/assets/jsonl/test_01.jsonl', reader)
+    datasets = read_jsonl('tests/unittest/test_data/assets/jsonl/test_01.jsonl', reader)
     assert len(datasets) == 1
     assert len(datasets) == 1
     assert len(datasets[0]) == 10
     assert len(datasets[0]) == 10
 
 
-    datasets = read_jsonl('tests/test_data/assets/jsonl/test_02.jsonl')
+    datasets = read_jsonl('tests/unittest/test_data/assets/jsonl/test_02.jsonl')
     assert len(datasets) == 1
     assert len(datasets) == 1
     assert len(datasets[0]) == 1
     assert len(datasets[0]) == 1

+ 2 - 2
tests/unittest/test_integrations/test_rag/test_api.py

@@ -17,7 +17,7 @@ def test_rag_document_reader():
     os.makedirs(temp_output_dir, exist_ok=True)
     os.makedirs(temp_output_dir, exist_ok=True)
 
 
     # test
     # test
-    with open('tests/test_integrations/test_rag/assets/middle.json') as f:
+    with open('tests/unittest/test_integrations/test_rag/assets/middle.json') as f:
         json_data = json.load(f)
         json_data = json.load(f)
     res = convert_middle_json_to_layout_elements(json_data, temp_output_dir)
     res = convert_middle_json_to_layout_elements(json_data, temp_output_dir)
 
 
@@ -43,7 +43,7 @@ def test_data_reader():
     os.makedirs(temp_output_dir, exist_ok=True)
     os.makedirs(temp_output_dir, exist_ok=True)
 
 
     # test
     # test
-    data_reader = DataReader('tests/test_integrations/test_rag/assets', 'ocr',
+    data_reader = DataReader('tests/unittest/test_integrations/test_rag/assets', 'ocr',
                              temp_output_dir)
                              temp_output_dir)
 
 
     assert data_reader.get_documents_count() == 2
     assert data_reader.get_documents_count() == 2

+ 3 - 3
tests/unittest/test_integrations/test_rag/test_utils.py

@@ -16,7 +16,7 @@ def test_convert_middle_json_to_layout_elements():
     os.makedirs(temp_output_dir, exist_ok=True)
     os.makedirs(temp_output_dir, exist_ok=True)
 
 
     # test
     # test
-    with open('tests/test_integrations/test_rag/assets/middle.json') as f:
+    with open('tests/unittest/test_integrations/test_rag/assets/middle.json') as f:
         json_data = json.load(f)
         json_data = json.load(f)
     res = convert_middle_json_to_layout_elements(json_data, temp_output_dir)
     res = convert_middle_json_to_layout_elements(json_data, temp_output_dir)
 
 
@@ -32,7 +32,7 @@ def test_convert_middle_json_to_layout_elements():
 
 
 def test_inference():
 def test_inference():
 
 
-    asset_dir = 'tests/test_integrations/test_rag/assets'
+    asset_dir = 'tests/unittest/test_integrations/test_rag/assets'
     # setup
     # setup
     unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
     unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
     os.makedirs(unitest_dir, exist_ok=True)
     os.makedirs(unitest_dir, exist_ok=True)
@@ -48,7 +48,7 @@ def test_inference():
 
 
     assert res is not None
     assert res is not None
     assert len(res) == 1
     assert len(res) == 1
-    assert len(res[0].layout_dets) == 10
+    assert len(res[0].layout_dets) == 11
     assert res[0].layout_dets[0].anno_id == 0
     assert res[0].layout_dets[0].anno_id == 0
     assert res[0].layout_dets[0].category_type == CategoryType.text
     assert res[0].layout_dets[0].category_type == CategoryType.text
     assert len(res[0].extra.element_relation) == 3
     assert len(res[0].extra.element_relation) == 3

+ 4 - 4
tests/unittest/test_model/test_magic_model.py

@@ -5,8 +5,8 @@ from magic_pdf.model.magic_model import MagicModel
 
 
 
 
 def test_magic_model_image_v2():
 def test_magic_model_image_v2():
-    datasets = read_local_pdfs('tests/test_model/assets/test_01.pdf')
-    with open('tests/test_model/assets/test_01.model.json') as f:
+    datasets = read_local_pdfs('tests/unittest/test_model/assets/test_01.pdf')
+    with open('tests/unittest/test_model/assets/test_01.model.json') as f:
         model_json = json.load(f)
         model_json = json.load(f)
 
 
     magic_model = MagicModel(model_json, datasets[0])
     magic_model = MagicModel(model_json, datasets[0])
@@ -19,8 +19,8 @@ def test_magic_model_image_v2():
 
 
 
 
 def test_magic_model_table_v2():
 def test_magic_model_table_v2():
-    datasets = read_local_pdfs('tests/test_model/assets/test_02.pdf')
-    with open('tests/test_model/assets/test_02.model.json') as f:
+    datasets = read_local_pdfs('tests/unittest/test_model/assets/test_02.pdf')
+    with open('tests/unittest/test_model/assets/test_02.model.json') as f:
         model_json = json.load(f)
         model_json = json.load(f)
 
 
     magic_model = MagicModel(model_json, datasets[0])
     magic_model = MagicModel(model_json, datasets[0])

File diff suppressed because it is too large
+ 0 - 0
tests/unittest/test_tools/assets/cli_dev/cli_test_01.jsonl


+ 52 - 51
tests/unittest/test_tools/test_cli.py

@@ -1,6 +1,7 @@
-import tempfile
 import os
 import os
 import shutil
 import shutil
+import tempfile
+
 from click.testing import CliRunner
 from click.testing import CliRunner
 
 
 from magic_pdf.tools.cli import cli
 from magic_pdf.tools.cli import cli
@@ -8,19 +9,19 @@ from magic_pdf.tools.cli import cli
 
 
 def test_cli_pdf():
 def test_cli_pdf():
     # setup
     # setup
-    unitest_dir = "/tmp/magic_pdf/unittest/tools"
-    filename = "cli_test_01"
+    unitest_dir = '/tmp/magic_pdf/unittest/tools'
+    filename = 'cli_test_01'
     os.makedirs(unitest_dir, exist_ok=True)
     os.makedirs(unitest_dir, exist_ok=True)
-    temp_output_dir = tempfile.mkdtemp(dir="/tmp/magic_pdf/unittest/tools")
+    temp_output_dir = tempfile.mkdtemp(dir='/tmp/magic_pdf/unittest/tools')
 
 
     # run
     # run
     runner = CliRunner()
     runner = CliRunner()
     result = runner.invoke(
     result = runner.invoke(
         cli,
         cli,
         [
         [
-            "-p",
-            "tests/test_tools/assets/cli/pdf/cli_test_01.pdf",
-            "-o",
+            '-p',
+            'tests/unittest/test_tools/assets/cli/pdf/cli_test_01.pdf',
+            '-o',
             temp_output_dir,
             temp_output_dir,
         ],
         ],
     )
     )
@@ -28,29 +29,29 @@ def test_cli_pdf():
     # check
     # check
     assert result.exit_code == 0
     assert result.exit_code == 0
 
 
-    base_output_dir = os.path.join(temp_output_dir, "cli_test_01/auto")
+    base_output_dir = os.path.join(temp_output_dir, 'cli_test_01/auto')
 
 
-    r = os.stat(os.path.join(base_output_dir, f"{filename}.md"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}.md'))
     assert r.st_size > 7000
     assert r.st_size > 7000
 
 
-    r = os.stat(os.path.join(base_output_dir, "middle.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_middle.json'))
     assert r.st_size > 200000
     assert r.st_size > 200000
 
 
-    r = os.stat(os.path.join(base_output_dir, "model.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_model.json'))
     assert r.st_size > 15000
     assert r.st_size > 15000
 
 
-    r = os.stat(os.path.join(base_output_dir, "origin.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_origin.pdf'))
+    assert r.st_size > 400000
 
 
-    r = os.stat(os.path.join(base_output_dir, "layout.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_layout.pdf'))
+    assert r.st_size > 400000
 
 
-    r = os.stat(os.path.join(base_output_dir, "spans.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_spans.pdf'))
+    assert r.st_size > 400000
 
 
-    assert os.path.exists(os.path.join(base_output_dir, "images")) is True
-    assert os.path.isdir(os.path.join(base_output_dir, "images")) is True
-    assert os.path.exists(os.path.join(base_output_dir, "content_list.json")) is False
+    assert os.path.exists(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.isdir(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.exists(os.path.join(base_output_dir, f'{filename}_content_list.json')) is True
 
 
     # teardown
     # teardown
     shutil.rmtree(temp_output_dir)
     shutil.rmtree(temp_output_dir)
@@ -58,68 +59,68 @@ def test_cli_pdf():
 
 
 def test_cli_path():
 def test_cli_path():
     # setup
     # setup
-    unitest_dir = "/tmp/magic_pdf/unittest/tools"
+    unitest_dir = '/tmp/magic_pdf/unittest/tools'
     os.makedirs(unitest_dir, exist_ok=True)
     os.makedirs(unitest_dir, exist_ok=True)
-    temp_output_dir = tempfile.mkdtemp(dir="/tmp/magic_pdf/unittest/tools")
+    temp_output_dir = tempfile.mkdtemp(dir='/tmp/magic_pdf/unittest/tools')
 
 
     # run
     # run
     runner = CliRunner()
     runner = CliRunner()
     result = runner.invoke(
     result = runner.invoke(
-        cli, ["-p", "tests/test_tools/assets/cli/path", "-o", temp_output_dir]
+        cli, ['-p', 'tests/unittest/test_tools/assets/cli/path', '-o', temp_output_dir]
     )
     )
 
 
     # check
     # check
     assert result.exit_code == 0
     assert result.exit_code == 0
 
 
-    filename = "cli_test_01"
-    base_output_dir = os.path.join(temp_output_dir, "cli_test_01/auto")
+    filename = 'cli_test_01'
+    base_output_dir = os.path.join(temp_output_dir, 'cli_test_01/auto')
 
 
-    r = os.stat(os.path.join(base_output_dir, f"{filename}.md"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}.md'))
     assert r.st_size > 7000
     assert r.st_size > 7000
 
 
-    r = os.stat(os.path.join(base_output_dir, "middle.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_middle.json'))
     assert r.st_size > 200000
     assert r.st_size > 200000
 
 
-    r = os.stat(os.path.join(base_output_dir, "model.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_model.json'))
     assert r.st_size > 15000
     assert r.st_size > 15000
 
 
-    r = os.stat(os.path.join(base_output_dir, "origin.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_origin.pdf'))
+    assert r.st_size > 400000
 
 
-    r = os.stat(os.path.join(base_output_dir, "layout.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_layout.pdf'))
+    assert r.st_size > 400000
 
 
-    r = os.stat(os.path.join(base_output_dir, "spans.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_spans.pdf'))
+    assert r.st_size > 400000
 
 
-    assert os.path.exists(os.path.join(base_output_dir, "images")) is True
-    assert os.path.isdir(os.path.join(base_output_dir, "images")) is True
-    assert os.path.exists(os.path.join(base_output_dir, "content_list.json")) is False
+    assert os.path.exists(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.isdir(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.exists(os.path.join(base_output_dir, f'{filename}_content_list.json')) is True
 
 
-    base_output_dir = os.path.join(temp_output_dir, "cli_test_02/auto")
-    filename = "cli_test_02"
+    base_output_dir = os.path.join(temp_output_dir, 'cli_test_02/auto')
+    filename = 'cli_test_02'
 
 
-    r = os.stat(os.path.join(base_output_dir, f"{filename}.md"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}.md'))
     assert r.st_size > 5000
     assert r.st_size > 5000
 
 
-    r = os.stat(os.path.join(base_output_dir, "middle.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_middle.json'))
     assert r.st_size > 200000
     assert r.st_size > 200000
 
 
-    r = os.stat(os.path.join(base_output_dir, "model.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_model.json'))
     assert r.st_size > 15000
     assert r.st_size > 15000
 
 
-    r = os.stat(os.path.join(base_output_dir, "origin.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_origin.pdf'))
+    assert r.st_size > 400000
 
 
-    r = os.stat(os.path.join(base_output_dir, "layout.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_layout.pdf'))
+    assert r.st_size > 400000
 
 
-    r = os.stat(os.path.join(base_output_dir, "spans.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_spans.pdf'))
+    assert r.st_size > 400000
 
 
-    assert os.path.exists(os.path.join(base_output_dir, "images")) is True
-    assert os.path.isdir(os.path.join(base_output_dir, "images")) is True
-    assert os.path.exists(os.path.join(base_output_dir, "content_list.json")) is False
+    assert os.path.exists(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.isdir(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.exists(os.path.join(base_output_dir, f'{filename}_content_list.json')) is True
 
 
     # teardown
     # teardown
     shutil.rmtree(temp_output_dir)
     shutil.rmtree(temp_output_dir)

+ 46 - 46
tests/unittest/test_tools/test_cli_dev.py

@@ -1,6 +1,7 @@
-import tempfile
 import os
 import os
 import shutil
 import shutil
+import tempfile
+
 from click.testing import CliRunner
 from click.testing import CliRunner
 
 
 from magic_pdf.tools import cli_dev
 from magic_pdf.tools import cli_dev
@@ -8,22 +9,22 @@ from magic_pdf.tools import cli_dev
 
 
 def test_cli_pdf():
 def test_cli_pdf():
     # setup
     # setup
-    unitest_dir = "/tmp/magic_pdf/unittest/tools"
-    filename = "cli_test_01"
+    unitest_dir = '/tmp/magic_pdf/unittest/tools'
+    filename = 'cli_test_01'
     os.makedirs(unitest_dir, exist_ok=True)
     os.makedirs(unitest_dir, exist_ok=True)
-    temp_output_dir = tempfile.mkdtemp(dir="/tmp/magic_pdf/unittest/tools")
+    temp_output_dir = tempfile.mkdtemp(dir='/tmp/magic_pdf/unittest/tools')
 
 
     # run
     # run
     runner = CliRunner()
     runner = CliRunner()
     result = runner.invoke(
     result = runner.invoke(
         cli_dev.cli,
         cli_dev.cli,
         [
         [
-            "pdf",
-            "-p",
-            "tests/test_tools/assets/cli/pdf/cli_test_01.pdf",
-            "-j",
-            "tests/test_tools/assets/cli_dev/cli_test_01.model.json",
-            "-o",
+            'pdf',
+            '-p',
+            'tests/unittest/test_tools/assets/cli/pdf/cli_test_01.pdf',
+            '-j',
+            'tests/unittest/test_tools/assets/cli_dev/cli_test_01.model.json',
+            '-o',
             temp_output_dir,
             temp_output_dir,
         ],
         ],
     )
     )
@@ -31,31 +32,30 @@ def test_cli_pdf():
     # check
     # check
     assert result.exit_code == 0
     assert result.exit_code == 0
 
 
-    base_output_dir = os.path.join(temp_output_dir, "cli_test_01/auto")
+    base_output_dir = os.path.join(temp_output_dir, 'cli_test_01/auto')
 
 
-    r = os.stat(os.path.join(base_output_dir, "content_list.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_content_list.json'))
     assert r.st_size > 5000
     assert r.st_size > 5000
-
-    r = os.stat(os.path.join(base_output_dir, f"{filename}.md"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}.md'))
     assert r.st_size > 7000
     assert r.st_size > 7000
 
 
-    r = os.stat(os.path.join(base_output_dir, "middle.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_middle.json'))
     assert r.st_size > 200000
     assert r.st_size > 200000
 
 
-    r = os.stat(os.path.join(base_output_dir, "model.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_model.json'))
     assert r.st_size > 15000
     assert r.st_size > 15000
 
 
-    r = os.stat(os.path.join(base_output_dir, "origin.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_origin.pdf'))
+    assert r.st_size > 400000
 
 
-    r = os.stat(os.path.join(base_output_dir, "layout.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_layout.pdf'))
+    assert r.st_size > 400000
 
 
-    r = os.stat(os.path.join(base_output_dir, "spans.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_spans.pdf'))
+    assert r.st_size > 400000
 
 
-    assert os.path.exists(os.path.join(base_output_dir, "images")) is True
-    assert os.path.isdir(os.path.join(base_output_dir, "images")) is True
+    assert os.path.exists(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.isdir(os.path.join(base_output_dir, 'images')) is True
 
 
     # teardown
     # teardown
     shutil.rmtree(temp_output_dir)
     shutil.rmtree(temp_output_dir)
@@ -63,26 +63,26 @@ def test_cli_pdf():
 
 
 def test_cli_jsonl():
 def test_cli_jsonl():
     # setup
     # setup
-    unitest_dir = "/tmp/magic_pdf/unittest/tools"
-    filename = "cli_test_01"
+    unitest_dir = '/tmp/magic_pdf/unittest/tools'
+    filename = 'cli_test_01'
     os.makedirs(unitest_dir, exist_ok=True)
     os.makedirs(unitest_dir, exist_ok=True)
-    temp_output_dir = tempfile.mkdtemp(dir="/tmp/magic_pdf/unittest/tools")
+    temp_output_dir = tempfile.mkdtemp(dir='/tmp/magic_pdf/unittest/tools')
 
 
     def mock_read_s3_path(s3path):
     def mock_read_s3_path(s3path):
-        with open(s3path, "rb") as f:
+        with open(s3path, 'rb') as f:
             return f.read()
             return f.read()
 
 
-    cli_dev.read_s3_path = mock_read_s3_path # mock
+    cli_dev.read_s3_path = mock_read_s3_path  # mock
 
 
     # run
     # run
     runner = CliRunner()
     runner = CliRunner()
     result = runner.invoke(
     result = runner.invoke(
         cli_dev.cli,
         cli_dev.cli,
         [
         [
-            "jsonl",
-            "-j",
-            "tests/test_tools/assets/cli_dev/cli_test_01.jsonl",
-            "-o",
+            'jsonl',
+            '-j',
+            'tests/unittest/test_tools/assets/cli_dev/cli_test_01.jsonl',
+            '-o',
             temp_output_dir,
             temp_output_dir,
         ],
         ],
     )
     )
@@ -90,31 +90,31 @@ def test_cli_jsonl():
     # check
     # check
     assert result.exit_code == 0
     assert result.exit_code == 0
 
 
-    base_output_dir = os.path.join(temp_output_dir, "cli_test_01/auto")
+    base_output_dir = os.path.join(temp_output_dir, 'cli_test_01/auto')
 
 
-    r = os.stat(os.path.join(base_output_dir, "content_list.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_content_list.json'))
     assert r.st_size > 5000
     assert r.st_size > 5000
 
 
-    r = os.stat(os.path.join(base_output_dir, f"{filename}.md"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}.md'))
     assert r.st_size > 7000
     assert r.st_size > 7000
 
 
-    r = os.stat(os.path.join(base_output_dir, "middle.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_middle.json'))
     assert r.st_size > 200000
     assert r.st_size > 200000
 
 
-    r = os.stat(os.path.join(base_output_dir, "model.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_model.json'))
     assert r.st_size > 15000
     assert r.st_size > 15000
 
 
-    r = os.stat(os.path.join(base_output_dir, "origin.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_origin.pdf'))
+    assert r.st_size > 400000
 
 
-    r = os.stat(os.path.join(base_output_dir, "layout.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_layout.pdf'))
+    assert r.st_size > 400000
 
 
-    r = os.stat(os.path.join(base_output_dir, "spans.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_spans.pdf'))
+    assert r.st_size > 400000
 
 
-    assert os.path.exists(os.path.join(base_output_dir, "images")) is True
-    assert os.path.isdir(os.path.join(base_output_dir, "images")) is True
+    assert os.path.exists(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.isdir(os.path.join(base_output_dir, 'images')) is True
 
 
     # teardown
     # teardown
     shutil.rmtree(temp_output_dir)
     shutil.rmtree(temp_output_dir)

+ 21 - 19
tests/unittest/test_tools/test_common.py

@@ -1,23 +1,25 @@
-import tempfile
 import os
 import os
 import shutil
 import shutil
+import tempfile
 
 
 import pytest
 import pytest
 
 
 from magic_pdf.tools.common import do_parse
 from magic_pdf.tools.common import do_parse
 
 
 
 
-@pytest.mark.parametrize("method", ["auto", "txt", "ocr"])
+@pytest.mark.parametrize('method', ['auto', 'txt', 'ocr'])
 def test_common_do_parse(method):
 def test_common_do_parse(method):
+    import magic_pdf.model as model_config
+    model_config.__use_inside_model__ = True
     # setup
     # setup
-    unitest_dir = "/tmp/magic_pdf/unittest/tools"
-    filename = "fake"
+    unitest_dir = '/tmp/magic_pdf/unittest/tools'
+    filename = 'fake'
     os.makedirs(unitest_dir, exist_ok=True)
     os.makedirs(unitest_dir, exist_ok=True)
 
 
-    temp_output_dir = tempfile.mkdtemp(dir="/tmp/magic_pdf/unittest/tools")
+    temp_output_dir = tempfile.mkdtemp(dir='/tmp/magic_pdf/unittest/tools')
 
 
     # run
     # run
-    with open("tests/test_tools/assets/common/cli_test_01.pdf", "rb") as f:
+    with open('tests/unittest/test_tools/assets/common/cli_test_01.pdf', 'rb') as f:
         bits = f.read()
         bits = f.read()
     do_parse(temp_output_dir,
     do_parse(temp_output_dir,
              filename,
              filename,
@@ -27,31 +29,31 @@ def test_common_do_parse(method):
              f_dump_content_list=True)
              f_dump_content_list=True)
 
 
     # check
     # check
-    base_output_dir = os.path.join(temp_output_dir, f"fake/{method}")
+    base_output_dir = os.path.join(temp_output_dir, f'fake/{method}')
 
 
-    r = os.stat(os.path.join(base_output_dir, "content_list.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_content_list.json'))
     assert r.st_size > 5000
     assert r.st_size > 5000
 
 
-    r = os.stat(os.path.join(base_output_dir, f"{filename}.md"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}.md'))
     assert r.st_size > 7000
     assert r.st_size > 7000
 
 
-    r = os.stat(os.path.join(base_output_dir, "middle.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_middle.json'))
     assert r.st_size > 200000
     assert r.st_size > 200000
 
 
-    r = os.stat(os.path.join(base_output_dir, "model.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_model.json'))
     assert r.st_size > 15000
     assert r.st_size > 15000
 
 
-    r = os.stat(os.path.join(base_output_dir, "origin.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_origin.pdf'))
+    assert r.st_size > 400000
 
 
-    r = os.stat(os.path.join(base_output_dir, "layout.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_layout.pdf'))
+    assert r.st_size > 400000
 
 
-    r = os.stat(os.path.join(base_output_dir, "spans.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_spans.pdf'))
+    assert r.st_size > 400000
 
 
-    os.path.exists(os.path.join(base_output_dir, "images"))
-    os.path.isdir(os.path.join(base_output_dir, "images"))
+    os.path.exists(os.path.join(base_output_dir, 'images'))
+    os.path.isdir(os.path.join(base_output_dir, 'images'))
 
 
     # teardown
     # teardown
     shutil.rmtree(temp_output_dir)
     shutil.rmtree(temp_output_dir)

Some files were not shown because too many files changed in this diff