pp_structure_v2.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import random
  2. from loguru import logger
  3. try:
  4. from paddleocr import PPStructure
  5. except ImportError:
  6. logger.error('paddleocr not installed, please install by "pip install magic-pdf[lite]"')
  7. exit(1)
  8. def region_to_bbox(region):
  9. x0 = region[0][0]
  10. y0 = region[0][1]
  11. x1 = region[2][0]
  12. y1 = region[2][1]
  13. return [x0, y0, x1, y1]
  14. class CustomPaddleModel:
  15. def __init__(self, ocr: bool = False, show_log: bool = False, lang=None):
  16. if lang is not None:
  17. self.model = PPStructure(table=False, ocr=ocr, show_log=show_log, lang=lang)
  18. else:
  19. self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
  20. def __call__(self, img):
  21. try:
  22. import cv2
  23. except ImportError:
  24. logger.error("opencv-python not installed, please install by pip.")
  25. exit(1)
  26. # 将RGB图片转换为BGR格式适配paddle
  27. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  28. result = self.model(img)
  29. spans = []
  30. for line in result:
  31. line.pop("img")
  32. """
  33. 为paddle输出适配type no.
  34. title: 0 # 标题
  35. text: 1 # 文本
  36. header: 2 # abandon
  37. footer: 2 # abandon
  38. reference: 1 # 文本 or abandon
  39. equation: 8 # 行间公式 block
  40. equation: 14 # 行间公式 text
  41. figure: 3 # 图片
  42. figure_caption: 4 # 图片描述
  43. table: 5 # 表格
  44. table_caption: 6 # 表格描述
  45. """
  46. if line["type"] == "title":
  47. line["category_id"] = 0
  48. elif line["type"] in ["text", "reference"]:
  49. line["category_id"] = 1
  50. elif line["type"] == "figure":
  51. line["category_id"] = 3
  52. elif line["type"] == "figure_caption":
  53. line["category_id"] = 4
  54. elif line["type"] == "table":
  55. line["category_id"] = 5
  56. elif line["type"] == "table_caption":
  57. line["category_id"] = 6
  58. elif line["type"] == "equation":
  59. line["category_id"] = 8
  60. elif line["type"] in ["header", "footer"]:
  61. line["category_id"] = 2
  62. else:
  63. logger.warning(f"unknown type: {line['type']}")
  64. # 兼容不输出score的paddleocr版本
  65. if line.get("score") is None:
  66. line["score"] = 0.5 + random.random() * 0.5
  67. res = line.pop("res", None)
  68. if res is not None and len(res) > 0:
  69. for span in res:
  70. new_span = {
  71. "category_id": 15,
  72. "bbox": region_to_bbox(span["text_region"]),
  73. "score": span["confidence"],
  74. "text": span["text"],
  75. }
  76. spans.append(new_span)
  77. if len(spans) > 0:
  78. result.extend(spans)
  79. return result