model_init.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from .visualizer import Visualizer
  2. from .rcnn_vl import *
  3. from .backbone import *
  4. from detectron2.config import get_cfg
  5. from detectron2.config import CfgNode as CN
  6. from detectron2.data import MetadataCatalog, DatasetCatalog
  7. from detectron2.data.datasets import register_coco_instances
  8. from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch, DefaultPredictor
  9. def add_vit_config(cfg):
  10. """
  11. Add config for VIT.
  12. """
  13. _C = cfg
  14. _C.MODEL.VIT = CN()
  15. # CoaT model name.
  16. _C.MODEL.VIT.NAME = ""
  17. # Output features from CoaT backbone.
  18. _C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"]
  19. _C.MODEL.VIT.IMG_SIZE = [224, 224]
  20. _C.MODEL.VIT.POS_TYPE = "shared_rel"
  21. _C.MODEL.VIT.DROP_PATH = 0.
  22. _C.MODEL.VIT.MODEL_KWARGS = "{}"
  23. _C.SOLVER.OPTIMIZER = "ADAMW"
  24. _C.SOLVER.BACKBONE_MULTIPLIER = 1.0
  25. _C.AUG = CN()
  26. _C.AUG.DETR = False
  27. _C.MODEL.IMAGE_ONLY = True
  28. _C.PUBLAYNET_DATA_DIR_TRAIN = ""
  29. _C.PUBLAYNET_DATA_DIR_TEST = ""
  30. _C.FOOTNOTE_DATA_DIR_TRAIN = ""
  31. _C.FOOTNOTE_DATA_DIR_VAL = ""
  32. _C.SCIHUB_DATA_DIR_TRAIN = ""
  33. _C.SCIHUB_DATA_DIR_TEST = ""
  34. _C.JIAOCAI_DATA_DIR_TRAIN = ""
  35. _C.JIAOCAI_DATA_DIR_TEST = ""
  36. _C.ICDAR_DATA_DIR_TRAIN = ""
  37. _C.ICDAR_DATA_DIR_TEST = ""
  38. _C.M6DOC_DATA_DIR_TEST = ""
  39. _C.DOCSTRUCTBENCH_DATA_DIR_TEST = ""
  40. _C.DOCSTRUCTBENCHv2_DATA_DIR_TEST = ""
  41. _C.CACHE_DIR = ""
  42. _C.MODEL.CONFIG_PATH = ""
  43. # effective update steps would be MAX_ITER/GRADIENT_ACCUMULATION_STEPS
  44. # maybe need to set MAX_ITER *= GRADIENT_ACCUMULATION_STEPS
  45. _C.SOLVER.GRADIENT_ACCUMULATION_STEPS = 1
  46. def setup(args):
  47. """
  48. Create configs and perform basic setups.
  49. """
  50. cfg = get_cfg()
  51. # add_coat_config(cfg)
  52. add_vit_config(cfg)
  53. cfg.merge_from_file(args.config_file)
  54. cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.2 # set threshold for this model
  55. cfg.merge_from_list(args.opts)
  56. cfg.freeze()
  57. default_setup(cfg, args)
  58. register_coco_instances(
  59. "scihub_train",
  60. {},
  61. cfg.SCIHUB_DATA_DIR_TRAIN + ".json",
  62. cfg.SCIHUB_DATA_DIR_TRAIN
  63. )
  64. return cfg
  65. class DotDict(dict):
  66. def __init__(self, *args, **kwargs):
  67. super(DotDict, self).__init__(*args, **kwargs)
  68. def __getattr__(self, key):
  69. if key not in self.keys():
  70. return None
  71. value = self[key]
  72. if isinstance(value, dict):
  73. value = DotDict(value)
  74. return value
  75. def __setattr__(self, key, value):
  76. self[key] = value
  77. class Layoutlmv3_Predictor(object):
  78. def __init__(self, weights, config_file):
  79. layout_args = {
  80. "config_file": config_file,
  81. "resume": False,
  82. "eval_only": False,
  83. "num_gpus": 1,
  84. "num_machines": 1,
  85. "machine_rank": 0,
  86. "dist_url": "tcp://127.0.0.1:57823",
  87. "opts": ["MODEL.WEIGHTS", weights],
  88. }
  89. layout_args = DotDict(layout_args)
  90. cfg = setup(layout_args)
  91. self.mapping = ["title", "plain text", "abandon", "figure", "figure_caption", "table", "table_caption", "table_footnote", "isolate_formula", "formula_caption"]
  92. MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes = self.mapping
  93. self.predictor = DefaultPredictor(cfg)
  94. def __call__(self, image, ignore_catids=[]):
  95. page_layout_result = {
  96. "layout_dets": []
  97. }
  98. outputs = self.predictor(image)
  99. boxes = outputs["instances"].to("cpu")._fields["pred_boxes"].tensor.tolist()
  100. labels = outputs["instances"].to("cpu")._fields["pred_classes"].tolist()
  101. scores = outputs["instances"].to("cpu")._fields["scores"].tolist()
  102. for bbox_idx in range(len(boxes)):
  103. if labels[bbox_idx] in ignore_catids:
  104. continue
  105. page_layout_result["layout_dets"].append({
  106. "category_id": labels[bbox_idx],
  107. "poly": [
  108. boxes[bbox_idx][0], boxes[bbox_idx][1],
  109. boxes[bbox_idx][2], boxes[bbox_idx][1],
  110. boxes[bbox_idx][2], boxes[bbox_idx][3],
  111. boxes[bbox_idx][0], boxes[bbox_idx][3],
  112. ],
  113. "score": scores[bbox_idx]
  114. })
  115. return page_layout_result