Przeglądaj źródła

fix: using new data api replace old rw api

icecraft 1 rok temu
rodzic
commit
6a481320ea

+ 3 - 4
magic_pdf/integrations/rag/utils.py

@@ -5,14 +5,13 @@ from pathlib import Path
 from loguru import logger
 
 import magic_pdf.model as model_config
+from magic_pdf.data.data_reader_writer import FileBasedDataReader
 from magic_pdf.dict2md.ocr_mkcontent import merge_para_with_text
 from magic_pdf.integrations.rag.type import (CategoryType, ContentObject,
                                              ElementRelation, ElementRelType,
                                              LayoutElements,
                                              LayoutElementsExtra, PageInfo)
 from magic_pdf.libs.ocr_content_type import BlockType, ContentType
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 from magic_pdf.tools.common import do_parse, prepare_env
 
 
@@ -224,8 +223,8 @@ def inference(path, output_dir, method):
                                                 str(Path(path).stem), method)
 
     def read_fn(path):
-        disk_rw = DiskReaderWriter(os.path.dirname(path))
-        return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
+        disk_rw = FileBasedDataReader(os.path.dirname(path))
+        return disk_rw.read(os.path.basename(path))
 
     def parse_doc(doc_path: str):
         try:

+ 9 - 11
magic_pdf/libs/pdf_image_tools.py

@@ -1,23 +1,21 @@
 
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.libs.commons import fitz
-from magic_pdf.libs.commons import join_path
+from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.libs.commons import fitz, join_path
 from magic_pdf.libs.hash_utils import compute_sha256
 
 
-def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: AbsReaderWriter):
-    """
-    从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径
-    save_path:需要同时支持s3和本地, 图片存放在save_path下,文件名是: {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。
-    """
+def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWriter: DataWriter):
+    """从第page_num页的page中,根据bbox进行裁剪出一张jpg图片,返回图片路径 save_path:需要同时支持s3和本地,
+    图片存放在save_path下,文件名是:
+    {page_num}_{bbox[0]}_{bbox[1]}_{bbox[2]}_{bbox[3]}.jpg , bbox内数字取整。"""
     # 拼接文件名
-    filename = f"{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}"
+    filename = f'{page_num}_{int(bbox[0])}_{int(bbox[1])}_{int(bbox[2])}_{int(bbox[3])}'
 
     # 老版本返回不带bucket的路径
     img_path = join_path(return_path, filename) if return_path is not None else None
 
     # 新版本生成平铺路径
-    img_hash256_path = f"{compute_sha256(img_path)}.jpg"
+    img_hash256_path = f'{compute_sha256(img_path)}.jpg'
 
     # 将坐标转换为fitz.Rect对象
     rect = fitz.Rect(*bbox)
@@ -28,6 +26,6 @@ def cut_image(bbox: tuple, page_num: int, page: fitz.Page, return_path, imageWri
 
     byte_data = pix.tobytes(output='jpeg', jpg_quality=95)
 
-    imageWriter.write(byte_data, img_hash256_path, AbsReaderWriter.MODE_BIN)
+    imageWriter.write(img_hash256_path, byte_data)
 
     return img_hash256_path

+ 11 - 11
magic_pdf/model/magic_model.py

@@ -1,6 +1,8 @@
 import enum
 import json
 
+from magic_pdf.data.data_reader_writer import (FileBasedDataReader,
+                                               FileBasedDataWriter)
 from magic_pdf.data.dataset import Dataset
 from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance,
                                     bbox_relative_pos, box_area, calculate_iou,
@@ -12,8 +14,6 @@ from magic_pdf.libs.local_math import float_gt
 from magic_pdf.libs.ModelBlockTypeEnum import ModelBlockTypeEnum
 from magic_pdf.libs.ocr_content_type import CategoryId, ContentType
 from magic_pdf.pre_proc.remove_bbox_overlap import _remove_overlap_between_bbox
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 
 CAPATION_OVERLAP_AREA_RATIO = 0.6
 MERGE_BOX_OVERLAP_AREA_RATIO = 1.1
@@ -1050,27 +1050,27 @@ class MagicModel:
 
 
 if __name__ == '__main__':
-    drw = DiskReaderWriter(r'D:/project/20231108code-clean')
+    drw = FileBasedDataReader(r'D:/project/20231108code-clean')
     if 0:
         pdf_file_path = r'linshixuqiu\19983-00.pdf'
         model_file_path = r'linshixuqiu\19983-00_new.json'
-        pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
-        model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
+        pdf_bytes = drw.read(pdf_file_path)
+        model_json_txt = drw.read(model_file_path).decode()
         model_list = json.loads(model_json_txt)
         write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
         img_bucket_path = 'imgs'
-        img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
+        img_writer = FileBasedDataWriter(join_path(write_path, img_bucket_path))
         pdf_docs = fitz.open('pdf', pdf_bytes)
         magic_model = MagicModel(model_list, pdf_docs)
 
     if 1:
+        from magic_pdf.data.dataset import PymuDocDataset
+
         model_list = json.loads(
             drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.json')
         )
-        pdf_bytes = drw.read(
-            '/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf', AbsReaderWriter.MODE_BIN
-        )
-        pdf_docs = fitz.open('pdf', pdf_bytes)
-        magic_model = MagicModel(model_list, pdf_docs)
+        pdf_bytes = drw.read('/opt/data/pdf/20240418/j.chroma.2009.03.042.pdf')
+
+        magic_model = MagicModel(model_list, PymuDocDataset(pdf_bytes))
         for i in range(7):
             print(magic_model.get_imgs(i))

+ 27 - 43
magic_pdf/pipe/AbsPipe.py

@@ -1,22 +1,20 @@
 from abc import ABC, abstractmethod
 
+from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.dict2md.ocr_mkcontent import union_make
 from magic_pdf.filter.pdf_classify_by_type import classify
 from magic_pdf.filter.pdf_meta_scan import pdf_meta_scan
-from magic_pdf.libs.MakeContentConfig import MakeMode, DropMode
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
 from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.json_compressor import JsonCompressor
+from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 
 
 class AbsPipe(ABC):
-    """
-    txt和ocr处理的抽象类
-    """
-    PIP_OCR = "ocr"
-    PIP_TXT = "txt"
+    """txt和ocr处理的抽象类."""
+    PIP_OCR = 'ocr'
+    PIP_TXT = 'txt'
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
                  start_page_id=0, end_page_id=None, lang=None, layout_model=None, formula_enable=None, table_enable=None):
         self.pdf_bytes = pdf_bytes
         self.model_list = model_list
@@ -29,29 +27,23 @@ class AbsPipe(ABC):
         self.layout_model = layout_model
         self.formula_enable = formula_enable
         self.table_enable = table_enable
-    
+
     def get_compress_pdf_mid_data(self):
         return JsonCompressor.compress_json(self.pdf_mid_data)
 
     @abstractmethod
     def pipe_classify(self):
-        """
-        有状态的分类
-        """
+        """有状态的分类."""
         raise NotImplementedError
 
     @abstractmethod
     def pipe_analyze(self):
-        """
-        有状态的跑模型分析
-        """
+        """有状态的跑模型分析."""
         raise NotImplementedError
 
     @abstractmethod
     def pipe_parse(self):
-        """
-        有状态的解析
-        """
+        """有状态的解析."""
         raise NotImplementedError
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
@@ -64,27 +56,25 @@ class AbsPipe(ABC):
 
     @staticmethod
     def classify(pdf_bytes: bytes) -> str:
-        """
-        根据pdf的元数据,判断是文本pdf,还是ocr pdf
-        """
+        """根据pdf的元数据,判断是文本pdf,还是ocr pdf."""
         pdf_meta = pdf_meta_scan(pdf_bytes)
-        if pdf_meta.get("_need_drop", False):  # 如果返回了需要丢弃的标志,则抛出异常
+        if pdf_meta.get('_need_drop', False):  # 如果返回了需要丢弃的标志,则抛出异常
             raise Exception(f"pdf meta_scan need_drop,reason is {pdf_meta['_drop_reason']}")
         else:
-            is_encrypted = pdf_meta["is_encrypted"]
-            is_needs_password = pdf_meta["is_needs_password"]
+            is_encrypted = pdf_meta['is_encrypted']
+            is_needs_password = pdf_meta['is_needs_password']
             if is_encrypted or is_needs_password:  # 加密的,需要密码的,没有页面的,都不处理
-                raise Exception(f"pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}")
+                raise Exception(f'pdf meta_scan need_drop,reason is {DropReason.ENCRYPTED}')
             else:
                 is_text_pdf, results = classify(
-                    pdf_meta["total_page"],
-                    pdf_meta["page_width_pts"],
-                    pdf_meta["page_height_pts"],
-                    pdf_meta["image_info_per_page"],
-                    pdf_meta["text_len_per_page"],
-                    pdf_meta["imgs_per_page"],
-                    pdf_meta["text_layout_per_page"],
-                    pdf_meta["invalid_chars"],
+                    pdf_meta['total_page'],
+                    pdf_meta['page_width_pts'],
+                    pdf_meta['page_height_pts'],
+                    pdf_meta['image_info_per_page'],
+                    pdf_meta['text_len_per_page'],
+                    pdf_meta['imgs_per_page'],
+                    pdf_meta['text_layout_per_page'],
+                    pdf_meta['invalid_chars'],
                 )
                 if is_text_pdf:
                     return AbsPipe.PIP_TXT
@@ -93,22 +83,16 @@ class AbsPipe(ABC):
 
     @staticmethod
     def mk_uni_format(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF) -> list:
-        """
-        根据pdf类型,生成统一格式content_list
-        """
+        """根据pdf类型,生成统一格式content_list."""
         pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
-        pdf_info_list = pdf_mid_data["pdf_info"]
+        pdf_info_list = pdf_mid_data['pdf_info']
         content_list = union_make(pdf_info_list, MakeMode.STANDARD_FORMAT, drop_mode, img_buket_path)
         return content_list
 
     @staticmethod
     def mk_markdown(compressed_pdf_mid_data: str, img_buket_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD) -> list:
-        """
-        根据pdf类型,markdown
-        """
+        """根据pdf类型,markdown."""
         pdf_mid_data = JsonCompressor.decompress_json(compressed_pdf_mid_data)
-        pdf_info_list = pdf_mid_data["pdf_info"]
+        pdf_info_list = pdf_mid_data['pdf_info']
         md_content = union_make(pdf_info_list, md_make_mode, drop_mode, img_buket_path)
         return md_content
-
-

+ 4 - 4
magic_pdf/pipe/OCRPipe.py

@@ -1,15 +1,15 @@
 from loguru import logger
 
+from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.user_api import parse_ocr_pdf
 
 
 class OCRPipe(AbsPipe):
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
                  start_page_id=0, end_page_id=None, lang=None,
                  layout_model=None, formula_enable=None, table_enable=None):
         super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
@@ -32,10 +32,10 @@ class OCRPipe(AbsPipe):
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
-        logger.info("ocr_pipe mk content list finished")
+        logger.info('ocr_pipe mk content list finished')
         return result
 
     def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
-        logger.info(f"ocr_pipe mk {md_make_mode} finished")
+        logger.info(f'ocr_pipe mk {md_make_mode} finished')
         return result

+ 4 - 5
magic_pdf/pipe/TXTPipe.py

@@ -1,16 +1,15 @@
 from loguru import logger
 
+from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.libs.json_compressor import JsonCompressor
 from magic_pdf.pipe.AbsPipe import AbsPipe
 from magic_pdf.user_api import parse_txt_pdf
 
 
 class TXTPipe(AbsPipe):
 
-    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+    def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: DataWriter, is_debug: bool = False,
                  start_page_id=0, end_page_id=None, lang=None,
                  layout_model=None, formula_enable=None, table_enable=None):
         super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang,
@@ -33,10 +32,10 @@ class TXTPipe(AbsPipe):
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
-        logger.info("txt_pipe mk content list finished")
+        logger.info('txt_pipe mk content list finished')
         return result
 
     def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
-        logger.info(f"txt_pipe mk {md_make_mode} finished")
+        logger.info(f'txt_pipe mk {md_make_mode} finished')
         return result

+ 23 - 24
magic_pdf/pipe/UNIPipe.py

@@ -2,22 +2,21 @@ import json
 
 from loguru import logger
 
+from magic_pdf.data.data_reader_writer import DataWriter
+from magic_pdf.libs.commons import join_path
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
-from magic_pdf.libs.commons import join_path
 from magic_pdf.pipe.AbsPipe import AbsPipe
-from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
+from magic_pdf.user_api import parse_ocr_pdf, parse_union_pdf
 
 
 class UNIPipe(AbsPipe):
 
-    def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
+    def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: DataWriter, is_debug: bool = False,
                  start_page_id=0, end_page_id=None, lang=None,
                  layout_model=None, formula_enable=None, table_enable=None):
-        self.pdf_type = jso_useful_key["_pdf_type"]
-        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id,
+        self.pdf_type = jso_useful_key['_pdf_type']
+        super().__init__(pdf_bytes, jso_useful_key['model_list'], image_writer, is_debug, start_page_id, end_page_id,
                          lang, layout_model, formula_enable, table_enable)
         if len(self.model_list) == 0:
             self.input_model_is_empty = True
@@ -54,27 +53,28 @@ class UNIPipe(AbsPipe):
 
     def pipe_mk_uni_format(self, img_parent_path: str, drop_mode=DropMode.NONE_WITH_REASON):
         result = super().pipe_mk_uni_format(img_parent_path, drop_mode)
-        logger.info("uni_pipe mk content list finished")
+        logger.info('uni_pipe mk content list finished')
         return result
 
     def pipe_mk_markdown(self, img_parent_path: str, drop_mode=DropMode.WHOLE_PDF, md_make_mode=MakeMode.MM_MD):
         result = super().pipe_mk_markdown(img_parent_path, drop_mode, md_make_mode)
-        logger.info(f"uni_pipe mk {md_make_mode} finished")
+        logger.info(f'uni_pipe mk {md_make_mode} finished')
         return result
 
 
 if __name__ == '__main__':
     # 测试
-    drw = DiskReaderWriter(r"D:/project/20231108code-clean")
+    from magic_pdf.data.data_reader_writer import DataReader
+    drw = DataReader(r'D:/project/20231108code-clean')
 
-    pdf_file_path = r"linshixuqiu\19983-00.pdf"
-    model_file_path = r"linshixuqiu\19983-00.json"
-    pdf_bytes = drw.read(pdf_file_path, AbsReaderWriter.MODE_BIN)
-    model_json_txt = drw.read(model_file_path, AbsReaderWriter.MODE_TXT)
+    pdf_file_path = r'linshixuqiu\19983-00.pdf'
+    model_file_path = r'linshixuqiu\19983-00.json'
+    pdf_bytes = drw.read(pdf_file_path)
+    model_json_txt = drw.read(model_file_path).decode()
     model_list = json.loads(model_json_txt)
-    write_path = r"D:\project\20231108code-clean\linshixuqiu\19983-00"
-    img_bucket_path = "imgs"
-    img_writer = DiskReaderWriter(join_path(write_path, img_bucket_path))
+    write_path = r'D:\project\20231108code-clean\linshixuqiu\19983-00'
+    img_bucket_path = 'imgs'
+    img_writer = DataWriter(join_path(write_path, img_bucket_path))
 
     # pdf_type = UNIPipe.classify(pdf_bytes)
     # jso_useful_key = {
@@ -83,8 +83,8 @@ if __name__ == '__main__':
     # }
 
     jso_useful_key = {
-        "_pdf_type": "",
-        "model_list": model_list
+        '_pdf_type': '',
+        'model_list': model_list
     }
     pipe = UNIPipe(pdf_bytes, jso_useful_key, img_writer)
     pipe.pipe_classify()
@@ -92,8 +92,7 @@ if __name__ == '__main__':
     md_content = pipe.pipe_mk_markdown(img_bucket_path)
     content_list = pipe.pipe_mk_uni_format(img_bucket_path)
 
-    md_writer = DiskReaderWriter(write_path)
-    md_writer.write(md_content, "19983-00.md", AbsReaderWriter.MODE_TXT)
-    md_writer.write(json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4), "19983-00.json",
-                    AbsReaderWriter.MODE_TXT)
-    md_writer.write(str(content_list), "19983-00.txt", AbsReaderWriter.MODE_TXT)
+    md_writer = DataWriter(write_path)
+    md_writer.write_string('19983-00.md', md_content)
+    md_writer.write_string('19983-00.json', json.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4))
+    md_writer.write_string('19983-00.txt', str(content_list))

+ 3 - 4
magic_pdf/tools/cli.py

@@ -5,9 +5,8 @@ import click
 from loguru import logger
 
 import magic_pdf.model as model_config
+from magic_pdf.data.data_reader_writer import FileBasedDataReader
 from magic_pdf.libs.version import __version__
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 from magic_pdf.tools.common import do_parse, parse_pdf_methods
 
 
@@ -86,8 +85,8 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
     os.makedirs(output_dir, exist_ok=True)
 
     def read_fn(path):
-        disk_rw = DiskReaderWriter(os.path.dirname(path))
-        return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
+        disk_rw = FileBasedDataReader(os.path.dirname(path))
+        return disk_rw.read(os.path.basename(path))
 
     def parse_doc(doc_path: str):
         try:

+ 6 - 9
magic_pdf/tools/cli_dev.py

@@ -5,13 +5,11 @@ from pathlib import Path
 import click
 
 import magic_pdf.model as model_config
+from magic_pdf.data.data_reader_writer import FileBasedDataReader, S3DataReader
 from magic_pdf.libs.config_reader import get_s3_config
 from magic_pdf.libs.path_utils import (parse_s3_range_params, parse_s3path,
                                        remove_non_official_s3_args)
 from magic_pdf.libs.version import __version__
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
-from magic_pdf.rw.S3ReaderWriter import S3ReaderWriter
 from magic_pdf.tools.common import do_parse, parse_pdf_methods
 
 
@@ -19,15 +17,14 @@ def read_s3_path(s3path):
     bucket, key = parse_s3path(s3path)
 
     s3_ak, s3_sk, s3_endpoint = get_s3_config(bucket)
-    s3_rw = S3ReaderWriter(s3_ak, s3_sk, s3_endpoint, 'auto',
-                           remove_non_official_s3_args(s3path))
+    s3_rw = S3DataReader('', bucket, s3_ak, s3_sk, s3_endpoint, 'auto')
     may_range_params = parse_s3_range_params(s3path)
     if may_range_params is None or 2 != len(may_range_params):
-        byte_start, byte_end = 0, None
+        byte_start, byte_end = 0, -1
     else:
         byte_start, byte_end = int(may_range_params[0]), int(
             may_range_params[1])
-    return s3_rw.read_offset(
+    return s3_rw.read_at(
         remove_non_official_s3_args(s3path),
         byte_start,
         byte_end,
@@ -129,8 +126,8 @@ def pdf(pdf, json_data, output_dir, method):
     os.makedirs(output_dir, exist_ok=True)
 
     def read_fn(path):
-        disk_rw = DiskReaderWriter(os.path.dirname(path))
-        return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
+        disk_rw = FileBasedDataReader(os.path.dirname(path))
+        return disk_rw.read(os.path.basename(path))
 
     model_json_list = json_parse.loads(read_fn(json_data).decode('utf-8'))
 

+ 22 - 35
magic_pdf/tools/common.py

@@ -3,18 +3,18 @@ import json as json_parse
 import os
 
 import click
+import fitz
 from loguru import logger
 
 import magic_pdf.model as model_config
+from magic_pdf.data.data_reader_writer import FileBasedDataWriter
 from magic_pdf.libs.draw_bbox import (draw_layout_bbox, draw_line_sort_bbox,
                                       draw_model_bbox, draw_span_bbox)
 from magic_pdf.libs.MakeContentConfig import DropMode, MakeMode
 from magic_pdf.pipe.OCRPipe import OCRPipe
 from magic_pdf.pipe.TXTPipe import TXTPipe
 from magic_pdf.pipe.UNIPipe import UNIPipe
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
-import fitz
+
 # from io import BytesIO
 # from pypdf import PdfReader, PdfWriter
 
@@ -54,11 +54,11 @@ def prepare_env(output_dir, pdf_file_name, method):
 
 
 def convert_pdf_bytes_to_bytes_by_pymupdf(pdf_bytes, start_page_id=0, end_page_id=None):
-    document = fitz.open("pdf", pdf_bytes)
+    document = fitz.open('pdf', pdf_bytes)
     output_document = fitz.open()
     end_page_id = end_page_id if end_page_id is not None and end_page_id >= 0 else len(document) - 1
     if end_page_id > len(document) - 1:
-        logger.warning("end_page_id is out of range, use pdf_docs length")
+        logger.warning('end_page_id is out of range, use pdf_docs length')
         end_page_id = len(document) - 1
     output_document.insert_pdf(document, from_page=start_page_id, to_page=end_page_id)
     output_bytes = output_document.tobytes()
@@ -100,8 +100,8 @@ def do_parse(
     local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name,
                                                 parse_method)
 
-    image_writer, md_writer = DiskReaderWriter(
-        local_image_dir), DiskReaderWriter(local_md_dir)
+    image_writer, md_writer = FileBasedDataWriter(
+        local_image_dir), FileBasedDataWriter(local_md_dir)
     image_dir = str(os.path.basename(local_image_dir))
 
     if parse_method == 'auto':
@@ -145,49 +145,36 @@ def do_parse(
     if f_draw_line_sort_bbox:
         draw_line_sort_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name)
 
-    md_content = pipe.pipe_mk_markdown(image_dir,
-                                       drop_mode=DropMode.NONE,
-                                       md_make_mode=f_make_md_mode)
+    md_content = pipe.pipe_mk_markdown(image_dir, drop_mode=DropMode.NONE, md_make_mode=f_make_md_mode)
     if f_dump_md:
-        md_writer.write(
-            content=md_content,
-            path=f'{pdf_file_name}.md',
-            mode=AbsReaderWriter.MODE_TXT,
+        md_writer.write_string(
+            f'{pdf_file_name}.md',
+            md_content
         )
 
     if f_dump_middle_json:
-        md_writer.write(
-            content=json_parse.dumps(pipe.pdf_mid_data,
-                                     ensure_ascii=False,
-                                     indent=4),
-            path=f'{pdf_file_name}_middle.json',
-            mode=AbsReaderWriter.MODE_TXT,
+        md_writer.write_string(
+            f'{pdf_file_name}_middle.json',
+            json_parse.dumps(pipe.pdf_mid_data, ensure_ascii=False, indent=4)
         )
 
     if f_dump_model_json:
-        md_writer.write(
-            content=json_parse.dumps(orig_model_list,
-                                     ensure_ascii=False,
-                                     indent=4),
-            path=f'{pdf_file_name}_model.json',
-            mode=AbsReaderWriter.MODE_TXT,
+        md_writer.write_string(
+            f'{pdf_file_name}_model.json',
+            json_parse.dumps(orig_model_list, ensure_ascii=False, indent=4)
         )
 
     if f_dump_orig_pdf:
         md_writer.write(
-            content=pdf_bytes,
-            path=f'{pdf_file_name}_origin.pdf',
-            mode=AbsReaderWriter.MODE_BIN,
+            f'{pdf_file_name}_origin.pdf',
+            pdf_bytes,
         )
 
     content_list = pipe.pipe_mk_uni_format(image_dir, drop_mode=DropMode.NONE)
     if f_dump_content_list:
-        md_writer.write(
-            content=json_parse.dumps(content_list,
-                                     ensure_ascii=False,
-                                     indent=4),
-            path=f'{pdf_file_name}_content_list.json',
-            mode=AbsReaderWriter.MODE_TXT,
+        md_writer.write_string(
+            f'{pdf_file_name}_content_list.json',
+            json_parse.dumps(content_list, ensure_ascii=False, indent=4)
         )
 
     logger.info(f'local output dir is {local_md_dir}')

+ 26 - 38
magic_pdf/user_api.py

@@ -1,36 +1,28 @@
-"""
-用户输入:
-    model数组,每个元素代表一个页面
-    pdf在s3的路径
-    截图保存的s3位置
+"""用户输入: model数组,每个元素代表一个页面 pdf在s3的路径 截图保存的s3位置.
 
 然后:
     1)根据s3路径,调用spark集群的api,拿到ak,sk,endpoint,构造出s3PDFReader
     2)根据用户输入的s3地址,调用spark集群的api,拿到ak,sk,endpoint,构造出s3ImageWriter
 
 其余部分至于构造s3cli, 获取ak,sk都在code-clean里写代码完成。不要反向依赖!!!
-
 """
-import re
 
 from loguru import logger
 
+from magic_pdf.data.data_reader_writer import DataWriter
 from magic_pdf.libs.version import __version__
 from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
-from magic_pdf.rw import AbsReaderWriter
 from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
 from magic_pdf.pdf_parse_by_txt import parse_pdf_by_txt
 
-PARSE_TYPE_TXT = "txt"
-PARSE_TYPE_OCR = "ocr"
+PARSE_TYPE_TXT = 'txt'
+PARSE_TYPE_OCR = 'ocr'
 
 
-def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
+def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False,
                   start_page_id=0, end_page_id=None, lang=None,
                   *args, **kwargs):
-    """
-    解析文本类pdf
-    """
+    """解析文本类pdf."""
     pdf_info_dict = parse_pdf_by_txt(
         pdf_bytes,
         pdf_models,
@@ -40,22 +32,20 @@ def parse_txt_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
         debug_mode=is_debug,
     )
 
-    pdf_info_dict["_parse_type"] = PARSE_TYPE_TXT
+    pdf_info_dict['_parse_type'] = PARSE_TYPE_TXT
 
-    pdf_info_dict["_version_name"] = __version__
+    pdf_info_dict['_version_name'] = __version__
 
     if lang is not None:
-        pdf_info_dict["_lang"] = lang
+        pdf_info_dict['_lang'] = lang
 
     return pdf_info_dict
 
 
-def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
+def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False,
                   start_page_id=0, end_page_id=None, lang=None,
                   *args, **kwargs):
-    """
-    解析ocr类pdf
-    """
+    """解析ocr类pdf."""
     pdf_info_dict = parse_pdf_by_ocr(
         pdf_bytes,
         pdf_models,
@@ -65,23 +55,21 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
         debug_mode=is_debug,
     )
 
-    pdf_info_dict["_parse_type"] = PARSE_TYPE_OCR
+    pdf_info_dict['_parse_type'] = PARSE_TYPE_OCR
 
-    pdf_info_dict["_version_name"] = __version__
+    pdf_info_dict['_version_name'] = __version__
 
     if lang is not None:
-        pdf_info_dict["_lang"] = lang
+        pdf_info_dict['_lang'] = lang
 
     return pdf_info_dict
 
 
-def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
+def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: DataWriter, is_debug=False,
                     input_model_is_empty: bool = False,
                     start_page_id=0, end_page_id=None, lang=None,
                     *args, **kwargs):
-    """
-    ocr和文本混合的pdf,全部解析出来
-    """
+    """ocr和文本混合的pdf,全部解析出来."""
 
     def parse_pdf(method):
         try:
@@ -98,12 +86,12 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
             return None
 
     pdf_info_dict = parse_pdf(parse_pdf_by_txt)
-    if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
-        logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
+    if pdf_info_dict is None or pdf_info_dict.get('_need_drop', False):
+        logger.warning('parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr')
         if input_model_is_empty:
-            layout_model = kwargs.get("layout_model", None)
-            formula_enable = kwargs.get("formula_enable", None)
-            table_enable = kwargs.get("table_enable", None)
+            layout_model = kwargs.get('layout_model', None)
+            formula_enable = kwargs.get('formula_enable', None)
+            table_enable = kwargs.get('table_enable', None)
             pdf_models = doc_analyze(
                 pdf_bytes,
                 ocr=True,
@@ -116,15 +104,15 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
             )
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         if pdf_info_dict is None:
-            raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")
+            raise Exception('Both parse_pdf_by_txt and parse_pdf_by_ocr failed.')
         else:
-            pdf_info_dict["_parse_type"] = PARSE_TYPE_OCR
+            pdf_info_dict['_parse_type'] = PARSE_TYPE_OCR
     else:
-        pdf_info_dict["_parse_type"] = PARSE_TYPE_TXT
+        pdf_info_dict['_parse_type'] = PARSE_TYPE_TXT
 
-    pdf_info_dict["_version_name"] = __version__
+    pdf_info_dict['_version_name'] = __version__
 
     if lang is not None:
-        pdf_info_dict["_lang"] = lang
+        pdf_info_dict['_lang'] = lang
 
     return pdf_info_dict

+ 49 - 52
projects/gradio_app/app.py

@@ -2,39 +2,37 @@
 
 import base64
 import os
+import re
 import time
 import uuid
 import zipfile
 from pathlib import Path
-import re
 
+import gradio as gr
 import pymupdf
+from gradio_pdf import PDF
 from loguru import logger
 
+from magic_pdf.data.data_reader_writer import DataReader
 from magic_pdf.libs.hash_utils import compute_sha256
-from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
-from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
 from magic_pdf.tools.common import do_parse, prepare_env
 
-import gradio as gr
-from gradio_pdf import PDF
-
 
 def read_fn(path):
-    disk_rw = DiskReaderWriter(os.path.dirname(path))
-    return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
+    disk_rw = DataReader(os.path.dirname(path))
+    return disk_rw.read(os.path.basename(path))
 
 
 def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_enable, table_enable, language):
     os.makedirs(output_dir, exist_ok=True)
 
     try:
-        file_name = f"{str(Path(doc_path).stem)}_{time.time()}"
+        file_name = f'{str(Path(doc_path).stem)}_{time.time()}'
         pdf_data = read_fn(doc_path)
         if is_ocr:
-            parse_method = "ocr"
+            parse_method = 'ocr'
         else:
-            parse_method = "auto"
+            parse_method = 'auto'
         local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method)
         do_parse(
             output_dir,
@@ -55,8 +53,7 @@ def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, layout_mode, formula_en
 
 
 def compress_directory_to_zip(directory_path, output_zip_path):
-    """
-    压缩指定目录到一个 ZIP 文件。
+    """压缩指定目录到一个 ZIP 文件。
 
     :param directory_path: 要压缩的目录路径
     :param output_zip_path: 输出的 ZIP 文件路径
@@ -80,7 +77,7 @@ def compress_directory_to_zip(directory_path, output_zip_path):
 
 
 def image_to_base64(image_path):
-    with open(image_path, "rb") as image_file:
+    with open(image_path, 'rb') as image_file:
         return base64.b64encode(image_file.read()).decode('utf-8')
 
 
@@ -93,7 +90,7 @@ def replace_image_with_base64(markdown_text, image_dir_path):
         relative_path = match.group(1)
         full_path = os.path.join(image_dir_path, relative_path)
         base64_image = image_to_base64(full_path)
-        return f"![{relative_path}](data:image/jpeg;base64,{base64_image})"
+        return f'![{relative_path}](data:image/jpeg;base64,{base64_image})'
 
     # 应用替换
     return re.sub(pattern, replace, markdown_text)
@@ -103,34 +100,34 @@ def to_markdown(file_path, end_pages, is_ocr, layout_mode, formula_enable, table
     # 获取识别的md文件以及压缩包文件路径
     local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1, is_ocr,
                                         layout_mode, formula_enable, table_enable, language)
-    archive_zip_path = os.path.join("./output", compute_sha256(local_md_dir) + ".zip")
+    archive_zip_path = os.path.join('./output', compute_sha256(local_md_dir) + '.zip')
     zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
     if zip_archive_success == 0:
-        logger.info("压缩成功")
+        logger.info('压缩成功')
     else:
-        logger.error("压缩失败")
-    md_path = os.path.join(local_md_dir, file_name + ".md")
+        logger.error('压缩失败')
+    md_path = os.path.join(local_md_dir, file_name + '.md')
     with open(md_path, 'r', encoding='utf-8') as f:
         txt_content = f.read()
     md_content = replace_image_with_base64(txt_content, local_md_dir)
     # 返回转换后的PDF路径
-    new_pdf_path = os.path.join(local_md_dir, file_name + "_layout.pdf")
+    new_pdf_path = os.path.join(local_md_dir, file_name + '_layout.pdf')
 
     return md_content, txt_content, archive_zip_path, new_pdf_path
 
 
-latex_delimiters = [{"left": "$$", "right": "$$", "display": True},
-                    {"left": '$', "right": '$', "display": False}]
+latex_delimiters = [{'left': '$$', 'right': '$$', 'display': True},
+                    {'left': '$', 'right': '$', 'display': False}]
 
 
 def init_model():
     from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
     try:
         model_manager = ModelSingleton()
-        txt_model = model_manager.get_model(False, False)
-        logger.info(f"txt_model init final")
-        ocr_model = model_manager.get_model(True, False)
-        logger.info(f"ocr_model init final")
+        txt_model = model_manager.get_model(False, False)  # noqa: F841
+        logger.info('txt_model init final')
+        ocr_model = model_manager.get_model(True, False)  # noqa: F841
+        logger.info('ocr_model init final')
         return 0
     except Exception as e:
         logger.exception(e)
@@ -138,31 +135,31 @@ def init_model():
 
 
 model_init = init_model()
-logger.info(f"model_init: {model_init}")
+logger.info(f'model_init: {model_init}')
 
 
-with open("header.html", "r") as file:
+with open('header.html', 'r') as file:
     header = file.read()
 
 
 latin_lang = [
-        'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
+        'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',  # noqa: E126
         'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
         'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
         'sw', 'tl', 'tr', 'uz', 'vi', 'french', 'german'
 ]
 arabic_lang = ['ar', 'fa', 'ug', 'ur']
 cyrillic_lang = [
-        'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
+        'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',  # noqa: E126
         'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
 ]
 devanagari_lang = [
-        'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
+        'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',  # noqa: E126
         'sa', 'bgc'
 ]
 other_lang = ['ch', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka']
 
-all_lang = [""]
+all_lang = ['']
 all_lang.extend([*other_lang, *latin_lang, *arabic_lang, *cyrillic_lang, *devanagari_lang])
 
 
@@ -174,7 +171,7 @@ def to_pdf(file_path):
             pdf_bytes = f.convert_to_pdf()
             # 将pdfbytes 写入到uuid.pdf中
             # 生成唯一的文件名
-            unique_filename = f"{uuid.uuid4()}.pdf"
+            unique_filename = f'{uuid.uuid4()}.pdf'
 
             # 构建完整的文件路径
             tmp_file_path = os.path.join(os.path.dirname(file_path), unique_filename)
@@ -186,43 +183,43 @@ def to_pdf(file_path):
             return tmp_file_path
 
 
-if __name__ == "__main__":
+if __name__ == '__main__':
     with gr.Blocks() as demo:
         gr.HTML(header)
         with gr.Row():
             with gr.Column(variant='panel', scale=5):
-                file = gr.File(label="Please upload a PDF or image", file_types=[".pdf", ".png", ".jpeg", ".jpg"])
-                max_pages = gr.Slider(1, 10, 5, step=1, label="Max convert pages")
+                file = gr.File(label='Please upload a PDF or image', file_types=['.pdf', '.png', '.jpeg', '.jpg'])
+                max_pages = gr.Slider(1, 10, 5, step=1, label='Max convert pages')
                 with gr.Row():
-                    layout_mode = gr.Dropdown(["layoutlmv3", "doclayout_yolo"], label="Layout model", value="layoutlmv3")
-                    language = gr.Dropdown(all_lang, label="Language", value="")
+                    layout_mode = gr.Dropdown(['layoutlmv3', 'doclayout_yolo'], label='Layout model', value='layoutlmv3')
+                    language = gr.Dropdown(all_lang, label='Language', value='')
                 with gr.Row():
-                    formula_enable = gr.Checkbox(label="Enable formula recognition", value=True)
-                    is_ocr = gr.Checkbox(label="Force enable OCR", value=False)
-                    table_enable = gr.Checkbox(label="Enable table recognition(test)", value=False)
+                    formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
+                    is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
+                    table_enable = gr.Checkbox(label='Enable table recognition(test)', value=False)
                 with gr.Row():
-                    change_bu = gr.Button("Convert")
-                    clear_bu = gr.ClearButton(value="Clear")
-                pdf_show = PDF(label="PDF preview", interactive=True, height=800)
-                with gr.Accordion("Examples:"):
-                    example_root = os.path.join(os.path.dirname(__file__), "examples")
+                    change_bu = gr.Button('Convert')
+                    clear_bu = gr.ClearButton(value='Clear')
+                pdf_show = PDF(label='PDF preview', interactive=True, height=800)
+                with gr.Accordion('Examples:'):
+                    example_root = os.path.join(os.path.dirname(__file__), 'examples')
                     gr.Examples(
                         examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
-                                  _.endswith("pdf")],
+                                  _.endswith('pdf')],
                         inputs=pdf_show
                     )
 
             with gr.Column(variant='panel', scale=5):
-                output_file = gr.File(label="convert result", interactive=False)
+                output_file = gr.File(label='convert result', interactive=False)
                 with gr.Tabs():
-                    with gr.Tab("Markdown rendering"):
-                        md = gr.Markdown(label="Markdown rendering", height=900, show_copy_button=True,
+                    with gr.Tab('Markdown rendering'):
+                        md = gr.Markdown(label='Markdown rendering', height=900, show_copy_button=True,
                                          latex_delimiters=latex_delimiters, line_breaks=True)
-                    with gr.Tab("Markdown text"):
+                    with gr.Tab('Markdown text'):
                         md_text = gr.TextArea(lines=45, show_copy_button=True)
         file.upload(fn=to_pdf, inputs=file, outputs=pdf_show)
         change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
                         outputs=[md, md_text, output_file, pdf_show])
         clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language])
 
-    demo.launch(server_name="0.0.0.0")
+    demo.launch(server_name='0.0.0.0')

+ 1 - 1
tests/unittest/test_data/assets/jsonl/test_02.jsonl

@@ -1 +1 @@
-{"track_id":"e8824f5a-9fcb-4ee5-b2d4-6bf2c67019dc","path":"tests/test_data/assets/pdfs/test_02.pdf","file_type":"pdf","content_type":"application/pdf","content_length":80078,"title":"German Idealism and the Concept of Punishment || Conclusion","remark":{"file_id":"scihub_78800000/libgen.scimag78872000-78872999.zip_10.1017/cbo9780511770425.012","file_source_type":"paper","original_file_id":"10.1017/cbo9780511770425.012","file_name":"10.1017/cbo9780511770425.012.pdf","author":"Merle, Jean-Christophe"}}
+{"track_id":"e8824f5a-9fcb-4ee5-b2d4-6bf2c67019dc","path":"tests/unittest/test_data/assets/pdfs/test_02.pdf","file_type":"pdf","content_type":"application/pdf","content_length":80078,"title":"German Idealism and the Concept of Punishment || Conclusion","remark":{"file_id":"scihub_78800000/libgen.scimag78872000-78872999.zip_10.1017/cbo9780511770425.012","file_source_type":"paper","original_file_id":"10.1017/cbo9780511770425.012","file_name":"10.1017/cbo9780511770425.012.pdf","author":"Merle, Jean-Christophe"}}

+ 2 - 2
tests/unittest/test_data/test_dataset.py

@@ -3,7 +3,7 @@ from magic_pdf.data.dataset import ImageDataset, PymuDocDataset
 
 
 def test_pymudataset():
-    with open('tests/test_data/assets/pdfs/test_01.pdf', 'rb') as f:
+    with open('tests/unittest/test_data/assets/pdfs/test_01.pdf', 'rb') as f:
         bits = f.read()
     datasets = PymuDocDataset(bits)
     assert len(datasets) > 0
@@ -11,7 +11,7 @@ def test_pymudataset():
 
 
 def test_imagedataset():
-    with open('tests/test_data/assets/pngs/test_01.png', 'rb') as f:
+    with open('tests/unittest/test_data/assets/pngs/test_01.png', 'rb') as f:
         bits = f.read()
     datasets = ImageDataset(bits)
     assert len(datasets) == 1

+ 4 - 4
tests/unittest/test_data/test_read_api.py

@@ -9,7 +9,7 @@ from magic_pdf.data.schemas import S3Config
 
 
 def test_read_local_pdfs():
-    datasets = read_local_pdfs('tests/test_data/assets/pdfs')
+    datasets = read_local_pdfs('tests/unittest/test_data/assets/pdfs')
     assert len(datasets) == 2
     assert len(datasets[0]) > 0
     assert len(datasets[1]) > 0
@@ -19,7 +19,7 @@ def test_read_local_pdfs():
 
 
 def test_read_local_images():
-    datasets = read_local_images('tests/test_data/assets/pngs', suffixes=['png'])
+    datasets = read_local_images('tests/unittest/test_data/assets/pngs', suffixes=['png'])
     assert len(datasets) == 2
     assert len(datasets[0]) == 1
     assert len(datasets[1]) == 1
@@ -69,10 +69,10 @@ def test_read_json():
     assert len(datasets) > 0
     assert len(datasets[0]) == 10
 
-    datasets = read_jsonl('tests/test_data/assets/jsonl/test_01.jsonl', reader)
+    datasets = read_jsonl('tests/unittest/test_data/assets/jsonl/test_01.jsonl', reader)
     assert len(datasets) == 1
     assert len(datasets[0]) == 10
 
-    datasets = read_jsonl('tests/test_data/assets/jsonl/test_02.jsonl')
+    datasets = read_jsonl('tests/unittest/test_data/assets/jsonl/test_02.jsonl')
     assert len(datasets) == 1
     assert len(datasets[0]) == 1

+ 2 - 2
tests/unittest/test_integrations/test_rag/test_api.py

@@ -17,7 +17,7 @@ def test_rag_document_reader():
     os.makedirs(temp_output_dir, exist_ok=True)
 
     # test
-    with open('tests/test_integrations/test_rag/assets/middle.json') as f:
+    with open('tests/unittest/test_integrations/test_rag/assets/middle.json') as f:
         json_data = json.load(f)
     res = convert_middle_json_to_layout_elements(json_data, temp_output_dir)
 
@@ -43,7 +43,7 @@ def test_data_reader():
     os.makedirs(temp_output_dir, exist_ok=True)
 
     # test
-    data_reader = DataReader('tests/test_integrations/test_rag/assets', 'ocr',
+    data_reader = DataReader('tests/unittest/test_integrations/test_rag/assets', 'ocr',
                              temp_output_dir)
 
     assert data_reader.get_documents_count() == 2

+ 3 - 3
tests/unittest/test_integrations/test_rag/test_utils.py

@@ -16,7 +16,7 @@ def test_convert_middle_json_to_layout_elements():
     os.makedirs(temp_output_dir, exist_ok=True)
 
     # test
-    with open('tests/test_integrations/test_rag/assets/middle.json') as f:
+    with open('tests/unittest/test_integrations/test_rag/assets/middle.json') as f:
         json_data = json.load(f)
     res = convert_middle_json_to_layout_elements(json_data, temp_output_dir)
 
@@ -32,7 +32,7 @@ def test_convert_middle_json_to_layout_elements():
 
 def test_inference():
 
-    asset_dir = 'tests/test_integrations/test_rag/assets'
+    asset_dir = 'tests/unittest/test_integrations/test_rag/assets'
     # setup
     unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
     os.makedirs(unitest_dir, exist_ok=True)
@@ -48,7 +48,7 @@ def test_inference():
 
     assert res is not None
     assert len(res) == 1
-    assert len(res[0].layout_dets) == 10
+    assert len(res[0].layout_dets) == 11
     assert res[0].layout_dets[0].anno_id == 0
     assert res[0].layout_dets[0].category_type == CategoryType.text
     assert len(res[0].extra.element_relation) == 3

+ 4 - 4
tests/unittest/test_model/test_magic_model.py

@@ -5,8 +5,8 @@ from magic_pdf.model.magic_model import MagicModel
 
 
 def test_magic_model_image_v2():
-    datasets = read_local_pdfs('tests/test_model/assets/test_01.pdf')
-    with open('tests/test_model/assets/test_01.model.json') as f:
+    datasets = read_local_pdfs('tests/unittest/test_model/assets/test_01.pdf')
+    with open('tests/unittest/test_model/assets/test_01.model.json') as f:
         model_json = json.load(f)
 
     magic_model = MagicModel(model_json, datasets[0])
@@ -19,8 +19,8 @@ def test_magic_model_image_v2():
 
 
 def test_magic_model_table_v2():
-    datasets = read_local_pdfs('tests/test_model/assets/test_02.pdf')
-    with open('tests/test_model/assets/test_02.model.json') as f:
+    datasets = read_local_pdfs('tests/unittest/test_model/assets/test_02.pdf')
+    with open('tests/unittest/test_model/assets/test_02.model.json') as f:
         model_json = json.load(f)
 
     magic_model = MagicModel(model_json, datasets[0])

Plik diff jest za duży
+ 0 - 0
tests/unittest/test_tools/assets/cli_dev/cli_test_01.jsonl


+ 52 - 51
tests/unittest/test_tools/test_cli.py

@@ -1,6 +1,7 @@
-import tempfile
 import os
 import shutil
+import tempfile
+
 from click.testing import CliRunner
 
 from magic_pdf.tools.cli import cli
@@ -8,19 +9,19 @@ from magic_pdf.tools.cli import cli
 
 def test_cli_pdf():
     # setup
-    unitest_dir = "/tmp/magic_pdf/unittest/tools"
-    filename = "cli_test_01"
+    unitest_dir = '/tmp/magic_pdf/unittest/tools'
+    filename = 'cli_test_01'
     os.makedirs(unitest_dir, exist_ok=True)
-    temp_output_dir = tempfile.mkdtemp(dir="/tmp/magic_pdf/unittest/tools")
+    temp_output_dir = tempfile.mkdtemp(dir='/tmp/magic_pdf/unittest/tools')
 
     # run
     runner = CliRunner()
     result = runner.invoke(
         cli,
         [
-            "-p",
-            "tests/test_tools/assets/cli/pdf/cli_test_01.pdf",
-            "-o",
+            '-p',
+            'tests/unittest/test_tools/assets/cli/pdf/cli_test_01.pdf',
+            '-o',
             temp_output_dir,
         ],
     )
@@ -28,29 +29,29 @@ def test_cli_pdf():
     # check
     assert result.exit_code == 0
 
-    base_output_dir = os.path.join(temp_output_dir, "cli_test_01/auto")
+    base_output_dir = os.path.join(temp_output_dir, 'cli_test_01/auto')
 
-    r = os.stat(os.path.join(base_output_dir, f"{filename}.md"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}.md'))
     assert r.st_size > 7000
 
-    r = os.stat(os.path.join(base_output_dir, "middle.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_middle.json'))
     assert r.st_size > 200000
 
-    r = os.stat(os.path.join(base_output_dir, "model.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_model.json'))
     assert r.st_size > 15000
 
-    r = os.stat(os.path.join(base_output_dir, "origin.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_origin.pdf'))
+    assert r.st_size > 400000
 
-    r = os.stat(os.path.join(base_output_dir, "layout.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_layout.pdf'))
+    assert r.st_size > 400000
 
-    r = os.stat(os.path.join(base_output_dir, "spans.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_spans.pdf'))
+    assert r.st_size > 400000
 
-    assert os.path.exists(os.path.join(base_output_dir, "images")) is True
-    assert os.path.isdir(os.path.join(base_output_dir, "images")) is True
-    assert os.path.exists(os.path.join(base_output_dir, "content_list.json")) is False
+    assert os.path.exists(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.isdir(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.exists(os.path.join(base_output_dir, f'{filename}_content_list.json')) is True
 
     # teardown
     shutil.rmtree(temp_output_dir)
@@ -58,68 +59,68 @@ def test_cli_pdf():
 
 def test_cli_path():
     # setup
-    unitest_dir = "/tmp/magic_pdf/unittest/tools"
+    unitest_dir = '/tmp/magic_pdf/unittest/tools'
     os.makedirs(unitest_dir, exist_ok=True)
-    temp_output_dir = tempfile.mkdtemp(dir="/tmp/magic_pdf/unittest/tools")
+    temp_output_dir = tempfile.mkdtemp(dir='/tmp/magic_pdf/unittest/tools')
 
     # run
     runner = CliRunner()
     result = runner.invoke(
-        cli, ["-p", "tests/test_tools/assets/cli/path", "-o", temp_output_dir]
+        cli, ['-p', 'tests/unittest/test_tools/assets/cli/path', '-o', temp_output_dir]
     )
 
     # check
     assert result.exit_code == 0
 
-    filename = "cli_test_01"
-    base_output_dir = os.path.join(temp_output_dir, "cli_test_01/auto")
+    filename = 'cli_test_01'
+    base_output_dir = os.path.join(temp_output_dir, 'cli_test_01/auto')
 
-    r = os.stat(os.path.join(base_output_dir, f"{filename}.md"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}.md'))
     assert r.st_size > 7000
 
-    r = os.stat(os.path.join(base_output_dir, "middle.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_middle.json'))
     assert r.st_size > 200000
 
-    r = os.stat(os.path.join(base_output_dir, "model.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_model.json'))
     assert r.st_size > 15000
 
-    r = os.stat(os.path.join(base_output_dir, "origin.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_origin.pdf'))
+    assert r.st_size > 400000
 
-    r = os.stat(os.path.join(base_output_dir, "layout.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_layout.pdf'))
+    assert r.st_size > 400000
 
-    r = os.stat(os.path.join(base_output_dir, "spans.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_spans.pdf'))
+    assert r.st_size > 400000
 
-    assert os.path.exists(os.path.join(base_output_dir, "images")) is True
-    assert os.path.isdir(os.path.join(base_output_dir, "images")) is True
-    assert os.path.exists(os.path.join(base_output_dir, "content_list.json")) is False
+    assert os.path.exists(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.isdir(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.exists(os.path.join(base_output_dir, f'{filename}_content_list.json')) is True
 
-    base_output_dir = os.path.join(temp_output_dir, "cli_test_02/auto")
-    filename = "cli_test_02"
+    base_output_dir = os.path.join(temp_output_dir, 'cli_test_02/auto')
+    filename = 'cli_test_02'
 
-    r = os.stat(os.path.join(base_output_dir, f"{filename}.md"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}.md'))
     assert r.st_size > 5000
 
-    r = os.stat(os.path.join(base_output_dir, "middle.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_middle.json'))
     assert r.st_size > 200000
 
-    r = os.stat(os.path.join(base_output_dir, "model.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_model.json'))
     assert r.st_size > 15000
 
-    r = os.stat(os.path.join(base_output_dir, "origin.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_origin.pdf'))
+    assert r.st_size > 400000
 
-    r = os.stat(os.path.join(base_output_dir, "layout.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_layout.pdf'))
+    assert r.st_size > 400000
 
-    r = os.stat(os.path.join(base_output_dir, "spans.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_spans.pdf'))
+    assert r.st_size > 400000
 
-    assert os.path.exists(os.path.join(base_output_dir, "images")) is True
-    assert os.path.isdir(os.path.join(base_output_dir, "images")) is True
-    assert os.path.exists(os.path.join(base_output_dir, "content_list.json")) is False
+    assert os.path.exists(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.isdir(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.exists(os.path.join(base_output_dir, f'{filename}_content_list.json')) is True
 
     # teardown
     shutil.rmtree(temp_output_dir)

+ 46 - 46
tests/unittest/test_tools/test_cli_dev.py

@@ -1,6 +1,7 @@
-import tempfile
 import os
 import shutil
+import tempfile
+
 from click.testing import CliRunner
 
 from magic_pdf.tools import cli_dev
@@ -8,22 +9,22 @@ from magic_pdf.tools import cli_dev
 
 def test_cli_pdf():
     # setup
-    unitest_dir = "/tmp/magic_pdf/unittest/tools"
-    filename = "cli_test_01"
+    unitest_dir = '/tmp/magic_pdf/unittest/tools'
+    filename = 'cli_test_01'
     os.makedirs(unitest_dir, exist_ok=True)
-    temp_output_dir = tempfile.mkdtemp(dir="/tmp/magic_pdf/unittest/tools")
+    temp_output_dir = tempfile.mkdtemp(dir='/tmp/magic_pdf/unittest/tools')
 
     # run
     runner = CliRunner()
     result = runner.invoke(
         cli_dev.cli,
         [
-            "pdf",
-            "-p",
-            "tests/test_tools/assets/cli/pdf/cli_test_01.pdf",
-            "-j",
-            "tests/test_tools/assets/cli_dev/cli_test_01.model.json",
-            "-o",
+            'pdf',
+            '-p',
+            'tests/unittest/test_tools/assets/cli/pdf/cli_test_01.pdf',
+            '-j',
+            'tests/unittest/test_tools/assets/cli_dev/cli_test_01.model.json',
+            '-o',
             temp_output_dir,
         ],
     )
@@ -31,31 +32,30 @@ def test_cli_pdf():
     # check
     assert result.exit_code == 0
 
-    base_output_dir = os.path.join(temp_output_dir, "cli_test_01/auto")
+    base_output_dir = os.path.join(temp_output_dir, 'cli_test_01/auto')
 
-    r = os.stat(os.path.join(base_output_dir, "content_list.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_content_list.json'))
     assert r.st_size > 5000
-
-    r = os.stat(os.path.join(base_output_dir, f"{filename}.md"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}.md'))
     assert r.st_size > 7000
 
-    r = os.stat(os.path.join(base_output_dir, "middle.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_middle.json'))
     assert r.st_size > 200000
 
-    r = os.stat(os.path.join(base_output_dir, "model.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_model.json'))
     assert r.st_size > 15000
 
-    r = os.stat(os.path.join(base_output_dir, "origin.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_origin.pdf'))
+    assert r.st_size > 400000
 
-    r = os.stat(os.path.join(base_output_dir, "layout.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_layout.pdf'))
+    assert r.st_size > 400000
 
-    r = os.stat(os.path.join(base_output_dir, "spans.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_spans.pdf'))
+    assert r.st_size > 400000
 
-    assert os.path.exists(os.path.join(base_output_dir, "images")) is True
-    assert os.path.isdir(os.path.join(base_output_dir, "images")) is True
+    assert os.path.exists(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.isdir(os.path.join(base_output_dir, 'images')) is True
 
     # teardown
     shutil.rmtree(temp_output_dir)
@@ -63,26 +63,26 @@ def test_cli_pdf():
 
 def test_cli_jsonl():
     # setup
-    unitest_dir = "/tmp/magic_pdf/unittest/tools"
-    filename = "cli_test_01"
+    unitest_dir = '/tmp/magic_pdf/unittest/tools'
+    filename = 'cli_test_01'
     os.makedirs(unitest_dir, exist_ok=True)
-    temp_output_dir = tempfile.mkdtemp(dir="/tmp/magic_pdf/unittest/tools")
+    temp_output_dir = tempfile.mkdtemp(dir='/tmp/magic_pdf/unittest/tools')
 
     def mock_read_s3_path(s3path):
-        with open(s3path, "rb") as f:
+        with open(s3path, 'rb') as f:
             return f.read()
 
-    cli_dev.read_s3_path = mock_read_s3_path # mock
+    cli_dev.read_s3_path = mock_read_s3_path  # mock
 
     # run
     runner = CliRunner()
     result = runner.invoke(
         cli_dev.cli,
         [
-            "jsonl",
-            "-j",
-            "tests/test_tools/assets/cli_dev/cli_test_01.jsonl",
-            "-o",
+            'jsonl',
+            '-j',
+            'tests/unittest/test_tools/assets/cli_dev/cli_test_01.jsonl',
+            '-o',
             temp_output_dir,
         ],
     )
@@ -90,31 +90,31 @@ def test_cli_jsonl():
     # check
     assert result.exit_code == 0
 
-    base_output_dir = os.path.join(temp_output_dir, "cli_test_01/auto")
+    base_output_dir = os.path.join(temp_output_dir, 'cli_test_01/auto')
 
-    r = os.stat(os.path.join(base_output_dir, "content_list.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_content_list.json'))
     assert r.st_size > 5000
 
-    r = os.stat(os.path.join(base_output_dir, f"{filename}.md"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}.md'))
     assert r.st_size > 7000
 
-    r = os.stat(os.path.join(base_output_dir, "middle.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_middle.json'))
     assert r.st_size > 200000
 
-    r = os.stat(os.path.join(base_output_dir, "model.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_model.json'))
     assert r.st_size > 15000
 
-    r = os.stat(os.path.join(base_output_dir, "origin.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_origin.pdf'))
+    assert r.st_size > 400000
 
-    r = os.stat(os.path.join(base_output_dir, "layout.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_layout.pdf'))
+    assert r.st_size > 400000
 
-    r = os.stat(os.path.join(base_output_dir, "spans.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_spans.pdf'))
+    assert r.st_size > 400000
 
-    assert os.path.exists(os.path.join(base_output_dir, "images")) is True
-    assert os.path.isdir(os.path.join(base_output_dir, "images")) is True
+    assert os.path.exists(os.path.join(base_output_dir, 'images')) is True
+    assert os.path.isdir(os.path.join(base_output_dir, 'images')) is True
 
     # teardown
     shutil.rmtree(temp_output_dir)

+ 21 - 19
tests/unittest/test_tools/test_common.py

@@ -1,23 +1,25 @@
-import tempfile
 import os
 import shutil
+import tempfile
 
 import pytest
 
 from magic_pdf.tools.common import do_parse
 
 
-@pytest.mark.parametrize("method", ["auto", "txt", "ocr"])
+@pytest.mark.parametrize('method', ['auto', 'txt', 'ocr'])
 def test_common_do_parse(method):
+    import magic_pdf.model as model_config
+    model_config.__use_inside_model__ = True
     # setup
-    unitest_dir = "/tmp/magic_pdf/unittest/tools"
-    filename = "fake"
+    unitest_dir = '/tmp/magic_pdf/unittest/tools'
+    filename = 'fake'
     os.makedirs(unitest_dir, exist_ok=True)
 
-    temp_output_dir = tempfile.mkdtemp(dir="/tmp/magic_pdf/unittest/tools")
+    temp_output_dir = tempfile.mkdtemp(dir='/tmp/magic_pdf/unittest/tools')
 
     # run
-    with open("tests/test_tools/assets/common/cli_test_01.pdf", "rb") as f:
+    with open('tests/unittest/test_tools/assets/common/cli_test_01.pdf', 'rb') as f:
         bits = f.read()
     do_parse(temp_output_dir,
              filename,
@@ -27,31 +29,31 @@ def test_common_do_parse(method):
              f_dump_content_list=True)
 
     # check
-    base_output_dir = os.path.join(temp_output_dir, f"fake/{method}")
+    base_output_dir = os.path.join(temp_output_dir, f'fake/{method}')
 
-    r = os.stat(os.path.join(base_output_dir, "content_list.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_content_list.json'))
     assert r.st_size > 5000
 
-    r = os.stat(os.path.join(base_output_dir, f"{filename}.md"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}.md'))
     assert r.st_size > 7000
 
-    r = os.stat(os.path.join(base_output_dir, "middle.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_middle.json'))
     assert r.st_size > 200000
 
-    r = os.stat(os.path.join(base_output_dir, "model.json"))
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_model.json'))
     assert r.st_size > 15000
 
-    r = os.stat(os.path.join(base_output_dir, "origin.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_origin.pdf'))
+    assert r.st_size > 400000
 
-    r = os.stat(os.path.join(base_output_dir, "layout.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_layout.pdf'))
+    assert r.st_size > 400000
 
-    r = os.stat(os.path.join(base_output_dir, "spans.pdf"))
-    assert r.st_size > 500000
+    r = os.stat(os.path.join(base_output_dir, f'{filename}_spans.pdf'))
+    assert r.st_size > 400000
 
-    os.path.exists(os.path.join(base_output_dir, "images"))
-    os.path.isdir(os.path.join(base_output_dir, "images"))
+    os.path.exists(os.path.join(base_output_dir, 'images'))
+    os.path.isdir(os.path.join(base_output_dir, 'images'))
 
     # teardown
     shutil.rmtree(temp_output_dir)

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików