utils.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import os
  3. from pathlib import Path
  4. import yaml
  5. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  6. from magic_pdf.config.constants import MODEL_NAME
  7. from magic_pdf.data.utils import load_images_from_pdf
  8. from magic_pdf.libs.config_reader import get_local_models_dir, get_device
  9. from magic_pdf.libs.pdf_check import extract_pages
  10. from magic_pdf.model.model_list import AtomicModel
  11. from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
  12. def get_model_config():
  13. local_models_dir = get_local_models_dir()
  14. device = get_device()
  15. current_file_path = os.path.abspath(__file__)
  16. root_dir = Path(current_file_path).parents[3]
  17. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  18. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  19. with open(config_path, 'r', encoding='utf-8') as f:
  20. configs = yaml.load(f, Loader=yaml.FullLoader)
  21. return root_dir, local_models_dir, device, configs
  22. def get_text_images(simple_images):
  23. _, local_models_dir, device, configs = get_model_config()
  24. atom_model_manager = AtomModelSingleton()
  25. temp_layout_model = atom_model_manager.get_atom_model(
  26. atom_model_name=AtomicModel.Layout,
  27. layout_model_name=MODEL_NAME.DocLayout_YOLO,
  28. doclayout_yolo_weights=str(
  29. os.path.join(
  30. local_models_dir, configs['weights'][MODEL_NAME.DocLayout_YOLO]
  31. )
  32. ),
  33. device=device,
  34. )
  35. text_images = []
  36. for simple_image in simple_images:
  37. image = simple_image['img']
  38. layout_res = temp_layout_model.predict(image)
  39. # 给textblock截图
  40. for res in layout_res:
  41. if res['category_id'] in [1]:
  42. x1, y1, _, _, x2, y2, _, _ = res['poly']
  43. # 初步清洗(宽和高都小于100)
  44. if x2 - x1 < 100 and y2 - y1 < 100:
  45. continue
  46. text_images.append(image[y1:y2, x1:x2])
  47. return text_images
  48. def auto_detect_lang(pdf_bytes: bytes):
  49. sample_docs = extract_pages(pdf_bytes)
  50. sample_pdf_bytes = sample_docs.tobytes()
  51. simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
  52. text_images = get_text_images(simple_images)
  53. langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
  54. lang = langdetect_model.do_detect(text_images)
  55. return lang
  56. def model_init(model_name: str):
  57. atom_model_manager = AtomModelSingleton()
  58. if model_name == MODEL_NAME.YOLO_V11_LangDetect:
  59. root_dir, _, device, _ = get_model_config()
  60. model = atom_model_manager.get_atom_model(
  61. atom_model_name=AtomicModel.LangDetect,
  62. langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
  63. langdetect_model_weight=str(os.path.join(root_dir, 'resources', 'yolov11-langdetect', 'yolo_v11_ft.pt')),
  64. device=device,
  65. )
  66. else:
  67. raise ValueError(f"model_name {model_name} not found")
  68. return model