Pārlūkot izejas kodu

fix: standardize parameter names for formula and table options across modules

myhloli 4 mēneši atpakaļ
vecāks
revīzija
677a301a9f

+ 23 - 9
mineru/backend/vlm/vlm_middle_json_mkcontent.py

@@ -1,3 +1,5 @@
+import os
+
 from mineru.utils.config_reader import get_latex_delimiter_config
 from mineru.utils.enum_class import MakeMode, BlockType, ContentType
 
@@ -16,7 +18,7 @@ display_right_delimiter = delimiters['display']['right']
 inline_left_delimiter = delimiters['inline']['left']
 inline_right_delimiter = delimiters['inline']['right']
 
-def merge_para_with_text(para_block):
+def merge_para_with_text(para_block, formula_enable=True, img_buket_path=''):
     para_text = ''
     for line in para_block['lines']:
         for j, span in enumerate(line['spans']):
@@ -27,7 +29,11 @@ def merge_para_with_text(para_block):
             elif span_type == ContentType.INLINE_EQUATION:
                 content = f"{inline_left_delimiter}{span['content']}{inline_right_delimiter}"
             elif span_type == ContentType.INTERLINE_EQUATION:
-                content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
+                if formula_enable:
+                    content = f"\n{display_left_delimiter}\n{span['content']}\n{display_right_delimiter}\n"
+                else:
+                    if span.get('image_path', ''):
+                        content = f"![]({img_buket_path}/{span['image_path']})"
             # content = content.strip()
             if content:
                 if span_type in [ContentType.TEXT, ContentType.INLINE_EQUATION]:
@@ -39,13 +45,13 @@ def merge_para_with_text(para_block):
                     para_text += content
     return para_text
 
-def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
+def mk_blocks_to_markdown(para_blocks, make_mode, formula_enable, table_enable, img_buket_path=''):
     page_markdown = []
     for para_block in para_blocks:
         para_text = ''
         para_type = para_block['type']
         if para_type in [BlockType.TEXT, BlockType.LIST, BlockType.INDEX, BlockType.INTERLINE_EQUATION]:
-            para_text = merge_para_with_text(para_block)
+            para_text = merge_para_with_text(para_block, formula_enable=formula_enable, img_buket_path=img_buket_path)
         elif para_type == BlockType.TITLE:
             title_level = get_title_level(para_block)
             para_text = f'{"#" * title_level} {merge_para_with_text(para_block)}'
@@ -95,10 +101,14 @@ def mk_blocks_to_markdown(para_blocks, make_mode, img_buket_path=''):
                             for span in line['spans']:
                                 if span['type'] == ContentType.TABLE:
                                     # if processed by table model
-                                    if span.get('html', ''):
-                                        para_text += f"\n{span['html']}\n"
-                                    elif span.get('image_path', ''):
-                                        para_text += f"![]({img_buket_path}/{span['image_path']})"
+                                    if table_enable:
+                                        if span.get('html', ''):
+                                            para_text += f"\n{span['html']}\n"
+                                        elif span.get('image_path', ''):
+                                            para_text += f"![]({img_buket_path}/{span['image_path']})"
+                                    else:
+                                        if span.get('image_path', ''):
+                                            para_text += f"![]({img_buket_path}/{span['image_path']})"
                 for block in para_block['blocks']:  # 3rd.拼table_footnote
                     if block['type'] == BlockType.TABLE_FOOTNOTE:
                         para_text += '\n' + merge_para_with_text(block) + '  '
@@ -177,6 +187,10 @@ def union_make(pdf_info_dict: list,
                make_mode: str,
                img_buket_path: str = '',
                ):
+
+    formula_enable = os.getenv('MINERU_FORMULA_ENABLE', 'True').lower() == 'true'
+    table_enable = os.getenv('MINERU_TABLE_ENABLE', 'True').lower() == 'true'
+
     output_content = []
     for page_info in pdf_info_dict:
         paras_of_layout = page_info.get('para_blocks')
@@ -184,7 +198,7 @@ def union_make(pdf_info_dict: list,
         if not paras_of_layout:
             continue
         if make_mode in [MakeMode.MM_MD, MakeMode.NLP_MD]:
-            page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, img_buket_path)
+            page_markdown = mk_blocks_to_markdown(paras_of_layout, make_mode, formula_enable, table_enable, img_buket_path)
             output_content.extend(page_markdown)
         elif make_mode == MakeMode.CONTENT_LIST:
             for para_block in paras_of_layout:

+ 2 - 2
mineru/cli/client.py

@@ -180,8 +180,8 @@ def main(input_path, output_dir, method, backend, lang, server_url, start_page_i
                 p_lang_list=lang_list,
                 backend=backend,
                 parse_method=method,
-                p_formula_enable=formula_enable,
-                p_table_enable=table_enable,
+                formula_enable=formula_enable,
+                table_enable=table_enable,
                 server_url=server_url,
                 start_page_id=start_page_id,
                 end_page_id=end_page_id

+ 12 - 6
mineru/cli/common.py

@@ -298,8 +298,8 @@ def do_parse(
         p_lang_list: list[str],
         backend="pipeline",
         parse_method="auto",
-        p_formula_enable=True,
-        p_table_enable=True,
+        formula_enable=True,
+        table_enable=True,
         server_url=None,
         f_draw_layout_bbox=True,
         f_draw_span_bbox=True,
@@ -318,7 +318,7 @@ def do_parse(
     if backend == "pipeline":
         _process_pipeline(
             output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
-            parse_method, p_formula_enable, p_table_enable,
+            parse_method, formula_enable, table_enable,
             f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
             f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
         )
@@ -326,6 +326,9 @@ def do_parse(
         if backend.startswith("vlm-"):
             backend = backend[4:]
 
+        os.environ['MINERU_FORMULA_ENABLE'] = str(formula_enable)
+        os.environ['MINERU_TABLE_ENABLE'] = str(table_enable)
+
         _process_vlm(
             output_dir, pdf_file_names, pdf_bytes_list, backend,
             f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
@@ -341,8 +344,8 @@ async def aio_do_parse(
         p_lang_list: list[str],
         backend="pipeline",
         parse_method="auto",
-        p_formula_enable=True,
-        p_table_enable=True,
+        formula_enable=True,
+        table_enable=True,
         server_url=None,
         f_draw_layout_bbox=True,
         f_draw_span_bbox=True,
@@ -362,7 +365,7 @@ async def aio_do_parse(
         # pipeline模式暂不支持异步,使用同步处理方式
         _process_pipeline(
             output_dir, pdf_file_names, pdf_bytes_list, p_lang_list,
-            parse_method, p_formula_enable, p_table_enable,
+            parse_method, formula_enable, table_enable,
             f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,
             f_dump_model_output, f_dump_orig_pdf, f_dump_content_list, f_make_md_mode
         )
@@ -370,6 +373,9 @@ async def aio_do_parse(
         if backend.startswith("vlm-"):
             backend = backend[4:]
 
+        os.environ['MINERU_FORMULA_ENABLE'] = str(formula_enable)
+        os.environ['MINERU_TABLE_ENABLE'] = str(table_enable)
+
         await _async_process_vlm(
             output_dir, pdf_file_names, pdf_bytes_list, backend,
             f_draw_layout_bbox, f_draw_span_bbox, f_dump_md, f_dump_middle_json,

+ 17 - 13
mineru/cli/gradio_app.py

@@ -38,8 +38,8 @@ async def parse_pdf(doc_path, output_dir, end_page_id, is_ocr, formula_enable, t
             p_lang_list=[language],
             parse_method=parse_method,
             end_page_id=end_page_id,
-            p_formula_enable=formula_enable,
-            p_table_enable=table_enable,
+            formula_enable=formula_enable,
+            table_enable=table_enable,
             backend=backend,
             server_url=url,
         )
@@ -179,11 +179,11 @@ def to_pdf(file_path):
 # 更新界面函数
 def update_interface(backend_choice):
     if backend_choice in ["vlm-transformers", "vlm-sglang-engine"]:
-        return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
+        return gr.update(visible=False), gr.update(visible=False)
     elif backend_choice in ["vlm-sglang-client"]:  # pipeline
-        return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
+        return gr.update(visible=True), gr.update(visible=False)
     elif backend_choice in ["pipeline"]:
-        return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
+        return gr.update(visible=False), gr.update(visible=True)
     else:
         pass
 
@@ -266,14 +266,18 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil
                         drop_list = ["pipeline", "vlm-transformers", "vlm-sglang-client"]
                         preferred_option = "pipeline"
                     backend = gr.Dropdown(drop_list, label="Backend", value=preferred_option)
-                with gr.Row(visible=False) as ocr_options:
-                    language = gr.Dropdown(all_lang, label='Language', value='ch')
+                # with gr.Row(visible=False) as lang_options:
+
                 with gr.Row(visible=False) as client_options:
                     url = gr.Textbox(label='Server URL', value='http://localhost:30000', placeholder='http://localhost:30000')
-                with gr.Row(visible=False) as pipeline_options:
-                    is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
-                    formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
-                    table_enable = gr.Checkbox(label='Enable table recognition(test)', value=True)
+                with gr.Row(equal_height=True):
+                    with gr.Column():
+                        gr.Markdown("**Recognition Options:**")
+                        formula_enable = gr.Checkbox(label='Enable formula recognition', value=True)
+                        table_enable = gr.Checkbox(label='Enable table recognition', value=True)
+                    with gr.Column(visible=False) as ocr_options:
+                        language = gr.Dropdown(all_lang, label='Language', value='ch')
+                        is_ocr = gr.Checkbox(label='Force enable OCR', value=False)
                 with gr.Row():
                     change_bu = gr.Button('Convert')
                     clear_bu = gr.ClearButton(value='Clear')
@@ -302,14 +306,14 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil
         backend.change(
             fn=update_interface,
             inputs=[backend],
-            outputs=[client_options, ocr_options, pipeline_options],
+            outputs=[client_options, ocr_options],
             api_name=False
         )
         # 添加demo.load事件,在页面加载时触发一次界面更新
         demo.load(
             fn=update_interface,
             inputs=[backend],
-            outputs=[client_options, ocr_options, pipeline_options],
+            outputs=[client_options, ocr_options],
             api_name=False
         )
         clear_bu.add([input_file, md, pdf_show, md_text, output_file, is_ocr])