pp_structure_v2.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import random
  2. from loguru import logger
  3. try:
  4. from paddleocr import PPStructure
  5. except ImportError:
  6. logger.error('paddleocr not installed, please install by "pip install magic-pdf[lite]"')
  7. exit(1)
  8. def region_to_bbox(region):
  9. x0 = region[0][0]
  10. y0 = region[0][1]
  11. x1 = region[2][0]
  12. y1 = region[2][1]
  13. return [x0, y0, x1, y1]
  14. class CustomPaddleModel:
  15. def __init__(self,
  16. ocr: bool = False,
  17. show_log: bool = False,
  18. lang=None,
  19. det_db_box_thresh=0.3,
  20. use_dilation=True,
  21. det_db_unclip_ratio=1.8
  22. ):
  23. if lang is not None:
  24. self.model = PPStructure(table=False,
  25. ocr=True,
  26. show_log=show_log,
  27. lang=lang,
  28. det_db_box_thresh=det_db_box_thresh,
  29. use_dilation=use_dilation,
  30. det_db_unclip_ratio=det_db_unclip_ratio,
  31. )
  32. else:
  33. self.model = PPStructure(table=False,
  34. ocr=True,
  35. show_log=show_log,
  36. det_db_box_thresh=det_db_box_thresh,
  37. use_dilation=use_dilation,
  38. det_db_unclip_ratio=det_db_unclip_ratio,
  39. )
  40. def __call__(self, img):
  41. try:
  42. import cv2
  43. except ImportError:
  44. logger.error("opencv-python not installed, please install by pip.")
  45. exit(1)
  46. # 将RGB图片转换为BGR格式适配paddle
  47. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  48. result = self.model(img)
  49. spans = []
  50. for line in result:
  51. line.pop("img")
  52. """
  53. 为paddle输出适配type no.
  54. title: 0 # 标题
  55. text: 1 # 文本
  56. header: 2 # abandon
  57. footer: 2 # abandon
  58. reference: 1 # 文本 or abandon
  59. equation: 8 # 行间公式 block
  60. equation: 14 # 行间公式 text
  61. figure: 3 # 图片
  62. figure_caption: 4 # 图片描述
  63. table: 5 # 表格
  64. table_caption: 6 # 表格描述
  65. """
  66. if line["type"] == "title":
  67. line["category_id"] = 0
  68. elif line["type"] in ["text", "reference"]:
  69. line["category_id"] = 1
  70. elif line["type"] == "figure":
  71. line["category_id"] = 3
  72. elif line["type"] == "figure_caption":
  73. line["category_id"] = 4
  74. elif line["type"] == "table":
  75. line["category_id"] = 5
  76. elif line["type"] == "table_caption":
  77. line["category_id"] = 6
  78. elif line["type"] == "equation":
  79. line["category_id"] = 8
  80. elif line["type"] in ["header", "footer"]:
  81. line["category_id"] = 2
  82. else:
  83. logger.warning(f"unknown type: {line['type']}")
  84. # 兼容不输出score的paddleocr版本
  85. if line.get("score") is None:
  86. line["score"] = 0.5 + random.random() * 0.5
  87. res = line.pop("res", None)
  88. if res is not None and len(res) > 0:
  89. for span in res:
  90. new_span = {
  91. "category_id": 15,
  92. "bbox": region_to_bbox(span["text_region"]),
  93. "score": span["confidence"],
  94. "text": span["text"],
  95. }
  96. spans.append(new_span)
  97. if len(spans) > 0:
  98. result.extend(spans)
  99. return result