model_utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import time
  2. import gc
  3. from PIL import Image
  4. from loguru import logger
  5. import numpy as np
  6. from mineru.utils.boxbase import get_minbox_if_overlap_by_ratio
  7. try:
  8. import torch
  9. import torch_npu
  10. except ImportError:
  11. pass
  12. def crop_img(input_res, input_img, crop_paste_x=0, crop_paste_y=0):
  13. crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
  14. crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
  15. # Calculate new dimensions
  16. crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
  17. crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
  18. if isinstance(input_img, np.ndarray):
  19. # Create a white background array
  20. return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255
  21. # Crop the original image using numpy slicing
  22. cropped_img = input_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
  23. # Paste the cropped image onto the white background
  24. return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
  25. crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img
  26. else:
  27. # Create a white background array
  28. return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
  29. # Crop image
  30. crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
  31. cropped_img = input_img.crop(crop_box)
  32. return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
  33. return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width,
  34. crop_new_height]
  35. return return_image, return_list
  36. def get_coords_and_area(block_with_poly):
  37. """Extract coordinates and area from a table."""
  38. xmin, ymin = int(block_with_poly['poly'][0]), int(block_with_poly['poly'][1])
  39. xmax, ymax = int(block_with_poly['poly'][4]), int(block_with_poly['poly'][5])
  40. area = (xmax - xmin) * (ymax - ymin)
  41. return xmin, ymin, xmax, ymax, area
  42. def calculate_intersection(box1, box2):
  43. """Calculate intersection coordinates between two boxes."""
  44. intersection_xmin = max(box1[0], box2[0])
  45. intersection_ymin = max(box1[1], box2[1])
  46. intersection_xmax = min(box1[2], box2[2])
  47. intersection_ymax = min(box1[3], box2[3])
  48. # Check if intersection is valid
  49. if intersection_xmax <= intersection_xmin or intersection_ymax <= intersection_ymin:
  50. return None
  51. return intersection_xmin, intersection_ymin, intersection_xmax, intersection_ymax
  52. def calculate_iou(box1, box2):
  53. """Calculate IoU between two boxes."""
  54. intersection = calculate_intersection(box1[:4], box2[:4])
  55. if not intersection:
  56. return 0
  57. intersection_xmin, intersection_ymin, intersection_xmax, intersection_ymax = intersection
  58. intersection_area = (intersection_xmax - intersection_xmin) * (intersection_ymax - intersection_ymin)
  59. area1, area2 = box1[4], box2[4]
  60. union_area = area1 + area2 - intersection_area
  61. return intersection_area / union_area if union_area > 0 else 0
  62. def is_inside(small_box, big_box, overlap_threshold=0.8):
  63. """Check if small_box is inside big_box by at least overlap_threshold."""
  64. intersection = calculate_intersection(small_box[:4], big_box[:4])
  65. if not intersection:
  66. return False
  67. intersection_xmin, intersection_ymin, intersection_xmax, intersection_ymax = intersection
  68. intersection_area = (intersection_xmax - intersection_xmin) * (intersection_ymax - intersection_ymin)
  69. # Check if overlap exceeds threshold
  70. return intersection_area >= overlap_threshold * small_box[4]
  71. def do_overlap(box1, box2):
  72. """Check if two boxes overlap."""
  73. return calculate_intersection(box1[:4], box2[:4]) is not None
  74. def merge_high_iou_tables(table_res_list, layout_res, table_indices, iou_threshold=0.7):
  75. """Merge tables with IoU > threshold."""
  76. if len(table_res_list) < 2:
  77. return table_res_list, table_indices
  78. table_info = [get_coords_and_area(table) for table in table_res_list]
  79. merged = True
  80. while merged:
  81. merged = False
  82. i = 0
  83. while i < len(table_res_list) - 1:
  84. j = i + 1
  85. while j < len(table_res_list):
  86. iou = calculate_iou(table_info[i], table_info[j])
  87. if iou > iou_threshold:
  88. # Merge tables by taking their union
  89. x1_min, y1_min, x1_max, y1_max, _ = table_info[i]
  90. x2_min, y2_min, x2_max, y2_max, _ = table_info[j]
  91. union_xmin = min(x1_min, x2_min)
  92. union_ymin = min(y1_min, y2_min)
  93. union_xmax = max(x1_max, x2_max)
  94. union_ymax = max(y1_max, y2_max)
  95. # Create merged table
  96. merged_table = table_res_list[i].copy()
  97. merged_table['poly'][0] = union_xmin
  98. merged_table['poly'][1] = union_ymin
  99. merged_table['poly'][2] = union_xmax
  100. merged_table['poly'][3] = union_ymin
  101. merged_table['poly'][4] = union_xmax
  102. merged_table['poly'][5] = union_ymax
  103. merged_table['poly'][6] = union_xmin
  104. merged_table['poly'][7] = union_ymax
  105. # Update layout_res
  106. to_remove = [table_indices[j], table_indices[i]]
  107. for idx in sorted(to_remove, reverse=True):
  108. del layout_res[idx]
  109. layout_res.append(merged_table)
  110. # Update tracking lists
  111. table_indices = [k if k < min(to_remove) else
  112. k - 1 if k < max(to_remove) else
  113. k - 2 if k > max(to_remove) else
  114. len(layout_res) - 1
  115. for k in table_indices
  116. if k not in to_remove]
  117. table_indices.append(len(layout_res) - 1)
  118. # Update table lists
  119. table_res_list.pop(j)
  120. table_res_list.pop(i)
  121. table_res_list.append(merged_table)
  122. # Update table_info
  123. table_info = [get_coords_and_area(table) for table in table_res_list]
  124. merged = True
  125. break
  126. j += 1
  127. if merged:
  128. break
  129. i += 1
  130. return table_res_list, table_indices
  131. def filter_nested_tables(table_res_list, overlap_threshold=0.8, area_threshold=0.8):
  132. """Remove big tables containing multiple smaller tables within them."""
  133. if len(table_res_list) < 3:
  134. return table_res_list
  135. table_info = [get_coords_and_area(table) for table in table_res_list]
  136. big_tables_idx = []
  137. for i in range(len(table_res_list)):
  138. # Find tables inside this one
  139. tables_inside = [j for j in range(len(table_res_list))
  140. if i != j and is_inside(table_info[j], table_info[i], overlap_threshold)]
  141. # Continue if there are at least 3 tables inside
  142. if len(tables_inside) >= 3:
  143. # Check if inside tables overlap with each other
  144. tables_overlap = any(do_overlap(table_info[tables_inside[idx1]], table_info[tables_inside[idx2]])
  145. for idx1 in range(len(tables_inside))
  146. for idx2 in range(idx1 + 1, len(tables_inside)))
  147. # If no overlaps, check area condition
  148. if not tables_overlap:
  149. total_inside_area = sum(table_info[j][4] for j in tables_inside)
  150. big_table_area = table_info[i][4]
  151. if total_inside_area > area_threshold * big_table_area:
  152. big_tables_idx.append(i)
  153. return [table for i, table in enumerate(table_res_list) if i not in big_tables_idx]
  154. def remove_overlaps_min_blocks(res_list):
  155. # 重叠block,小的不能直接删除,需要和大的那个合并成一个更大的。
  156. # 删除重叠blocks中较小的那些
  157. need_remove = []
  158. for i in range(len(res_list)):
  159. # 如果当前元素已在需要移除列表中,则跳过
  160. if res_list[i] in need_remove:
  161. continue
  162. for j in range(i + 1, len(res_list)):
  163. # 如果比较对象已在需要移除列表中,则跳过
  164. if res_list[j] in need_remove:
  165. continue
  166. overlap_box = get_minbox_if_overlap_by_ratio(
  167. res_list[i]['bbox'], res_list[j]['bbox'], 0.8
  168. )
  169. if overlap_box is not None:
  170. res_to_remove = None
  171. large_res = None
  172. # 确定哪个是小块(要移除的)
  173. if overlap_box == res_list[i]['bbox']:
  174. res_to_remove = res_list[i]
  175. large_res = res_list[j]
  176. elif overlap_box == res_list[j]['bbox']:
  177. res_to_remove = res_list[j]
  178. large_res = res_list[i]
  179. if res_to_remove is not None and res_to_remove not in need_remove:
  180. # 更新大块的边界为两者的并集
  181. x1, y1, x2, y2 = large_res['bbox']
  182. sx1, sy1, sx2, sy2 = res_to_remove['bbox']
  183. x1 = min(x1, sx1)
  184. y1 = min(y1, sy1)
  185. x2 = max(x2, sx2)
  186. y2 = max(y2, sy2)
  187. large_res['bbox'] = [x1, y1, x2, y2]
  188. need_remove.append(res_to_remove)
  189. # 从列表中移除标记的元素
  190. for res in need_remove:
  191. res_list.remove(res)
  192. return res_list, need_remove
  193. def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshold=0.8, area_threshold=0.8):
  194. """Extract OCR, table and other regions from layout results."""
  195. ocr_res_list = []
  196. text_res_list = []
  197. table_res_list = []
  198. table_indices = []
  199. single_page_mfdetrec_res = []
  200. # Categorize regions
  201. for i, res in enumerate(layout_res):
  202. category_id = int(res['category_id'])
  203. if category_id in [13, 14]: # Formula regions
  204. single_page_mfdetrec_res.append({
  205. "bbox": [int(res['poly'][0]), int(res['poly'][1]),
  206. int(res['poly'][4]), int(res['poly'][5])],
  207. })
  208. elif category_id in [0, 2, 4, 6, 7, 3]: # OCR regions
  209. ocr_res_list.append(res)
  210. elif category_id == 5: # Table regions
  211. table_res_list.append(res)
  212. table_indices.append(i)
  213. elif category_id in [1]: # Text regions
  214. res['bbox'] = [int(res['poly'][0]), int(res['poly'][1]), int(res['poly'][4]), int(res['poly'][5])]
  215. text_res_list.append(res)
  216. # Process tables: merge high IoU tables first, then filter nested tables
  217. table_res_list, table_indices = merge_high_iou_tables(
  218. table_res_list, layout_res, table_indices, iou_threshold)
  219. filtered_table_res_list = filter_nested_tables(
  220. table_res_list, overlap_threshold, area_threshold)
  221. # Remove filtered out tables from layout_res
  222. if len(filtered_table_res_list) < len(table_res_list):
  223. kept_tables = set(id(table) for table in filtered_table_res_list)
  224. to_remove = [table_indices[i] for i, table in enumerate(table_res_list)
  225. if id(table) not in kept_tables]
  226. for idx in sorted(to_remove, reverse=True):
  227. del layout_res[idx]
  228. # Remove overlaps in OCR and text regions
  229. text_res_list, need_remove = remove_overlaps_min_blocks(text_res_list)
  230. for res in text_res_list:
  231. # 将res的poly使用bbox重构
  232. res['poly'] = [res['bbox'][0], res['bbox'][1], res['bbox'][2], res['bbox'][1],
  233. res['bbox'][2], res['bbox'][3], res['bbox'][0], res['bbox'][3]]
  234. # 删除res的bbox
  235. del res['bbox']
  236. ocr_res_list.extend(text_res_list)
  237. if len(need_remove) > 0:
  238. for res in need_remove:
  239. del res['bbox']
  240. layout_res.remove(res)
  241. # 新增:检测大block内部是否包含多个小block
  242. # 合并ocr和table列表进行检测
  243. combined_res_list = ocr_res_list + filtered_table_res_list
  244. # 计算每个block的坐标和面积
  245. block_info = []
  246. for block in combined_res_list:
  247. xmin, ymin = int(block['poly'][0]), int(block['poly'][1])
  248. xmax, ymax = int(block['poly'][4]), int(block['poly'][5])
  249. area = (xmax - xmin) * (ymax - ymin)
  250. score = block.get('score', 0.5) # 如果没有score字段,默认为0.5
  251. block_info.append((xmin, ymin, xmax, ymax, area, score, block))
  252. blocks_to_remove = []
  253. # 检查每个block内部是否有3个及以上的小block
  254. for i, (xmin, ymin, xmax, ymax, area, score, block) in enumerate(block_info):
  255. # 查找内部的小block
  256. blocks_inside = [(j, j_score, j_block) for j, (xj_min, yj_min, xj_max, yj_max, j_area, j_score, j_block) in
  257. enumerate(block_info)
  258. if i != j and is_inside(block_info[j], block_info[i])]
  259. # 如果内部有3个及以上的小block
  260. if len(blocks_inside) >= 3:
  261. # 计算小block的平均分数
  262. avg_score = sum(s for _, s, _ in blocks_inside) / len(blocks_inside)
  263. # 比较大block的分数和小block的平均分数
  264. if score > avg_score:
  265. # 保留大block,扩展其边界
  266. # 首先将所有小block标记为要删除
  267. for j, _, j_block in blocks_inside:
  268. if j_block not in blocks_to_remove:
  269. blocks_to_remove.append(j_block)
  270. # 扩展大block的边界以包含所有小block
  271. new_xmin, new_ymin, new_xmax, new_ymax = xmin, ymin, xmax, ymax
  272. for _, _, j_block in blocks_inside:
  273. j_xmin, j_ymin = int(j_block['poly'][0]), int(j_block['poly'][1])
  274. j_xmax, j_ymax = int(j_block['poly'][4]), int(j_block['poly'][5])
  275. new_xmin = min(new_xmin, j_xmin)
  276. new_ymin = min(new_ymin, j_ymin)
  277. new_xmax = max(new_xmax, j_xmax)
  278. new_ymax = max(new_ymax, j_ymax)
  279. # 更新大block的边界
  280. block['poly'][0] = block['poly'][6] = new_xmin
  281. block['poly'][1] = block['poly'][3] = new_ymin
  282. block['poly'][2] = block['poly'][4] = new_xmax
  283. block['poly'][5] = block['poly'][7] = new_ymax
  284. else:
  285. # 保留小blocks,删除大block
  286. blocks_to_remove.append(block)
  287. # 移除需要删除的blocks
  288. for block in blocks_to_remove:
  289. if block in ocr_res_list:
  290. ocr_res_list.remove(block)
  291. elif block in filtered_table_res_list:
  292. filtered_table_res_list.remove(block)
  293. # 同时从layout_res中删除
  294. if block in layout_res:
  295. layout_res.remove(block)
  296. return ocr_res_list, filtered_table_res_list, single_page_mfdetrec_res
  297. def clean_memory(device='cuda'):
  298. if device == 'cuda':
  299. if torch.cuda.is_available():
  300. torch.cuda.empty_cache()
  301. torch.cuda.ipc_collect()
  302. elif str(device).startswith("npu"):
  303. if torch_npu.npu.is_available():
  304. torch_npu.npu.empty_cache()
  305. elif str(device).startswith("mps"):
  306. torch.mps.empty_cache()
  307. gc.collect()
  308. def clean_vram(device, vram_threshold=8):
  309. total_memory = get_vram(device)
  310. if total_memory and total_memory <= vram_threshold:
  311. gc_start = time.time()
  312. clean_memory(device)
  313. gc_time = round(time.time() - gc_start, 2)
  314. logger.info(f"gc time: {gc_time}")
  315. def get_vram(device):
  316. if torch.cuda.is_available() and str(device).startswith("cuda"):
  317. total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
  318. return total_memory
  319. elif str(device).startswith("npu"):
  320. if torch_npu.npu.is_available():
  321. total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
  322. return total_memory
  323. else:
  324. return None