Ver código fonte

Merge pull request #584 from myhloli/update-unimernet-to-0.2.0

refactor(pdf_extract_kit): update model config and weight paths for UniMERNet-0.2.0
Xiaomeng Zhao 1 ano atrás
pai
commit
73f66af9bd

+ 1 - 1
magic_pdf/model/pdf_extract_kit.py

@@ -58,7 +58,7 @@ def mfd_model_init(weight):
 def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
     args = argparse.Namespace(cfg_path=cfg_path, options=None)
     cfg = Config(args)
-    cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
+    cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
     cfg.config.model.model_config.model_name = weight_dir
     cfg.config.model.tokenizer_config.path = weight_dir
     task = tasks.setup_task(cfg)

+ 7 - 7
magic_pdf/resources/model_config/UniMERNet/demo.yaml

@@ -2,13 +2,13 @@ model:
   arch: unimernet
   model_type: unimernet
   model_config:
-    model_name: ./models
-    max_seq_len: 1024
-    length_aware: False
+    model_name: ./models/unimernet_base
+    max_seq_len: 1536
+
   load_pretrained: True
-  pretrained: ./models/pytorch_model.bin
+  pretrained: './models/unimernet_base/pytorch_model.pth'
   tokenizer_config:
-    path: ./models
+    path: ./models/unimernet_base
 
 datasets:
   formula_rec_eval:
@@ -18,7 +18,7 @@ datasets:
         image_size:
           - 192
           - 672
-   
+
 run:
   runner: runner_iter
   task: unimernet_train
@@ -43,4 +43,4 @@ run:
   distributed_type: ddp  # or fsdp when train llm
 
   generate_cfg:
-    temperature: 0.0
+    temperature: 0.0

+ 1 - 1
magic_pdf/resources/model_config/model_configs.yaml

@@ -10,6 +10,6 @@ config:
 weights:
   layout: Layout/model_final.pth
   mfd: MFD/weights.pt
-  mfr: MFR/UniMERNet
+  mfr: MFR/unimernet_base
   struct_eqtable: TabRec/StructEqTable
   TableMaster: TabRec/TableMaster

+ 1 - 1
requirements-docker.txt

@@ -8,7 +8,7 @@ fast-langdetect==0.2.0
 wordninja>=2.0.0
 scikit-learn>=1.0.2
 pdfminer.six==20231228
-unimernet==0.1.6
+unimernet==0.2.0
 matplotlib
 ultralytics
 paddleocr==2.7.3

+ 1 - 1
setup.py

@@ -36,7 +36,7 @@ if __name__ == '__main__':
                      "paddlepaddle==3.0.0b1;platform_system=='Linux'",
                      "paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'",
                      ],
-            "full": ["unimernet==0.1.6",  # 0.1.6版本大幅裁剪依赖包范围,推荐使用此版本
+            "full": ["unimernet==0.2.0",  # unimernet升级0.2.0
                      "matplotlib<=3.9.0;platform_system=='Windows'",  # 3.9.1及之后不提供windows的预编译包,避免一些没有编译环境的windows设备安装失败
                      "matplotlib;platform_system=='Linux' or platform_system=='Darwin'",  # linux 和 macos 不应限制matplotlib的最高版本,以避免无法更新导致的一些bug
                      "ultralytics",  # yolov8,公式检测