|
|
@@ -1 +1,433 @@
|
|
|
-pip install onnx2pytorch
|
|
|
+# Unified PyTorch Models
|
|
|
+
|
|
|
+统一的 PyTorch 模型推理接口,支持布局检测、OCR、文档方向分类等功能。
|
|
|
+
|
|
|
+## 📂 目录结构
|
|
|
+
|
|
|
+```
|
|
|
+unified_pytorch_models/
|
|
|
+├── Layout/ # 布局检测与方向分类模型
|
|
|
+│ ├── PP-LCNet_x1_0_doc_ori.onnx # 文档方向分类模型
|
|
|
+│ └── RT-DETR-H_layout_17cls.onnx # 布局检测模型
|
|
|
+├── OCR/ # OCR 模型目录(从 ModelScope 下载)
|
|
|
+│ ├── Cls/ # 方向分类器模型
|
|
|
+│ └── Rec/ # 文本识别模型
|
|
|
+├── vendor/ # ✨ 核心依赖模块
|
|
|
+│ ├── __init__.py
|
|
|
+│ ├── device_utils.py # 设备检测工具
|
|
|
+│ ├── ocr_utils.py # OCR 工具函数
|
|
|
+│ ├── README.md # Vendor 依赖说明
|
|
|
+│ ├── infer/ # 推理模块
|
|
|
+│ │ ├── predict_det.py # 文本检测
|
|
|
+│ │ ├── predict_rec.py # 文本识别
|
|
|
+│ │ ├── predict_cls.py # 方向分类
|
|
|
+│ │ ├── predict_system.py # OCR 系统
|
|
|
+│ │ └── pytorchocr_utility.py # 工具函数
|
|
|
+│ └── pytorchocr/ # PytorchOCR 核心
|
|
|
+│ ├── modeling/ # 模型架构
|
|
|
+│ ├── postprocess/ # 后处理
|
|
|
+│ └── utils/ # 工具与资源
|
|
|
+│ └── resources/ # 配置与字典
|
|
|
+│ ├── models_config.yml
|
|
|
+│ └── dict/ # 多语言字典
|
|
|
+├── pytorch_paddle.py # ✨ PytorchPaddleOCR 主模块
|
|
|
+├── layout_detect_onnx.py # 布局检测器
|
|
|
+├── orientation_classifier_v2.py # 增强版方向分类器
|
|
|
+├── doc_preprocessor_v2.py # 文档预处理 Pipeline
|
|
|
+├── doc_preprocess_result.py # 预处理结果数据类
|
|
|
+├── paddle_to_pytorch_universal.py # Paddle 模型转换工具
|
|
|
+├── .env # 环境变量配置
|
|
|
+└── README.md
|
|
|
+```
|
|
|
+
|
|
|
+## 🚀 快速开始
|
|
|
+
|
|
|
+### 1. 安装依赖
|
|
|
+
|
|
|
+```bash
|
|
|
+pip install torch torchvision opencv-python onnxruntime numpy pyyaml shapely pyclipper loguru python-dotenv
|
|
|
+```
|
|
|
+
|
|
|
+### 2. 准备模型文件
|
|
|
+
|
|
|
+```bash
|
|
|
+# 模型会自动从 ModelScope 下载到:
|
|
|
+# ~/.cache/modelscope/models/OpenDataLab/PDF-Extract-Kit-1.0/
|
|
|
+
|
|
|
+# 或者设置自定义缓存目录:
|
|
|
+export MODELSCOPE_CACHE_DIR="/path/to/your/cache"
|
|
|
+```
|
|
|
+
|
|
|
+### 3. 配置环境变量
|
|
|
+
|
|
|
+创建 `.env` 文件(可选):
|
|
|
+
|
|
|
+```bash
|
|
|
+# ModelScope 缓存目录
|
|
|
+MODELSCOPE_CACHE_DIR=/Users/zhch158/models/modelscope_cache
|
|
|
+
|
|
|
+# 设备配置
|
|
|
+DEVICE=cpu # 或 cuda, mps
|
|
|
+```
|
|
|
+
|
|
|
+### 4. 运行测试
|
|
|
+
|
|
|
+```bash
|
|
|
+cd /Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models
|
|
|
+
|
|
|
+# 测试 OCR 功能(含可视化)
|
|
|
+python pytorch_paddle.py
|
|
|
+
|
|
|
+# 测试文档预处理
|
|
|
+python doc_preprocessor_v2.py
|
|
|
+
|
|
|
+# 测试布局检测
|
|
|
+python layout_detect_onnx.py
|
|
|
+
|
|
|
+# 测试方向分类
|
|
|
+python orientation_classifier_v2.py
|
|
|
+```
|
|
|
+
|
|
|
+## 📖 使用示例
|
|
|
+
|
|
|
+### 1. OCR 识别(推荐)
|
|
|
+
|
|
|
+```python
|
|
|
+from pytorch_paddle import PytorchPaddleOCR
|
|
|
+import cv2
|
|
|
+
|
|
|
+# 初始化 OCR 引擎
|
|
|
+ocr = PytorchPaddleOCR(
|
|
|
+ lang='ch', # 语言: ch, en, ch_lite, korean, japan 等
|
|
|
+ device='cpu', # 设备: cpu, cuda, mps
|
|
|
+ use_orientation_cls=True, # ✨ 启用方向分类
|
|
|
+ orientation_model_path='./Layout/PP-LCNet_x1_0_doc_ori.onnx',
|
|
|
+ rec_batch_num=6, # 识别批大小
|
|
|
+ enable_merge_det_boxes=True # 合并检测框
|
|
|
+)
|
|
|
+
|
|
|
+# 读取图像
|
|
|
+img = cv2.imread("test.jpg")
|
|
|
+
|
|
|
+# 执行 OCR
|
|
|
+results = ocr.ocr(img, det=True, rec=True)
|
|
|
+
|
|
|
+# 打印结果
|
|
|
+if results and results[0]:
|
|
|
+ for box, (text, conf) in results[0]:
|
|
|
+ print(f"{text} (confidence={conf:.3f})")
|
|
|
+
|
|
|
+# 可视化结果
|
|
|
+img_vis = ocr.visualize(
|
|
|
+ img,
|
|
|
+ results,
|
|
|
+ output_path="output_ocr.jpg",
|
|
|
+ show_text=True,
|
|
|
+ show_confidence=True
|
|
|
+)
|
|
|
+```
|
|
|
+
|
|
|
+### 2. 布局检测
|
|
|
+
|
|
|
+```python
|
|
|
+from layout_detect_onnx import LayoutDetectorONNX
|
|
|
+import cv2
|
|
|
+
|
|
|
+# 初始化检测器
|
|
|
+detector = LayoutDetectorONNX(
|
|
|
+ onnx_path="./Layout/RT-DETR-H_layout_17cls.onnx",
|
|
|
+ use_gpu=False
|
|
|
+)
|
|
|
+
|
|
|
+# 检测
|
|
|
+img = cv2.imread("test.jpg")
|
|
|
+results = detector.predict(
|
|
|
+ img,
|
|
|
+ conf_threshold=0.5,
|
|
|
+ return_debug=True
|
|
|
+)
|
|
|
+
|
|
|
+# 打印结果
|
|
|
+for box in results:
|
|
|
+ print(f"{box['category_name']}: {box['bbox']}, score={box['score']:.3f}")
|
|
|
+
|
|
|
+# 可视化
|
|
|
+img_vis = detector.visualize(
|
|
|
+ img,
|
|
|
+ results,
|
|
|
+ output_path="output_layout.jpg"
|
|
|
+)
|
|
|
+```
|
|
|
+
|
|
|
+### 3. 文档方向分类
|
|
|
+
|
|
|
+```python
|
|
|
+from orientation_classifier_v2 import OrientationClassifierV2
|
|
|
+from pytorch_paddle import PytorchPaddleOCR
|
|
|
+import cv2
|
|
|
+
|
|
|
+# 初始化 OCR(用于辅助判断)
|
|
|
+ocr = PytorchPaddleOCR(lang='ch', device='cpu')
|
|
|
+
|
|
|
+# 初始化方向分类器
|
|
|
+classifier = OrientationClassifierV2(
|
|
|
+ model_path="./Layout/PP-LCNet_x1_0_doc_ori.onnx",
|
|
|
+ text_detector=ocr, # ✨ 传入文本检测器
|
|
|
+ aspect_ratio_threshold=1.2, # 长宽比阈值
|
|
|
+ vertical_text_ratio=0.28, # 垂直文本占比阈值
|
|
|
+ vertical_text_min_count=3, # 最小垂直文本数量
|
|
|
+ use_gpu=False
|
|
|
+)
|
|
|
+
|
|
|
+# 预测方向
|
|
|
+img = cv2.imread("test.jpg")
|
|
|
+result = classifier.predict(img, return_debug=True)
|
|
|
+
|
|
|
+print(f"Rotation angle: {result.rotation_angle}°")
|
|
|
+print(f"Confidence: {result.confidence:.3f}")
|
|
|
+print(f"Needs rotation: {result.needs_rotation}")
|
|
|
+
|
|
|
+# 如果需要旋转
|
|
|
+if result.needs_rotation:
|
|
|
+ img_rotated = classifier.rotate_image(img, result.rotation_angle)
|
|
|
+ cv2.imwrite("rotated.jpg", img_rotated)
|
|
|
+```
|
|
|
+
|
|
|
+### 4. 完整文档预处理流程
|
|
|
+
|
|
|
+```python
|
|
|
+from pytorch_paddle import PytorchPaddleOCR
|
|
|
+from doc_preprocessor_v2 import DocPreprocessorV2
|
|
|
+import cv2
|
|
|
+
|
|
|
+# 初始化 OCR
|
|
|
+ocr = PytorchPaddleOCR(lang='ch', device='cpu')
|
|
|
+
|
|
|
+# 初始化预处理 Pipeline
|
|
|
+pipeline = DocPreprocessorV2(
|
|
|
+ orientation_model="./Layout/PP-LCNet_x1_0_doc_ori.onnx",
|
|
|
+ text_detector=ocr,
|
|
|
+ use_orientation_classify=True
|
|
|
+)
|
|
|
+
|
|
|
+# 预测
|
|
|
+img = cv2.imread("test.jpg")
|
|
|
+results = pipeline.predict(img, return_debug=True)
|
|
|
+
|
|
|
+print(results[0]) # DocPreprocessResult 对象
|
|
|
+```
|
|
|
+
|
|
|
+## 🔧 配置说明
|
|
|
+
|
|
|
+### OCR 引擎参数
|
|
|
+
|
|
|
+```python
|
|
|
+PytorchPaddleOCR(
|
|
|
+ lang='ch', # 语言
|
|
|
+ device='cpu', # 设备
|
|
|
+ use_orientation_cls=True, # 启用方向分类
|
|
|
+ orientation_model_path='...', # 方向分类模型路径
|
|
|
+ rec_batch_num=6, # 识别批大小
|
|
|
+ det_db_thresh=0.3, # 检测二值化阈值
|
|
|
+ det_db_box_thresh=0.6, # 检测框过滤阈值
|
|
|
+ enable_merge_det_boxes=True, # 合并检测框
|
|
|
+ drop_score=0.5 # 最低置信度
|
|
|
+)
|
|
|
+```
|
|
|
+
|
|
|
+### 支持的语言
|
|
|
+
|
|
|
+| 语言代码 | 说明 | 推荐用途 |
|
|
|
+|---------|------|---------|
|
|
|
+| `ch` | 中文(标准) | 通用中文识别 |
|
|
|
+| `ch_lite` | 中文(轻量) | CPU 环境 |
|
|
|
+| `ch_server` | 中文(服务器) | 高精度场景 |
|
|
|
+| `en` | 英文 | 英文识别 |
|
|
|
+| `korean` | 韩文 | 韩文识别 |
|
|
|
+| `japan` | 日文 | 日文识别 |
|
|
|
+| `chinese_cht` | 繁体中文 | 繁体中文识别 |
|
|
|
+| `latin` | 拉丁字母 | 多语言拉丁字母 |
|
|
|
+| `arabic` | 阿拉伯语 | 阿拉伯语识别 |
|
|
|
+| `cyrillic` | 西里尔字母 | 俄语等 |
|
|
|
+| `devanagari` | 梵文字母 | 印地语等 |
|
|
|
+
|
|
|
+### 方向分类器参数
|
|
|
+
|
|
|
+```python
|
|
|
+OrientationClassifierV2(
|
|
|
+ model_path="...",
|
|
|
+ text_detector=ocr, # ✨ 文本检测器(辅助判断)
|
|
|
+ aspect_ratio_threshold=1.2, # 长宽比阈值(h/w > 1.2 触发检测)
|
|
|
+ vertical_text_ratio=0.28, # 垂直文本占比阈值(>28% 判定为横向扫描)
|
|
|
+ vertical_text_min_count=3, # 最小垂直文本数量
|
|
|
+ use_gpu=False
|
|
|
+)
|
|
|
+```
|
|
|
+
|
|
|
+## 🎨 可视化功能
|
|
|
+
|
|
|
+### OCR 可视化
|
|
|
+
|
|
|
+```python
|
|
|
+img_vis = ocr.visualize(
|
|
|
+ img,
|
|
|
+ results,
|
|
|
+ output_path="output.jpg",
|
|
|
+ show_text=True, # 显示识别文字
|
|
|
+ show_confidence=True, # 显示置信度
|
|
|
+ font_scale=0.5, # 字体大小
|
|
|
+ thickness=2 # 边框粗细
|
|
|
+)
|
|
|
+```
|
|
|
+
|
|
|
+**颜色编码**:
|
|
|
+- 🟢 **绿色**: 高置信度 (≥0.9)
|
|
|
+- 🟡 **黄色**: 中置信度 (0.7-0.9)
|
|
|
+- 🟠 **橙色**: 低置信度 (<0.7)
|
|
|
+
|
|
|
+### 布局检测可视化
|
|
|
+
|
|
|
+```python
|
|
|
+img_vis = detector.visualize(
|
|
|
+ img,
|
|
|
+ results,
|
|
|
+ output_path="layout.jpg",
|
|
|
+ show_labels=True,
|
|
|
+ show_scores=True
|
|
|
+)
|
|
|
+```
|
|
|
+
|
|
|
+## 📝 注意事项
|
|
|
+
|
|
|
+### 1. Vendor 依赖
|
|
|
+
|
|
|
+首次使用需要确保 `vendor/` 目录下的模块完整:
|
|
|
+
|
|
|
+```bash
|
|
|
+cd vendor
|
|
|
+# 查看 README.md 了解依赖说明
|
|
|
+```
|
|
|
+
|
|
|
+### 2. 模型路径
|
|
|
+
|
|
|
+**自动下载** (推荐):
|
|
|
+```python
|
|
|
+# 设置环境变量
|
|
|
+export MODELSCOPE_CACHE_DIR="/path/to/cache"
|
|
|
+
|
|
|
+# 模型会自动下载到:
|
|
|
+# $MODELSCOPE_CACHE_DIR/models/OpenDataLab/PDF-Extract-Kit-1.0/
|
|
|
+```
|
|
|
+
|
|
|
+**手动指定**:
|
|
|
+```python
|
|
|
+ocr = PytorchPaddleOCR(
|
|
|
+ det_model_path="/path/to/det_model.pth",
|
|
|
+ rec_model_path="/path/to/rec_model.pth",
|
|
|
+ rec_char_dict_path="/path/to/dict.txt"
|
|
|
+)
|
|
|
+```
|
|
|
+
|
|
|
+### 3. GPU 支持
|
|
|
+
|
|
|
+```python
|
|
|
+# CUDA (NVIDIA GPU)
|
|
|
+ocr = PytorchPaddleOCR(device='cuda')
|
|
|
+
|
|
|
+# MPS (Apple Silicon M1/M2/M3)
|
|
|
+ocr = PytorchPaddleOCR(device='mps')
|
|
|
+
|
|
|
+# CPU
|
|
|
+ocr = PytorchPaddleOCR(device='cpu')
|
|
|
+```
|
|
|
+
|
|
|
+### 4. 内存优化
|
|
|
+
|
|
|
+```python
|
|
|
+# CPU 环境使用轻量模型
|
|
|
+ocr = PytorchPaddleOCR(lang='ch_lite', device='cpu')
|
|
|
+
|
|
|
+# 调整批大小
|
|
|
+ocr = PytorchPaddleOCR(rec_batch_num=4) # 默认 6
|
|
|
+```
|
|
|
+
|
|
|
+## 🐛 故障排除
|
|
|
+
|
|
|
+### 1. 识别结果全是空字符串
|
|
|
+
|
|
|
+**原因**: 字符集未正确加载
|
|
|
+
|
|
|
+**解决方案**:
|
|
|
+```python
|
|
|
+# 初始化后验证字符集
|
|
|
+if hasattr(ocr.text_recognizer, 'postprocess_op'):
|
|
|
+ char_count = len(ocr.text_recognizer.postprocess_op.character)
|
|
|
+ print(f"Character set size: {char_count}") # 应该 > 0
|
|
|
+```
|
|
|
+
|
|
|
+### 2. 横向扫描图片无法识别
|
|
|
+
|
|
|
+**原因**: 图像方向未矫正
|
|
|
+
|
|
|
+**解决方案**:
|
|
|
+```python
|
|
|
+# 启用方向分类
|
|
|
+ocr = PytorchPaddleOCR(
|
|
|
+ use_orientation_cls=True,
|
|
|
+ orientation_model_path='./Layout/PP-LCNet_x1_0_doc_ori.onnx'
|
|
|
+)
|
|
|
+```
|
|
|
+
|
|
|
+### 3. ImportError: No module named 'vendor'
|
|
|
+
|
|
|
+**解决方案**:
|
|
|
+```python
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+# 添加项目根目录到 Python 路径
|
|
|
+root_dir = Path(__file__).resolve().parent
|
|
|
+sys.path.insert(0, str(root_dir))
|
|
|
+```
|
|
|
+
|
|
|
+### 4. 模型加载失败
|
|
|
+
|
|
|
+```bash
|
|
|
+MODELSCOPE_CACHE_DIR="/Users/zhch158/models/modelscope_cache"
|
|
|
+# 检查模型文件
|
|
|
+ls $MODELSCOPE_CACHE_DIR/models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/
|
|
|
+
|
|
|
+# 清除缓存重新下载
|
|
|
+rm -rf $MODELSCOPE_CACHE_DIR/models/OpenDataLab/PDF-Extract-Kit-1.0/
|
|
|
+```
|
|
|
+
|
|
|
+## 🔄 模型转换
|
|
|
+
|
|
|
+将 Paddle 模型转换为 PyTorch:
|
|
|
+
|
|
|
+```python
|
|
|
+from paddle_to_pytorch_universal import PaddleModelConverter
|
|
|
+
|
|
|
+converter = PaddleModelConverter(
|
|
|
+ paddle_model_path="paddle_model.pdparams",
|
|
|
+ paddle_config_path="config.yml"
|
|
|
+)
|
|
|
+
|
|
|
+pytorch_model = converter.convert()
|
|
|
+torch.save(pytorch_model.state_dict(), "pytorch_model.pth")
|
|
|
+```
|
|
|
+
|
|
|
+## 📚 参考资料
|
|
|
+
|
|
|
+- [MinerU](https://github.com/opendatalab/MinerU) - PDF 文档解析工具
|
|
|
+- [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) - 百度 OCR 工具包
|
|
|
+- [PaddleX](https://github.com/PaddlePaddle/PaddleX) - 飞桨低代码开发工具
|
|
|
+- [PDF-Extract-Kit](https://modelscope.cn/models/OpenDataLab/PDF-Extract-Kit-1.0) - ModelScope 模型
|
|
|
+
|
|
|
+## 📄 许可证
|
|
|
+
|
|
|
+本项目仅供学习研究使用,模型版权归原作者所有。
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+**最后更新**: 2024-10-30
|