|
|
@@ -61,16 +61,21 @@ def add_vit_config(cfg):
|
|
|
_C.SOLVER.GRADIENT_ACCUMULATION_STEPS = 1
|
|
|
|
|
|
|
|
|
-def setup(args):
|
|
|
+def setup(args, device):
|
|
|
"""
|
|
|
Create configs and perform basic setups.
|
|
|
"""
|
|
|
cfg = get_cfg()
|
|
|
+
|
|
|
# add_coat_config(cfg)
|
|
|
add_vit_config(cfg)
|
|
|
cfg.merge_from_file(args.config_file)
|
|
|
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2 # set threshold for this model
|
|
|
cfg.merge_from_list(args.opts)
|
|
|
+
|
|
|
+ # 使用统一的device配置
|
|
|
+ cfg.MODEL.DEVICE = device
|
|
|
+
|
|
|
cfg.freeze()
|
|
|
default_setup(cfg, args)
|
|
|
|
|
|
@@ -101,7 +106,7 @@ class DotDict(dict):
|
|
|
|
|
|
|
|
|
class Layoutlmv3_Predictor(object):
|
|
|
- def __init__(self, weights, config_file):
|
|
|
+ def __init__(self, weights, config_file, device):
|
|
|
layout_args = {
|
|
|
"config_file": config_file,
|
|
|
"resume": False,
|
|
|
@@ -114,7 +119,7 @@ class Layoutlmv3_Predictor(object):
|
|
|
}
|
|
|
layout_args = DotDict(layout_args)
|
|
|
|
|
|
- cfg = setup(layout_args)
|
|
|
+ cfg = setup(layout_args, device)
|
|
|
self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption",
|
|
|
"table_footnote", "isolate_formula", "formula_caption"]
|
|
|
MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping
|