model_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import time
  2. import torch
  3. from loguru import logger
  4. import numpy as np
  5. from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio
  6. from magic_pdf.libs.clean_memory import clean_memory
  7. def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
  8. crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
  9. crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
  10. # Calculate new dimensions
  11. crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
  12. crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
  13. # Create a white background array
  14. return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255
  15. # Crop the original image using numpy slicing
  16. cropped_img = input_np_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
  17. # Paste the cropped image onto the white background
  18. return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
  19. crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img
  20. return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width,
  21. crop_new_height]
  22. return return_image, return_list
  23. def get_coords_and_area(table):
  24. """Extract coordinates and area from a table."""
  25. xmin, ymin = int(table['poly'][0]), int(table['poly'][1])
  26. xmax, ymax = int(table['poly'][4]), int(table['poly'][5])
  27. area = (xmax - xmin) * (ymax - ymin)
  28. return xmin, ymin, xmax, ymax, area
  29. def calculate_intersection(box1, box2):
  30. """Calculate intersection coordinates between two boxes."""
  31. intersection_xmin = max(box1[0], box2[0])
  32. intersection_ymin = max(box1[1], box2[1])
  33. intersection_xmax = min(box1[2], box2[2])
  34. intersection_ymax = min(box1[3], box2[3])
  35. # Check if intersection is valid
  36. if intersection_xmax <= intersection_xmin or intersection_ymax <= intersection_ymin:
  37. return None
  38. return intersection_xmin, intersection_ymin, intersection_xmax, intersection_ymax
  39. def calculate_iou(box1, box2):
  40. """Calculate IoU between two boxes."""
  41. intersection = calculate_intersection(box1[:4], box2[:4])
  42. if not intersection:
  43. return 0
  44. intersection_xmin, intersection_ymin, intersection_xmax, intersection_ymax = intersection
  45. intersection_area = (intersection_xmax - intersection_xmin) * (intersection_ymax - intersection_ymin)
  46. area1, area2 = box1[4], box2[4]
  47. union_area = area1 + area2 - intersection_area
  48. return intersection_area / union_area if union_area > 0 else 0
  49. def is_inside(small_box, big_box, overlap_threshold=0.8):
  50. """Check if small_box is inside big_box by at least overlap_threshold."""
  51. intersection = calculate_intersection(small_box[:4], big_box[:4])
  52. if not intersection:
  53. return False
  54. intersection_xmin, intersection_ymin, intersection_xmax, intersection_ymax = intersection
  55. intersection_area = (intersection_xmax - intersection_xmin) * (intersection_ymax - intersection_ymin)
  56. # Check if overlap exceeds threshold
  57. return intersection_area >= overlap_threshold * small_box[4]
  58. def do_overlap(box1, box2):
  59. """Check if two boxes overlap."""
  60. return calculate_intersection(box1[:4], box2[:4]) is not None
  61. def merge_high_iou_tables(table_res_list, layout_res, table_indices, iou_threshold=0.7):
  62. """Merge tables with IoU > threshold."""
  63. if len(table_res_list) < 2:
  64. return table_res_list, table_indices
  65. table_info = [get_coords_and_area(table) for table in table_res_list]
  66. merged = True
  67. while merged:
  68. merged = False
  69. i = 0
  70. while i < len(table_res_list) - 1:
  71. j = i + 1
  72. while j < len(table_res_list):
  73. iou = calculate_iou(table_info[i], table_info[j])
  74. if iou > iou_threshold:
  75. # Merge tables by taking their union
  76. x1_min, y1_min, x1_max, y1_max, _ = table_info[i]
  77. x2_min, y2_min, x2_max, y2_max, _ = table_info[j]
  78. union_xmin = min(x1_min, x2_min)
  79. union_ymin = min(y1_min, y2_min)
  80. union_xmax = max(x1_max, x2_max)
  81. union_ymax = max(y1_max, y2_max)
  82. # Create merged table
  83. merged_table = table_res_list[i].copy()
  84. merged_table['poly'][0] = union_xmin
  85. merged_table['poly'][1] = union_ymin
  86. merged_table['poly'][2] = union_xmax
  87. merged_table['poly'][3] = union_ymin
  88. merged_table['poly'][4] = union_xmax
  89. merged_table['poly'][5] = union_ymax
  90. merged_table['poly'][6] = union_xmin
  91. merged_table['poly'][7] = union_ymax
  92. # Update layout_res
  93. to_remove = [table_indices[j], table_indices[i]]
  94. for idx in sorted(to_remove, reverse=True):
  95. del layout_res[idx]
  96. layout_res.append(merged_table)
  97. # Update tracking lists
  98. table_indices = [k if k < min(to_remove) else
  99. k - 1 if k < max(to_remove) else
  100. k - 2 if k > max(to_remove) else
  101. len(layout_res) - 1
  102. for k in table_indices
  103. if k not in to_remove]
  104. table_indices.append(len(layout_res) - 1)
  105. # Update table lists
  106. table_res_list.pop(j)
  107. table_res_list.pop(i)
  108. table_res_list.append(merged_table)
  109. # Update table_info
  110. table_info = [get_coords_and_area(table) for table in table_res_list]
  111. merged = True
  112. break
  113. j += 1
  114. if merged:
  115. break
  116. i += 1
  117. return table_res_list, table_indices
  118. def filter_nested_tables(table_res_list, overlap_threshold=0.8, area_threshold=0.8):
  119. """Remove big tables containing multiple smaller tables within them."""
  120. if len(table_res_list) < 3:
  121. return table_res_list
  122. table_info = [get_coords_and_area(table) for table in table_res_list]
  123. big_tables_idx = []
  124. for i in range(len(table_res_list)):
  125. # Find tables inside this one
  126. tables_inside = [j for j in range(len(table_res_list))
  127. if i != j and is_inside(table_info[j], table_info[i], overlap_threshold)]
  128. # Continue if there are at least 3 tables inside
  129. if len(tables_inside) >= 3:
  130. # Check if inside tables overlap with each other
  131. tables_overlap = any(do_overlap(table_info[tables_inside[idx1]], table_info[tables_inside[idx2]])
  132. for idx1 in range(len(tables_inside))
  133. for idx2 in range(idx1 + 1, len(tables_inside)))
  134. # If no overlaps, check area condition
  135. if not tables_overlap:
  136. total_inside_area = sum(table_info[j][4] for j in tables_inside)
  137. big_table_area = table_info[i][4]
  138. if total_inside_area > area_threshold * big_table_area:
  139. big_tables_idx.append(i)
  140. return [table for i, table in enumerate(table_res_list) if i not in big_tables_idx]
  141. def remove_overlaps_min_blocks(res_list):
  142. # 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
  143. # 删除重叠blocks中较小的那些
  144. need_remove = []
  145. for res1 in res_list:
  146. for res2 in res_list:
  147. if res1 != res2:
  148. overlap_box = get_minbox_if_overlap_by_ratio(
  149. res1['bbox'], res2['bbox'], 0.8
  150. )
  151. if overlap_box is not None:
  152. res_to_remove = next(
  153. (res for res in res_list if res['bbox'] == overlap_box),
  154. None,
  155. )
  156. if (
  157. res_to_remove is not None
  158. and res_to_remove not in need_remove
  159. ):
  160. large_res = res1 if res1 != res_to_remove else res2
  161. x1, y1, x2, y2 = large_res['bbox']
  162. sx1, sy1, sx2, sy2 = res_to_remove['bbox']
  163. x1 = min(x1, sx1)
  164. y1 = min(y1, sy1)
  165. x2 = max(x2, sx2)
  166. y2 = max(y2, sy2)
  167. large_res['bbox'] = [x1, y1, x2, y2]
  168. need_remove.append(res_to_remove)
  169. if len(need_remove) > 0:
  170. for res in need_remove:
  171. res_list.remove(res)
  172. return res_list, need_remove
  173. def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8):
  174. """Extract OCR, table and other regions from layout results."""
  175. ocr_res_list = []
  176. text_res_list = []
  177. table_res_list = []
  178. table_indices = []
  179. single_page_mfdetrec_res = []
  180. # Categorize regions
  181. for i, res in enumerate(layout_res):
  182. category_id = int(res['category_id'])
  183. if category_id in [13, 14]: # Formula regions
  184. single_page_mfdetrec_res.append({
  185. "bbox": [int(res['poly'][0]), int(res['poly'][1]),
  186. int(res['poly'][4]), int(res['poly'][5])],
  187. })
  188. elif category_id in [0, 2, 4, 6, 7]: # OCR regions
  189. ocr_res_list.append(res)
  190. elif category_id == 5: # Table regions
  191. table_res_list.append(res)
  192. table_indices.append(i)
  193. elif category_id in [1]: # Text regions
  194. res['bbox'] = [int(res['poly'][0]), int(res['poly'][1]), int(res['poly'][4]), int(res['poly'][5])]
  195. text_res_list.append(res)
  196. # Process tables: merge high IoU tables first, then filter nested tables
  197. table_res_list, table_indices = merge_high_iou_tables(
  198. table_res_list, layout_res, table_indices, iou_threshold)
  199. filtered_table_res_list = filter_nested_tables(
  200. table_res_list, overlap_threshold, area_threshold)
  201. # Remove filtered out tables from layout_res
  202. if len(filtered_table_res_list) < len(table_res_list):
  203. kept_tables = set(id(table) for table in filtered_table_res_list)
  204. to_remove = [table_indices[i] for i, table in enumerate(table_res_list)
  205. if id(table) not in kept_tables]
  206. for idx in sorted(to_remove, reverse=True):
  207. del layout_res[idx]
  208. # Remove overlaps in OCR and text regions
  209. text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)
  210. for res in text_res_list:
  211. # 将res的poly使用bbox重构
  212. res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
  213. res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
  214. # 删除res的bbox
  215. del res['bbox']
  216. ocr_res_list.extend(text_res_list)
  217. if len(need_remove) > 0:
  218. for res in need_remove:
  219. del res['bbox']
  220. layout_res.remove(res)
  221. return ocr_res_list, filtered_table_res_list, single_page_mfdetrec_res
  222. def clean_vram(device, vram_threshold=8):
  223. total_memory = get_vram(device)
  224. if total_memory and total_memory <= vram_threshold:
  225. gc_start = time.time()
  226. clean_memory(device)
  227. gc_time = round(time.time() - gc_start, 2)
  228. logger.info(f"gc time: {gc_time}")
  229. def get_vram(device):
  230. if torch.cuda.is_available() and str(device).startswith("cuda"):
  231. total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
  232. return total_memory
  233. elif str(device).startswith("npu"):
  234. import torch_npu
  235. if torch_npu.npu.is_available():
  236. total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
  237. return total_memory
  238. else:
  239. return None