Sfoglia il codice sorgente

feat(model-config): Unify all device selections through a single YAML file

myhloli 1 anno fa
parent
commit
45e7fbd2d8

+ 4 - 3
magic_pdf/model/pdf_extract_kit.py

@@ -19,8 +19,8 @@ from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex
 from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
 
 
-def layout_model_init(weight, config_file):
-    model = Layoutlmv3_Predictor(weight, config_file)
+def layout_model_init(weight, config_file, device):
+    model = Layoutlmv3_Predictor(weight, config_file, device)
     return model
 
 
@@ -89,7 +89,8 @@ class CustomPEKModel:
         # 初始化layout模型
         self.layout_model = layout_model_init(
             os.path.join(root_dir, self.configs['weights']['layout']),
-            os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")
+            os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml"),
+            device=self.device
         )
         # 初始化公式识别
         if self.apply_formula:

+ 8 - 3
magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py

@@ -61,16 +61,21 @@ def add_vit_config(cfg):
     _C.SOLVER.GRADIENT_ACCUMULATION_STEPS = 1
 
 
-def setup(args):
+def setup(args, device):
     """
     Create configs and perform basic setups.
     """
     cfg = get_cfg()
+
     # add_coat_config(cfg)
     add_vit_config(cfg)
     cfg.merge_from_file(args.config_file)
     cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2  # set threshold for this model
     cfg.merge_from_list(args.opts)
+
+    # 使用统一的device配置
+    cfg.MODEL.DEVICE = device
+
     cfg.freeze()
     default_setup(cfg, args)
 
@@ -101,7 +106,7 @@ class DotDict(dict):
 
 
 class Layoutlmv3_Predictor(object):
-    def __init__(self, weights, config_file):
+    def __init__(self, weights, config_file, device):
         layout_args = {
             "config_file": config_file,
             "resume": False,
@@ -114,7 +119,7 @@ class Layoutlmv3_Predictor(object):
         }
         layout_args = DotDict(layout_args)
 
-        cfg = setup(layout_args)
+        cfg = setup(layout_args, device)
         self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption",
                         "table_footnote", "isolate_formula", "formula_caption"]
         MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping

+ 1 - 1
magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml

@@ -69,7 +69,7 @@ MODEL:
     FREEZE_AT: 2
     NAME: build_vit_fpn_backbone
   CONFIG_PATH: ''
-  DEVICE: cpu
+  DEVICE: cuda
   FPN:
     FUSE_TYPE: sum
     IN_FEATURES: