Browse Source

update cli

赵小蒙 1 year ago
parent
commit
ca8788fbe7
1 changed files with 24 additions and 15 deletions
  1. 24 15
      magic_pdf/cli/magicpdf.py

+ 24 - 15
magic_pdf/cli/magicpdf.py

@@ -48,6 +48,7 @@ import csv
 import copy
 
 parse_pdf_methods = click.Choice(["ocr", "txt", "auto"])
+use_inside_model = False
 
 
 def prepare_env(pdf_file_name, method):
@@ -70,17 +71,17 @@ def write_to_csv(csv_file_path, csv_data):
 
 
 def do_parse(
-    pdf_file_name,
-    pdf_bytes,
-    model_list,
-    parse_method,
-    f_draw_span_bbox=True,
-    f_draw_layout_bbox=True,
-    f_dump_md=True,
-    f_dump_middle_json=True,
-    f_dump_model_json=True,
-    f_dump_orig_pdf=True,
-    f_dump_content_list=True,
+        pdf_file_name,
+        pdf_bytes,
+        model_list,
+        parse_method,
+        f_draw_span_bbox=True,
+        f_draw_layout_bbox=True,
+        f_dump_md=True,
+        f_dump_middle_json=True,
+        f_dump_model_json=True,
+        f_dump_orig_pdf=True,
+        f_dump_content_list=True,
 ):
     orig_model_list = copy.deepcopy(model_list)
 
@@ -96,14 +97,18 @@ def do_parse(
     elif parse_method == "ocr":
         pipe = OCRPipe(pdf_bytes, model_list, image_writer, is_debug=True)
     else:
-        print("unknown parse method")
+        logger.error("unknown parse method")
         sys.exit(1)
 
     pipe.pipe_classify()
 
-    """如果没有传入有效的模型数据,则使用内置paddle解析"""
+    """如果没有传入有效的模型数据,则使用内置model解析"""
     if len(model_list) == 0:
-        pipe.pipe_analyze()
+        if use_inside_model:
+            pipe.pipe_analyze()
+        else:
+            logger.error("need model list input")
+            sys.exit(1)
 
     pipe.pipe_parse()
     pdf_info = pipe.pdf_mid_data["pdf_info"]
@@ -267,7 +272,11 @@ def local_json_command(local_json, method):
     help="指定解析方法。txt: 文本型 pdf 解析方法, ocr: 光学识别解析 pdf, auto: 程序智能选择解析方法",
     default="auto",
 )
-def pdf_command(pdf, model, method):
+@click.option("--inside_model", type=click.BOOL, default=False, help="使用内置模型测试")
+def pdf_command(pdf, model, method, inside_model):
+    global use_inside_model
+    use_inside_model = inside_model
+
     def read_fn(path):
         disk_rw = DiskReaderWriter(os.path.dirname(path))
         return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)