test_single_model.py 6.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from paddlex import create_model
  2. import time
  3. from pathlib import Path
  4. # 定义paddlex模型名称列表
  5. MODEL_LIST = [
  6. # OCR文本检测模型
  7. {"model_name": "PP-OCRv5_mobile_det", "description": "轻量级OCR文本检测模型,适用于移动端部署"},
  8. {"model_name": "PP-OCRv5_server_det", "description": "PP-OCRv5_rec 是新一代文本识别模型。该模型致力于以单一模型高效、精准地支持简体中文、繁体中文、英文、日文四种主要语言,以及手写、竖版、拼音、生僻字等复杂文本场景的识别。在保持识别效果的同时,兼顾推理速度和模型鲁棒性,为各种场景下的文档理解提供高效、精准的技术支撑。"},
  9. # OCR文本识别模型
  10. {"model_name": "PP-OCRv5_mobile_rec", "description": "轻量级OCR文本识别模型,适用于移动端部署"},
  11. {"model_name": "PP-OCRv5_server_rec", "description": "服务端OCR文本识别模型,高精度识别"},
  12. # 版面区域检测模型
  13. {"model_name": "PP-DocLayout_plus-L", "description": "版面检测模型,包含20个常见的类别:文档标题、段落标题、文本、页码、摘要、目录、参考文献、脚注、页眉、页脚、算法、公式、公式编号、图像、表格、图和表标题(图标题、表格标题和图表标题)、印章、图表、侧栏文本和参考文献内容"},
  14. {"model_name": "PP-DocBlockLayout", "description": "文档图像版面子模块检测,包含1个 版面区域 类别,能检测多栏的报纸、杂志的每个子文章的文本区域"},
  15. # 表格分类模型
  16. {"model_name": "PP-LCNet_x1_0_table_cls", "description": "wired_table, wireless_table"},
  17. # 表格识别模型
  18. {"model_name": "SLANet_plus", "description": "SLANet_plus 是百度飞桨视觉团队自研的表格结构识别模型 SLANet 的增强版。相较于 SLANet,SLANet_plus 对无线表、复杂表格的识别能力得到了大幅提升,并降低了模型对表格定位准确性的敏感度,即使表格定位出现偏移,也能够较准确地进行识别。"},
  19. {"model_name": "SLANeXt_wired", "description": "SLANeXt 系列是百度飞桨视觉团队自研的新一代表格结构识别模型。相较于 SLANet 和 SLANet_plus,SLANeXt 专注于对表格结构进行识别,并且对有线表格(wired)和无线表格(wireless)的识别分别训练了专用的权重,对各类型表格的识别能力都得到了明显提高,特别是对有线表格的识别能力得到了大幅提升。"},
  20. {"model_name": "SLANeXt_wireless", "description": "SLANeXt 系列是百度飞桨视觉团队自研的新一代表格结构识别模型。相较于 SLANet 和 SLANet_plus,SLANeXt 专注于对表格结构进行识别,并且对有线表格(wired)和无线表格(wireless)的识别分别训练了专用的权重,对各类型表格的识别能力都得到了明显提高,特别是对无线表格的识别能力得到了大幅提升。"},
  21. # 公式识别模型
  22. {"model_name": "PP-FormulaNet_plus-L", "description": "负责将图像中的数学公式转换为可编辑的文本或计算机可识别的格式。该模块的性能直接影响到整个OCR系统的准确性和效率。公式识别模块通常会输出数学公式的 LaTeX 或 MathML 代码"},
  23. # 文档图像方向分类模型
  24. {"model_name": "PP-LCNet_x1_0_doc_ori", "description": "基于PP-LCNet_x1_0的文档图像分类模型,含有四个类别,即0度,90度,180度,270度"},
  25. # 文本图像矫正模型
  26. {"model_name": "UVDoc", "description": "针对图像进行几何变换,以纠正图像中的文档扭曲、倾斜、透视变形等问题,以供后续的文本识别进行更加准确"},
  27. # 印章检测模型
  28. {"model_name": "PP-OCRv4_mobile_seal_det", "description": "PP-OCRv4的移动端印章文本检测模型,效率更高,适合在端侧部署"},
  29. {"model_name": "PP-OCRv4_server_seal_det", "description": "PP-OCRv4的服务端印章文本检测模型,精度更高,适合在较好的服务器上部署"},
  30. ]
  31. def test_single_model(
  32. model_name: str,
  33. # input_path: str = "sample_data/300674-母公司现金流量表-扫描.png",
  34. input_path: str,
  35. output_path: str = "./sample_data/output/"):
  36. """
  37. Test single model for layout detection.
  38. """
  39. start_time = time.time()
  40. print(f"\nTesting model: {model_name}")
  41. # Create the model
  42. model = create_model(model_name=model_name)
  43. # 参数通常用于目标检测模型(如 DetPredictor)中,用于重叠框过滤。
  44. # 检查模型是否支持 layout_nms 参数
  45. predict_kwargs = {}
  46. if hasattr(model._predictor, 'layout_nms'):
  47. predict_kwargs['layout_nms'] = True
  48. # 特殊处理 Doc VLM 模型(如 PP-Chart2Table)
  49. if model_name in ["PP-Chart2Table", "PP-DocBee-2B", "PP-DocBee-7B", "PP-DocBee2-3B"]:
  50. # Doc VLM 模型需要字典格式的输入
  51. input_data = {
  52. "image": input_path,
  53. "query": "请将图表转换为表格格式" # 或其他适合的查询
  54. }
  55. output = model.predict(input_data, batch_size=1, **predict_kwargs)
  56. else:
  57. # 其他模型使用标准输入格式
  58. output = model.predict(input_path, batch_size=1, **predict_kwargs)
  59. for res in output:
  60. res.print()
  61. res.save_all(save_path=output_path) # Save all results to the specified path
  62. end_time = time.time()
  63. elapsed_time = end_time - start_time
  64. print(f"Total time taken for {model_name}: {elapsed_time:.2f} seconds")
  65. if __name__ == "__main__":
  66. # Specify the model name
  67. # 循环,如何指定模型名称,直到quit
  68. while True:
  69. model_name = input("请输入模型名称(或输入'quit'退出):")
  70. if model_name.lower() == "quit":
  71. break
  72. output_path = Path(f"./sample_data/single_model_output/{model_name}/")
  73. output_path.mkdir(parents=True, exist_ok=True)
  74. test_single_model(model_name=model_name,
  75. # input_path="sample_data/300674-母公司现金流量表-扫描.png",
  76. input_path="/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/data_PPStructureV3_Results/B用户_扫描流水/B用户_扫描流水_page_002.png",
  77. output_path=output_path.as_posix())
  78. print("\n" + "="*50 + "\n")