model_utils.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import time
  2. import torch
  3. from PIL import Image
  4. from loguru import logger
  5. from magic_pdf.libs.clean_memory import clean_memory
  6. def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
  7. crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
  8. crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
  9. # Create a white background with an additional width and height of 50
  10. crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
  11. crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
  12. return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
  13. # Crop image
  14. crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
  15. cropped_img = input_pil_img.crop(crop_box)
  16. return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
  17. return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
  18. return return_image, return_list
  19. # Select regions for OCR / formula regions / table regions
  20. def get_res_list_from_layout_res(layout_res):
  21. ocr_res_list = []
  22. table_res_list = []
  23. single_page_mfdetrec_res = []
  24. for res in layout_res:
  25. if int(res['category_id']) in [13, 14]:
  26. single_page_mfdetrec_res.append({
  27. "bbox": [int(res['poly'][0]), int(res['poly'][1]),
  28. int(res['poly'][4]), int(res['poly'][5])],
  29. })
  30. elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
  31. ocr_res_list.append(res)
  32. elif int(res['category_id']) in [5]:
  33. table_res_list.append(res)
  34. return ocr_res_list, table_res_list, single_page_mfdetrec_res
  35. def clean_vram(device, vram_threshold=8):
  36. total_memory = get_vram(device)
  37. if total_memory and total_memory <= vram_threshold:
  38. gc_start = time.time()
  39. clean_memory(device)
  40. gc_time = round(time.time() - gc_start, 2)
  41. logger.info(f"gc time: {gc_time}")
  42. def get_vram(device):
  43. if torch.cuda.is_available() and device != 'cpu':
  44. total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) # 将字节转换为 GB
  45. return total_memory
  46. elif str(device).startswith("npu"):
  47. import torch_npu
  48. if torch.npu.is_available():
  49. total_memory = torch.npu.get_device_properties(device).total_memory / (1024 ** 3) # 转为 GB
  50. return total_memory
  51. else:
  52. return None