model_init.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from .visualizer import Visualizer
  2. from .rcnn_vl import *
  3. from .backbone import *
  4. from detectron2.config import get_cfg
  5. from detectron2.config import CfgNode as CN
  6. from detectron2.data import MetadataCatalog, DatasetCatalog
  7. from detectron2.data.datasets import register_coco_instances
  8. from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, DefaultPredictor
  9. def add_vit_config(cfg):
  10. """
  11. Add config for VIT.
  12. """
  13. _C = cfg
  14. _C.MODEL.VIT = CN()
  15. # CoaT model name.
  16. _C.MODEL.VIT.NAME = ""
  17. # Output features from CoaT backbone.
  18. _C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"]
  19. _C.MODEL.VIT.IMG_SIZE = [224, 224]
  20. _C.MODEL.VIT.POS_TYPE = "shared_rel"
  21. _C.MODEL.VIT.DROP_PATH = 0.
  22. _C.MODEL.VIT.MODEL_KWARGS = "{}"
  23. _C.SOLVER.OPTIMIZER = "ADAMW"
  24. _C.SOLVER.BACKBONE_MULTIPLIER = 1.0
  25. _C.AUG = CN()
  26. _C.AUG.DETR = False
  27. _C.MODEL.IMAGE_ONLY = True
  28. _C.PUBLAYNET_DATA_DIR_TRAIN = ""
  29. _C.PUBLAYNET_DATA_DIR_TEST = ""
  30. _C.FOOTNOTE_DATA_DIR_TRAIN = ""
  31. _C.FOOTNOTE_DATA_DIR_VAL = ""
  32. _C.SCIHUB_DATA_DIR_TRAIN = ""
  33. _C.SCIHUB_DATA_DIR_TEST = ""
  34. _C.JIAOCAI_DATA_DIR_TRAIN = ""
  35. _C.JIAOCAI_DATA_DIR_TEST = ""
  36. _C.ICDAR_DATA_DIR_TRAIN = ""
  37. _C.ICDAR_DATA_DIR_TEST = ""
  38. _C.M6DOC_DATA_DIR_TEST = ""
  39. _C.DOCSTRUCTBENCH_DATA_DIR_TEST = ""
  40. _C.DOCSTRUCTBENCHv2_DATA_DIR_TEST = ""
  41. _C.CACHE_DIR = ""
  42. _C.MODEL.CONFIG_PATH = ""
  43. # effective update steps would be MAX_ITER/GRADIENT_ACCUMULATION_STEPS
  44. # maybe need to set MAX_ITER *= GRADIENT_ACCUMULATION_STEPS
  45. _C.SOLVER.GRADIENT_ACCUMULATION_STEPS = 1
  46. def setup(args, device):
  47. """
  48. Create configs and perform basic setups.
  49. """
  50. cfg = get_cfg()
  51. # add_coat_config(cfg)
  52. add_vit_config(cfg)
  53. cfg.merge_from_file(args.config_file)
  54. cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2 # set threshold for this model
  55. cfg.merge_from_list(args.opts)
  56. # 使用统一的device配置
  57. cfg.MODEL.DEVICE = device
  58. cfg.freeze()
  59. default_setup(cfg, args)
  60. register_coco_instances(
  61. "scihub_train",
  62. {},
  63. cfg.SCIHUB_DATA_DIR_TRAIN + ".json",
  64. cfg.SCIHUB_DATA_DIR_TRAIN
  65. )
  66. return cfg
  67. class DotDict(dict):
  68. def __init__(self, *args, **kwargs):
  69. super(DotDict, self).__init__(*args, **kwargs)
  70. def __getattr__(self, key):
  71. if key not in self.keys():
  72. return None
  73. value = self[key]
  74. if isinstance(value, dict):
  75. value = DotDict(value)
  76. return value
  77. def __setattr__(self, key, value):
  78. self[key] = value
  79. class Layoutlmv3_Predictor(object):
  80. def __init__(self, weights, config_file, device):
  81. layout_args = {
  82. "config_file": config_file,
  83. "resume": False,
  84. "eval_only": False,
  85. "num_gpus": 1,
  86. "num_machines": 1,
  87. "machine_rank": 0,
  88. "dist_url": "tcp://127.0.0.1:57823",
  89. "opts": ["MODEL.WEIGHTS", weights],
  90. }
  91. layout_args = DotDict(layout_args)
  92. cfg = setup(layout_args, device)
  93. self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption",
  94. "table_footnote", "isolate_formula", "formula_caption"]
  95. MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping
  96. self.predictor = DefaultPredictor(cfg)
  97. def __call__(self, image, ignore_catids=[]):
  98. # page_layout_result = {
  99. # "layout_dets": []
  100. # }
  101. layout_dets = []
  102. outputs = self.predictor(image)
  103. boxes = outputs["instances"].to("cpu")._fields["pred_boxes"].tensor.tolist()
  104. labels = outputs["instances"].to("cpu")._fields["pred_classes"].tolist()
  105. scores = outputs["instances"].to("cpu")._fields["scores"].tolist()
  106. for bbox_idx in range(len(boxes)):
  107. if labels[bbox_idx] in ignore_catids:
  108. continue
  109. layout_dets.append({
  110. "category_id": labels[bbox_idx],
  111. "poly": [
  112. boxes[bbox_idx][0], boxes[bbox_idx][1],
  113. boxes[bbox_idx][2], boxes[bbox_idx][1],
  114. boxes[bbox_idx][2], boxes[bbox_idx][3],
  115. boxes[bbox_idx][0], boxes[bbox_idx][3],
  116. ],
  117. "score": scores[bbox_idx]
  118. })
  119. return layout_dets