vlm_magic_model.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. from typing import Literal
  2. from .boxbase import bbox_distance, is_in
  3. def __reduct_overlap(bboxes):
  4. N = len(bboxes)
  5. keep = [True] * N
  6. for i in range(N):
  7. for j in range(N):
  8. if i == j:
  9. continue
  10. if is_in(bboxes[i]["bbox"], bboxes[j]["bbox"]):
  11. keep[i] = False
  12. return [bboxes[i] for i in range(N) if keep[i]]
  13. def __tie_up_category_by_distance_v3(
  14. blocks: list,
  15. subject_block_type: str,
  16. object_block_type: str,
  17. ):
  18. subjects = __reduct_overlap(
  19. list(
  20. map(
  21. lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
  22. filter(
  23. lambda x: x["type"] == subject_block_type,
  24. blocks,
  25. ),
  26. )
  27. )
  28. )
  29. objects = __reduct_overlap(
  30. list(
  31. map(
  32. lambda x: {"bbox": x["bbox"], "lines": x["lines"], "index": x["index"]},
  33. filter(
  34. lambda x: x["type"] == object_block_type,
  35. blocks,
  36. ),
  37. )
  38. )
  39. )
  40. ret = []
  41. N, M = len(subjects), len(objects)
  42. subjects.sort(key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2)
  43. objects.sort(key=lambda x: x["bbox"][0] ** 2 + x["bbox"][1] ** 2)
  44. OBJ_IDX_OFFSET = 10000
  45. SUB_BIT_KIND, OBJ_BIT_KIND = 0, 1
  46. all_boxes_with_idx = [(i, SUB_BIT_KIND, sub["bbox"][0], sub["bbox"][1]) for i, sub in enumerate(subjects)] + [
  47. (i + OBJ_IDX_OFFSET, OBJ_BIT_KIND, obj["bbox"][0], obj["bbox"][1]) for i, obj in enumerate(objects)
  48. ]
  49. seen_idx = set()
  50. seen_sub_idx = set()
  51. while N > len(seen_sub_idx):
  52. candidates = []
  53. for idx, kind, x0, y0 in all_boxes_with_idx:
  54. if idx in seen_idx:
  55. continue
  56. candidates.append((idx, kind, x0, y0))
  57. if len(candidates) == 0:
  58. break
  59. left_x = min([v[2] for v in candidates])
  60. top_y = min([v[3] for v in candidates])
  61. candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y) ** 2)
  62. fst_idx, fst_kind, left_x, top_y = candidates[0]
  63. candidates.sort(key=lambda x: (x[2] - left_x) ** 2 + (x[3] - top_y) ** 2)
  64. nxt = None
  65. for i in range(1, len(candidates)):
  66. if candidates[i][1] ^ fst_kind == 1:
  67. nxt = candidates[i]
  68. break
  69. if nxt is None:
  70. break
  71. if fst_kind == SUB_BIT_KIND:
  72. sub_idx, obj_idx = fst_idx, nxt[0] - OBJ_IDX_OFFSET
  73. else:
  74. sub_idx, obj_idx = nxt[0], fst_idx - OBJ_IDX_OFFSET
  75. pair_dis = bbox_distance(subjects[sub_idx]["bbox"], objects[obj_idx]["bbox"])
  76. nearest_dis = float("inf")
  77. for i in range(N):
  78. if i in seen_idx or i == sub_idx:
  79. continue
  80. nearest_dis = min(nearest_dis, bbox_distance(subjects[i]["bbox"], objects[obj_idx]["bbox"]))
  81. if pair_dis >= 3 * nearest_dis:
  82. seen_idx.add(sub_idx)
  83. continue
  84. seen_idx.add(sub_idx)
  85. seen_idx.add(obj_idx + OBJ_IDX_OFFSET)
  86. seen_sub_idx.add(sub_idx)
  87. ret.append(
  88. {
  89. "sub_bbox": {
  90. "bbox": subjects[sub_idx]["bbox"],
  91. "lines": subjects[sub_idx]["lines"],
  92. "index": subjects[sub_idx]["index"],
  93. },
  94. "obj_bboxes": [
  95. {"bbox": objects[obj_idx]["bbox"], "lines": objects[obj_idx]["lines"], "index": objects[obj_idx]["index"]}
  96. ],
  97. "sub_idx": sub_idx,
  98. }
  99. )
  100. for i in range(len(objects)):
  101. j = i + OBJ_IDX_OFFSET
  102. if j in seen_idx:
  103. continue
  104. seen_idx.add(j)
  105. nearest_dis, nearest_sub_idx = float("inf"), -1
  106. for k in range(len(subjects)):
  107. dis = bbox_distance(objects[i]["bbox"], subjects[k]["bbox"])
  108. if dis < nearest_dis:
  109. nearest_dis = dis
  110. nearest_sub_idx = k
  111. for k in range(len(subjects)):
  112. if k != nearest_sub_idx:
  113. continue
  114. if k in seen_sub_idx:
  115. for kk in range(len(ret)):
  116. if ret[kk]["sub_idx"] == k:
  117. ret[kk]["obj_bboxes"].append(
  118. {"bbox": objects[i]["bbox"], "lines": objects[i]["lines"], "index": objects[i]["index"]}
  119. )
  120. break
  121. else:
  122. ret.append(
  123. {
  124. "sub_bbox": {
  125. "bbox": subjects[k]["bbox"],
  126. "lines": subjects[k]["lines"],
  127. "index": subjects[k]["index"],
  128. },
  129. "obj_bboxes": [
  130. {"bbox": objects[i]["bbox"], "lines": objects[i]["lines"], "index": objects[i]["index"]}
  131. ],
  132. "sub_idx": k,
  133. }
  134. )
  135. seen_sub_idx.add(k)
  136. seen_idx.add(k)
  137. for i in range(len(subjects)):
  138. if i in seen_sub_idx:
  139. continue
  140. ret.append(
  141. {
  142. "sub_bbox": {
  143. "bbox": subjects[i]["bbox"],
  144. "lines": subjects[i]["lines"],
  145. "index": subjects[i]["index"],
  146. },
  147. "obj_bboxes": [],
  148. "sub_idx": i,
  149. }
  150. )
  151. return ret
  152. def get_type_blocks(blocks, block_type: Literal["image", "table"]):
  153. with_captions = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_caption")
  154. with_footnotes = __tie_up_category_by_distance_v3(blocks, f"{block_type}_body", f"{block_type}_footnote")
  155. ret = []
  156. for v in with_captions:
  157. record = {
  158. f"{block_type}_body": v["sub_bbox"],
  159. f"{block_type}_caption_list": v["obj_bboxes"],
  160. }
  161. filter_idx = v["sub_idx"]
  162. d = next(filter(lambda x: x["sub_idx"] == filter_idx, with_footnotes))
  163. record[f"{block_type}_footnote_list"] = d["obj_bboxes"]
  164. ret.append(record)
  165. return ret
  166. def fix_two_layer_blocks(blocks, fix_type: Literal["image", "table"]):
  167. need_fix_blocks = get_type_blocks(blocks, fix_type)
  168. fixed_blocks = []
  169. for block in need_fix_blocks:
  170. body = block[f"{fix_type}_body"]
  171. caption_list = block[f"{fix_type}_caption_list"]
  172. footnote_list = block[f"{fix_type}_footnote_list"]
  173. body["type"] = f"{fix_type}_body"
  174. for caption in caption_list:
  175. caption["type"] = f"{fix_type}_caption"
  176. for footnote in footnote_list:
  177. footnote["type"] = f"{fix_type}_footnote"
  178. two_layer_block = {
  179. "type": fix_type,
  180. "bbox": body["bbox"],
  181. "blocks": [
  182. body,
  183. ],
  184. "index": body["index"],
  185. }
  186. two_layer_block["blocks"].extend([*caption_list, *footnote_list])
  187. fixed_blocks.append(two_layer_block)
  188. return fixed_blocks