pp_structure_v2.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import random
  2. from loguru import logger
  3. try:
  4. from paddleocr import PPStructure
  5. except ImportError:
  6. logger.warning('paddleocr not installed, please install by "pip install magic-pdf[cpu]" or "pip install magic-pdf[gpu]"')
  7. exit(1)
  8. def region_to_bbox(region):
  9. x0 = region[0][0]
  10. y0 = region[0][1]
  11. x1 = region[2][0]
  12. y1 = region[2][1]
  13. return [x0, y0, x1, y1]
  14. class CustomPaddleModel:
  15. def __init__(self, ocr: bool = False, show_log: bool = False):
  16. self.model = PPStructure(table=False, ocr=ocr, show_log=show_log)
  17. def __call__(self, img):
  18. result = self.model(img)
  19. spans = []
  20. for line in result:
  21. line.pop("img")
  22. """
  23. 为paddle输出适配type no.
  24. title: 0 # 标题
  25. text: 1 # 文本
  26. header: 2 # abandon
  27. footer: 2 # abandon
  28. reference: 1 # 文本 or abandon
  29. equation: 8 # 行间公式 block
  30. equation: 14 # 行间公式 text
  31. figure: 3 # 图片
  32. figure_caption: 4 # 图片描述
  33. table: 5 # 表格
  34. table_caption: 6 # 表格描述
  35. """
  36. if line["type"] == "title":
  37. line["category_id"] = 0
  38. elif line["type"] in ["text", "reference"]:
  39. line["category_id"] = 1
  40. elif line["type"] == "figure":
  41. line["category_id"] = 3
  42. elif line["type"] == "figure_caption":
  43. line["category_id"] = 4
  44. elif line["type"] == "table":
  45. line["category_id"] = 5
  46. elif line["type"] == "table_caption":
  47. line["category_id"] = 6
  48. elif line["type"] == "equation":
  49. line["category_id"] = 8
  50. elif line["type"] in ["header", "footer"]:
  51. line["category_id"] = 2
  52. else:
  53. logger.warning(f"unknown type: {line['type']}")
  54. # 兼容不输出score的paddleocr版本
  55. if line.get("score") is None:
  56. line["score"] = 0.5 + random.random() * 0.5
  57. res = line.pop("res", None)
  58. if res is not None and len(res) > 0:
  59. for span in res:
  60. new_span = {
  61. "category_id": 15,
  62. "bbox": region_to_bbox(span["text_region"]),
  63. "score": span["confidence"],
  64. "text": span["text"],
  65. }
  66. spans.append(new_span)
  67. if len(spans) > 0:
  68. result.extend(spans)
  69. return result