model_utils.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import time
  2. import torch
  3. from loguru import logger
  4. import numpy as np
  5. from magic_pdf.libs.clean_memory import clean_memory
  6. def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
  7. crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
  8. crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
  9. # Calculate new dimensions
  10. crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
  11. crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
  12. # Create a white background array
  13. return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255
  14. # Crop the original image using numpy slicing
  15. cropped_img = input_np_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]
  16. # Paste the cropped image onto the white background
  17. return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
  18. crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img
  19. return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width,
  20. crop_new_height]
  21. return return_image, return_list
  22. # Select regions for OCR / formula regions / table regions
  23. def get_res_list_from_layout_res(layout_res):
  24. ocr_res_list = []
  25. table_res_list = []
  26. single_page_mfdetrec_res = []
  27. for res in layout_res:
  28. if int(res['category_id']) in [13, 14]:
  29. single_page_mfdetrec_res.append({
  30. "bbox": [int(res['poly'][0]), int(res['poly'][1]),
  31. int(res['poly'][4]), int(res['poly'][5])],
  32. })
  33. elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
  34. ocr_res_list.append(res)
  35. elif int(res['category_id']) in [5]:
  36. table_res_list.append(res)
  37. return ocr_res_list, table_res_list, single_page_mfdetrec_res
  38. def clean_vram(device, vram_threshold=8):
  39. total_memory = get_vram(device)
  40. if total_memory and total_memory <= vram_threshold:
  41. gc_start = time.time()
  42. clean_memory(device)
  43. gc_time = round(time.time() - gc_start, 2)
  44. logger.info(f"gc time: {gc_time}")
  45. def get_vram(device):
  46. if torch.cuda.is_available() and str(device).startswith("cuda"):
  47. total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
  48. return total_memory
  49. elif str(device).startswith("npu"):
  50. import torch_npu
  51. if torch_npu.npu.is_available():
  52. total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
  53. return total_memory
  54. else:
  55. return None