瀏覽代碼

fix: add magic-pdf-dev case

quyuan 1 年之前
父節點
當前提交
f9df92aa34

+ 1 - 0
.github/ISSUE_TEMPLATE/bug_report.yml

@@ -80,6 +80,7 @@ body:
         -
         - "0.6.x"
         - "0.7.x"
+        - "0.8.x"
     validations:
       required: true
 

文件差異過大導致無法顯示
+ 1 - 0
README.md


文件差異過大導致無法顯示
+ 1 - 0
README_zh-CN.md


+ 167 - 0
app.py

@@ -0,0 +1,167 @@
+# Copyright (c) Opendatalab. All rights reserved.
+
+import base64
+import os
+import time
+import zipfile
+from pathlib import Path
+import re
+
+from loguru import logger
+
+from magic_pdf.libs.hash_utils import compute_sha256
+from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
+from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
+from magic_pdf.tools.common import do_parse, prepare_env
+
+os.system("pip install gradio")
+os.system("pip install gradio-pdf")
+import gradio as gr
+from gradio_pdf import PDF
+
+
+def read_fn(path):
+    disk_rw = DiskReaderWriter(os.path.dirname(path))
+    return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)
+
+
+def parse_pdf(doc_path, output_dir, end_page_id):
+    os.makedirs(output_dir, exist_ok=True)
+
+    try:
+        file_name = f"{str(Path(doc_path).stem)}_{time.time()}"
+        pdf_data = read_fn(doc_path)
+        parse_method = "auto"
+        local_image_dir, local_md_dir = prepare_env(output_dir, file_name, parse_method)
+        do_parse(
+            output_dir,
+            file_name,
+            pdf_data,
+            [],
+            parse_method,
+            False,
+            end_page_id=end_page_id,
+        )
+        return local_md_dir, file_name
+    except Exception as e:
+        logger.exception(e)
+
+
+def compress_directory_to_zip(directory_path, output_zip_path):
+    """
+    压缩指定目录到一个 ZIP 文件。
+
+    :param directory_path: 要压缩的目录路径
+    :param output_zip_path: 输出的 ZIP 文件路径
+    """
+    try:
+        with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
+
+            # 遍历目录中的所有文件和子目录
+            for root, dirs, files in os.walk(directory_path):
+                for file in files:
+                    # 构建完整的文件路径
+                    file_path = os.path.join(root, file)
+                    # 计算相对路径
+                    arcname = os.path.relpath(file_path, directory_path)
+                    # 添加文件到 ZIP 文件
+                    zipf.write(file_path, arcname)
+        return 0
+    except Exception as e:
+        logger.exception(e)
+        return -1
+
+
+def image_to_base64(image_path):
+    with open(image_path, "rb") as image_file:
+        return base64.b64encode(image_file.read()).decode('utf-8')
+
+
+def replace_image_with_base64(markdown_text, image_dir_path):
+    # 匹配Markdown中的图片标签
+    pattern = r'\!\[(?:[^\]]*)\]\(([^)]+)\)'
+
+    # 替换图片链接
+    def replace(match):
+        relative_path = match.group(1)
+        full_path = os.path.join(image_dir_path, relative_path)
+        base64_image = image_to_base64(full_path)
+        return f"![{relative_path}](data:image/jpeg;base64,{base64_image})"
+
+    # 应用替换
+    return re.sub(pattern, replace, markdown_text)
+
+
+def to_markdown(file_path, end_pages):
+    # 获取识别的md文件以及压缩包文件路径
+    local_md_dir, file_name = parse_pdf(file_path, './output', end_pages - 1)
+    archive_zip_path = os.path.join("./output", compute_sha256(local_md_dir) + ".zip")
+    zip_archive_success = compress_directory_to_zip(local_md_dir, archive_zip_path)
+    if zip_archive_success == 0:
+        logger.info("压缩成功")
+    else:
+        logger.error("压缩失败")
+    md_path = os.path.join(local_md_dir, file_name + ".md")
+    with open(md_path, 'r', encoding='utf-8') as f:
+        txt_content = f.read()
+    md_content = replace_image_with_base64(txt_content, local_md_dir)
+    # 返回转换后的PDF路径
+    new_pdf_path = os.path.join(local_md_dir, file_name + "_layout.pdf")
+
+    return md_content, txt_content, archive_zip_path, new_pdf_path
+
+
+# def show_pdf(file_path):
+#     with open(file_path, "rb") as f:
+#         base64_pdf = base64.b64encode(f.read()).decode('utf-8')
+#     pdf_display = f'<embed src="data:application/pdf;base64,{base64_pdf}" ' \
+#                   f'width="100%" height="1000" type="application/pdf">'
+#     return pdf_display
+
+
+latex_delimiters = [{"left": "$$", "right": "$$", "display": True},
+                    {"left": '$', "right": '$', "display": False}]
+
+
+def init_model():
+    from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
+    try:
+        model_manager = ModelSingleton()
+        txt_model = model_manager.get_model(False, False)
+        logger.info(f"txt_model init final")
+        ocr_model = model_manager.get_model(True, False)
+        logger.info(f"ocr_model init final")
+        return 0
+    except Exception as e:
+        logger.exception(e)
+        return -1
+
+
+model_init = init_model()
+logger.info(f"model_init: {model_init}")
+
+
+if __name__ == "__main__":
+    with gr.Blocks() as demo:
+        with gr.Row():
+            with gr.Column(variant='panel', scale=5):
+                pdf_show = gr.Markdown()
+                max_pages = gr.Slider(1, 10, 5, step=1, label="Max convert pages")
+                with gr.Row() as bu_flow:
+                    change_bu = gr.Button("Convert")
+                    clear_bu = gr.ClearButton([pdf_show], value="Clear")
+                pdf_show = PDF(label="Please upload pdf", interactive=True, height=800)
+
+            with gr.Column(variant='panel', scale=5):
+                output_file = gr.File(label="convert result", interactive=False)
+                with gr.Tabs():
+                    with gr.Tab("Markdown rendering"):
+                        md = gr.Markdown(label="Markdown rendering", height=900, show_copy_button=True,
+                                         latex_delimiters=latex_delimiters, line_breaks=True)
+                    with gr.Tab("Markdown text"):
+                        md_text = gr.TextArea(lines=45, show_copy_button=True)
+        change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages], outputs=[md, md_text, output_file, pdf_show])
+        clear_bu.add([md, pdf_show, md_text, output_file])
+
+    demo.launch()
+

+ 4 - 0
magic_pdf/libs/version.py

@@ -1 +1,5 @@
+<<<<<<< HEAD
 __version__ = "0.7.1"
+=======
+__version__ = "0.8.0"
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999

+ 4 - 0
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -106,7 +106,11 @@ def custom_model_init(ocr: bool = False, show_log: bool = False, lang=None):
 
 
 def doc_analyze(pdf_bytes: bytes, ocr: bool = False, show_log: bool = False,
+<<<<<<< HEAD
                 start_page_id=0, end_page_id=None, lang=None):
+=======
+                start_page_id=0, end_page_id=None):
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 
     model_manager = ModelSingleton()
     custom_model = model_manager.get_model(ocr, show_log, lang)

+ 13 - 0
magic_pdf/model/pdf_extract_kit.py

@@ -74,11 +74,16 @@ def layout_model_init(weight, config_file, device):
     return model
 
 
+<<<<<<< HEAD
 def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None):
     if lang is not None:
         model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang)
     else:
         model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
+=======
+def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3):
+    model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
     return model
 
 
@@ -137,8 +142,12 @@ def atom_model_init(model_name: str, **kwargs):
     elif model_name == AtomicModel.OCR:
         atom_model = ocr_model_init(
             kwargs.get("ocr_show_log"),
+<<<<<<< HEAD
             kwargs.get("det_db_box_thresh"),
             kwargs.get("lang")
+=======
+            kwargs.get("det_db_box_thresh")
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
         )
     elif model_name == AtomicModel.Table:
         atom_model = table_model_init(
@@ -235,8 +244,12 @@ class CustomPEKModel:
             self.ocr_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.OCR,
                 ocr_show_log=show_log,
+<<<<<<< HEAD
                 det_db_box_thresh=0.3,
                 lang=self.lang
+=======
+                det_db_box_thresh=0.3
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
             )
         # init table model
         if self.apply_table:

+ 7 - 0
magic_pdf/pipe/AbsPipe.py

@@ -17,7 +17,11 @@ class AbsPipe(ABC):
     PIP_TXT = "txt"
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+<<<<<<< HEAD
                  start_page_id=0, end_page_id=None, lang=None):
+=======
+                 start_page_id=0, end_page_id=None):
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
         self.pdf_bytes = pdf_bytes
         self.model_list = model_list
         self.image_writer = image_writer
@@ -25,7 +29,10 @@ class AbsPipe(ABC):
         self.is_debug = is_debug
         self.start_page_id = start_page_id
         self.end_page_id = end_page_id
+<<<<<<< HEAD
         self.lang = lang
+=======
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
     
     def get_compress_pdf_mid_data(self):
         return JsonCompressor.compress_json(self.pdf_mid_data)

+ 9 - 0
magic_pdf/pipe/OCRPipe.py

@@ -10,16 +10,25 @@ from magic_pdf.user_api import parse_ocr_pdf
 class OCRPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+<<<<<<< HEAD
                  start_page_id=0, end_page_id=None, lang=None):
         super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
+=======
+                 start_page_id=0, end_page_id=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 
     def pipe_classify(self):
         pass
 
     def pipe_analyze(self):
         self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
+<<<<<<< HEAD
                                       start_page_id=self.start_page_id, end_page_id=self.end_page_id,
                                       lang=self.lang)
+=======
+                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,

+ 9 - 0
magic_pdf/pipe/TXTPipe.py

@@ -11,16 +11,25 @@ from magic_pdf.user_api import parse_txt_pdf
 class TXTPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, model_list: list, image_writer: AbsReaderWriter, is_debug: bool = False,
+<<<<<<< HEAD
                  start_page_id=0, end_page_id=None, lang=None):
         super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id, lang)
+=======
+                 start_page_id=0, end_page_id=None):
+        super().__init__(pdf_bytes, model_list, image_writer, is_debug, start_page_id, end_page_id)
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 
     def pipe_classify(self):
         pass
 
     def pipe_analyze(self):
         self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
+<<<<<<< HEAD
                                       start_page_id=self.start_page_id, end_page_id=self.end_page_id,
                                       lang=self.lang)
+=======
+                                      start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 
     def pipe_parse(self):
         self.pdf_mid_data = parse_txt_pdf(self.pdf_bytes, self.model_list, self.image_writer, is_debug=self.is_debug,

+ 17 - 0
magic_pdf/pipe/UNIPipe.py

@@ -14,9 +14,15 @@ from magic_pdf.user_api import parse_union_pdf, parse_ocr_pdf
 class UNIPipe(AbsPipe):
 
     def __init__(self, pdf_bytes: bytes, jso_useful_key: dict, image_writer: AbsReaderWriter, is_debug: bool = False,
+<<<<<<< HEAD
                  start_page_id=0, end_page_id=None, lang=None):
         self.pdf_type = jso_useful_key["_pdf_type"]
         super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id, lang)
+=======
+                 start_page_id=0, end_page_id=None):
+        self.pdf_type = jso_useful_key["_pdf_type"]
+        super().__init__(pdf_bytes, jso_useful_key["model_list"], image_writer, is_debug, start_page_id, end_page_id)
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
         if len(self.model_list) == 0:
             self.input_model_is_empty = True
         else:
@@ -28,19 +34,30 @@ class UNIPipe(AbsPipe):
     def pipe_analyze(self):
         if self.pdf_type == self.PIP_TXT:
             self.model_list = doc_analyze(self.pdf_bytes, ocr=False,
+<<<<<<< HEAD
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
                                           lang=self.lang)
         elif self.pdf_type == self.PIP_OCR:
             self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
                                           start_page_id=self.start_page_id, end_page_id=self.end_page_id,
                                           lang=self.lang)
+=======
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+        elif self.pdf_type == self.PIP_OCR:
+            self.model_list = doc_analyze(self.pdf_bytes, ocr=True,
+                                          start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 
     def pipe_parse(self):
         if self.pdf_type == self.PIP_TXT:
             self.pdf_mid_data = parse_union_pdf(self.pdf_bytes, self.model_list, self.image_writer,
                                                 is_debug=self.is_debug, input_model_is_empty=self.input_model_is_empty,
+<<<<<<< HEAD
                                                 start_page_id=self.start_page_id, end_page_id=self.end_page_id,
                                                 lang=self.lang)
+=======
+                                                start_page_id=self.start_page_id, end_page_id=self.end_page_id)
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
         elif self.pdf_type == self.PIP_OCR:
             self.pdf_mid_data = parse_ocr_pdf(self.pdf_bytes, self.model_list, self.image_writer,
                                               is_debug=self.is_debug,

+ 4 - 0
magic_pdf/resources/model_config/model_configs.yaml

@@ -10,6 +10,10 @@ config:
 weights:
   layout: Layout/model_final.pth
   mfd: MFD/weights.pt
+<<<<<<< HEAD
   mfr: MFR/unimernet_base
+=======
+  mfr: MFR/UniMERNet
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
   struct_eqtable: TabRec/StructEqTable
   TableMaster: TabRec/TableMaster

+ 10 - 0
magic_pdf/tools/cli.py

@@ -45,6 +45,7 @@ without method specified, auto will be used by default.""",
     default='auto',
 )
 @click.option(
+<<<<<<< HEAD
     '-l',
     '--lang',
     'lang',
@@ -57,6 +58,8 @@ without method specified, auto will be used by default.""",
     default=None,
 )
 @click.option(
+=======
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
     '-d',
     '--debug',
     'debug_able',
@@ -80,7 +83,11 @@ without method specified, auto will be used by default.""",
     help='The ending page for PDF parsing, beginning from 0.',
     default=None,
 )
+<<<<<<< HEAD
 def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
+=======
+def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
     model_config.__use_inside_model__ = True
     model_config.__model_mode__ = 'full'
     os.makedirs(output_dir, exist_ok=True)
@@ -102,7 +109,10 @@ def cli(path, output_dir, method, lang, debug_able, start_page_id, end_page_id):
                 debug_able,
                 start_page_id=start_page_id,
                 end_page_id=end_page_id,
+<<<<<<< HEAD
                 lang=lang
+=======
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
             )
 
         except Exception as e:

+ 13 - 0
magic_pdf/tools/common.py

@@ -44,7 +44,10 @@ def do_parse(
     f_draw_model_bbox=False,
     start_page_id=0,
     end_page_id=None,
+<<<<<<< HEAD
     lang=None,
+=======
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 ):
     if debug_able:
         logger.warning("debug mode is on")
@@ -62,6 +65,7 @@ def do_parse(
     if parse_method == 'auto':
         jso_useful_key = {'_pdf_type': '', 'model_list': model_list}
         pipe = UNIPipe(pdf_bytes, jso_useful_key, image_writer, is_debug=True,
+<<<<<<< HEAD
                        start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
     elif parse_method == 'txt':
         pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
@@ -69,6 +73,15 @@ def do_parse(
     elif parse_method == 'ocr':
         pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
                        start_page_id=start_page_id, end_page_id=end_page_id, lang=lang)
+=======
+                       start_page_id=start_page_id, end_page_id=end_page_id)
+    elif parse_method == 'txt':
+        pipe = TXTPipe(pdf_bytes, model_list, image_writer, is_debug=True,
+                       start_page_id=start_page_id, end_page_id=end_page_id)
+    elif parse_method == 'ocr':
+        pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True,
+                       start_page_id=start_page_id, end_page_id=end_page_id)
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
     else:
         logger.error('unknown parse method')
         exit(1)

+ 10 - 0
magic_pdf/user_api.py

@@ -71,7 +71,11 @@ def parse_ocr_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWrit
 
 def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWriter, is_debug=False,
                     input_model_is_empty: bool = False,
+<<<<<<< HEAD
                     start_page_id=0, end_page_id=None, lang=None,
+=======
+                    start_page_id=0, end_page_id=None,
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
                     *args, **kwargs):
     """
     ocr和文本混合的pdf,全部解析出来
@@ -95,11 +99,17 @@ def parse_union_pdf(pdf_bytes: bytes, pdf_models: list, imageWriter: AbsReaderWr
     if pdf_info_dict is None or pdf_info_dict.get("_need_drop", False):
         logger.warning(f"parse_pdf_by_txt drop or error, switch to parse_pdf_by_ocr")
         if input_model_is_empty:
+<<<<<<< HEAD
             pdf_models = doc_analyze(pdf_bytes,
                                      ocr=True,
                                      start_page_id=start_page_id,
                                      end_page_id=end_page_id,
                                      lang=lang)
+=======
+            pdf_models = doc_analyze(pdf_bytes, ocr=True,
+                                     start_page_id=start_page_id,
+                                     end_page_id=end_page_id)
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
         pdf_info_dict = parse_pdf(parse_pdf_by_ocr)
         if pdf_info_dict is None:
             raise Exception("Both parse_pdf_by_txt and parse_pdf_by_ocr failed.")

+ 3 - 0
projects/README.md

@@ -3,6 +3,9 @@
 ## Project List
 
 - [llama_index_rag](./llama_index_rag/README.md): Build a lightweight RAG system based on llama_index
+<<<<<<< HEAD
 - [gradio_app](./gradio_app/README.md): Build a web app based on gradio
 
 
+=======
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999

+ 3 - 0
projects/README_zh-CN.md

@@ -3,5 +3,8 @@
 ## 项目列表
 
 - [llama_index_rag](./llama_index_rag/README_zh-CN.md): 基于 llama_index 构建轻量级 RAG 系统
+<<<<<<< HEAD
 - [gradio_app](./gradio_app/README_zh-CN.md): 基于 Gradio 的 Web 应用
 
+=======
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999

+ 18 - 0
projects/llama_index_rag/README_zh-CN.md

@@ -59,7 +59,10 @@ Server: Docker Engine - Community
 ```bash
 # install
 pip install modelscope==1.14.0
+<<<<<<< HEAD
 
+=======
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 pip install llama-index-vector-stores-elasticsearch==0.2.0
 pip install llama-index-embeddings-dashscope==0.2.0
 pip install llama-index-core==0.10.68
@@ -71,13 +74,19 @@ pip install accelerate==0.33.0
 pip uninstall transformer-engine
 ```
 
+<<<<<<< HEAD
 
+=======
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 ## 示例
 
 ````bash
 cd  projects/llama_index_rag
 
+<<<<<<< HEAD
 
+=======
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 docker compose up -d
 
 or
@@ -85,14 +94,20 @@ or
 docker-compose up -d
 
 
+<<<<<<< HEAD
 # 配置环境变量
 
+=======
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 export ES_USER=elastic
 export ES_PASSWORD=llama_index
 export ES_URL=http://127.0.0.1:9200
 export DASHSCOPE_API_KEY={some_key}
 
+<<<<<<< HEAD
 
+=======
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 DASHSCOPE_API_KEY 开通参考[文档](https://help.aliyun.com/zh/dashscope/opening-service)
 
 # 未导入数据,查询问题。返回通义千问默认答案
@@ -120,7 +135,10 @@ python data_ingestion.py -p example/data/declaration_of_the_rights_of_man_1789.p
 
 
 # 导入数据后,查询问题。通义千问模型会根据 RAG 系统的检索结果,结合上下文,给出答案。
+<<<<<<< HEAD
 
+=======
+>>>>>>> 0140d7d271ac3b1561ca2272030e9e038b469999
 python query.py -q 'how about the rights of men'
 
 ## outputs

+ 54 - 0
tests/test_cli/test_bench.py

@@ -0,0 +1,54 @@
+"""
+bench
+"""
+import os
+import shutil
+import json
+from lib import calculate_score
+import pytest
+from conf import conf
+
+code_path = os.environ.get('GITHUB_WORKSPACE')
+pdf_dev_path = conf.conf["pdf_dev_path"]
+pdf_res_path = conf.conf["pdf_res_path"]
+
+class TestBench():
+    """
+    test bench
+    """
+    def test_ci_ben(self):
+        """
+        ci benchmark
+        """
+        fr = open(os.path.join(pdf_dev_path, "result.json"), "r", encoding="utf-8")
+        lines = fr.readlines()
+        last_line = lines[-1].strip()
+        last_score = json.loads(last_line)
+        last_simscore = last_score["average_sim_score"]
+        last_editdistance = last_score["average_edit_distance"]
+        last_bleu = last_score["average_bleu_score"]
+        os.system(f"python tests/test_cli/lib/pre_clean.py --tool_name mineru --download_dir {pdf_dev_path}")
+        now_score = get_score()
+        print ("now_score:", now_score)
+        if not os.path.exists(os.path.join(pdf_dev_path, "ci")):
+            os.makedirs(os.path.join(pdf_dev_path, "ci"), exist_ok=True)
+        fw = open(os.path.join(pdf_dev_path, "ci", "result.json"), "w+", encoding="utf-8")
+        fw.write(json.dumps(now_score) + "\n")
+        now_simscore = now_score["average_sim_score"]
+        now_editdistance = now_score["average_edit_distance"]
+        now_bleu = now_score["average_bleu_score"]
+        assert last_simscore <= now_simscore
+        assert last_editdistance <= now_editdistance
+        assert last_bleu <= now_bleu
+
+
+def get_score():
+    """
+    get score
+    """
+    score = calculate_score.Scoring(os.path.join(pdf_dev_path, "result.json"))
+    score.calculate_similarity_total("mineru", pdf_dev_path)
+    res = score.summary_scores()
+    return res
+
+

二進制
tests/test_table/assets/table.jpg


+ 18 - 0
tests/test_table/test_tablemaster.py

@@ -0,0 +1,18 @@
+import unittest
+from PIL import Image
+from magic_pdf.model.ppTableModel import ppTableModel
+
+class TestppTableModel(unittest.TestCase):
+    def test_image2html(self):
+        img = Image.open("tests/test_table/assets/table.jpg")
+        # 修改table模型路径
+        config = {"device": "cuda",
+                  "model_dir": "D:/models/PDF-Extract-Kit/models/TabRec/TableMaster"}
+        table_model = ppTableModel(config)
+        res = table_model.img2html(img)
+        true_value = """<td><table  border="1"><thead><tr><td><b>Methods</b></td><td><b>R</b></td><td><b>P</b></td><td><b>F</b></td><td><b>FPS</b></td></tr></thead><tbody><tr><td>SegLink [26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td>-</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2 </td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN[3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>-</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td>-</td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN [16]</td><td>79</td><td>88.</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td>-</td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td>-</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></tbody></table></td>\n"""
+        self.assertEqual(true_value, res)
+
+
+if __name__ == "__main__":
+    unittest.main()

部分文件因文件數量過多而無法顯示