|
|
@@ -13,12 +13,9 @@ import numpy as np
|
|
|
import yaml
|
|
|
from loguru import logger
|
|
|
import argparse
|
|
|
+from typing import Optional
|
|
|
|
|
|
-# ✅ 修改导入
|
|
|
-try:
|
|
|
- from .device_utils import get_device
|
|
|
-except ImportError:
|
|
|
- from device_utils import get_device
|
|
|
+from vendor import get_device
|
|
|
|
|
|
# 当作为脚本运行时,添加父目录到 Python 路径
|
|
|
current_dir = Path(__file__).resolve().parent
|
|
|
@@ -26,17 +23,12 @@ parent_dir = current_dir.parent
|
|
|
if str(parent_dir) not in sys.path:
|
|
|
sys.path.insert(0, str(parent_dir))
|
|
|
|
|
|
-# ✅ 根据运行模式选择导入方式
|
|
|
-try:
|
|
|
- # 作为模块导入时(from vendor.pytorch_paddle import ...)
|
|
|
- from .ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
|
|
|
- from .infer.predict_system import TextSystem
|
|
|
- from .infer import pytorchocr_utility as utility
|
|
|
-except ImportError:
|
|
|
- # 作为脚本运行时(python pytorch_paddle.py)
|
|
|
- from vendor.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
|
|
|
- from vendor.infer.predict_system import TextSystem
|
|
|
- from vendor.infer import pytorchocr_utility as utility
|
|
|
+# ✅ 导入方向分类器
|
|
|
+from orientation_classifier_v2 import OrientationClassifierV2
|
|
|
+
|
|
|
+from vendor import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
|
|
|
+from vendor.infer import TextSystem
|
|
|
+from vendor.infer import pytorchocr_utility as utility
|
|
|
|
|
|
latin_lang = [
|
|
|
"af", "az", "bs", "cs", "cy", "da", "de", "es", "et", "fr", "ga", "hr", "hu",
|
|
|
@@ -108,7 +100,14 @@ class PytorchPaddleOCR(TextSystem):
|
|
|
# 获取语言设置
|
|
|
self.lang = kwargs.get('lang', 'ch')
|
|
|
self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True)
|
|
|
-
|
|
|
+
|
|
|
+ # ✅ 新增:方向分类器配置
|
|
|
+ self.use_orientation_cls = kwargs.get("use_orientation_cls", False)
|
|
|
+ self.orientation_model_path = kwargs.get(
|
|
|
+ "orientation_model_path",
|
|
|
+ None
|
|
|
+ )
|
|
|
+
|
|
|
# 自动检测设备
|
|
|
device = kwargs.get('device', get_device())
|
|
|
|
|
|
@@ -130,7 +129,7 @@ class PytorchPaddleOCR(TextSystem):
|
|
|
self.lang = 'devanagari'
|
|
|
|
|
|
# ✅ 读取模型配置
|
|
|
- models_config_path = root_dir / 'pytorchocr' / 'utils' / 'resources' / 'models_config.yml'
|
|
|
+ models_config_path = root_dir / 'vendor' / 'pytorchocr' / 'utils' / 'resources' / 'models_config.yml'
|
|
|
|
|
|
if not models_config_path.exists():
|
|
|
raise FileNotFoundError(f"❌ Config file not found: {models_config_path}")
|
|
|
@@ -178,7 +177,7 @@ class PytorchPaddleOCR(TextSystem):
|
|
|
|
|
|
# ✅ 字典路径
|
|
|
if 'rec_char_dict_path' not in kwargs:
|
|
|
- dict_path = root_dir / 'pytorchocr' / 'utils' / 'resources' / 'dict' / dict_file
|
|
|
+ dict_path = root_dir / 'vendor' / 'pytorchocr' / 'utils' / 'resources' / 'dict' / dict_file
|
|
|
|
|
|
if not dict_path.exists():
|
|
|
logger.error(f"❌ Dictionary file not found: {dict_path}")
|
|
|
@@ -204,6 +203,55 @@ class PytorchPaddleOCR(TextSystem):
|
|
|
# ✅ 初始化 TextSystem
|
|
|
super().__init__(final_args)
|
|
|
|
|
|
+ # ✅ 初始化方向分类器(在 TextSystem 之后,这样可以使用 self.text_detector)
|
|
|
+ self.orientation_classifier = None
|
|
|
+ if self.use_orientation_cls and self.orientation_model_path:
|
|
|
+ try:
|
|
|
+ logger.info(f"🔄 Initializing orientation classifier...")
|
|
|
+ logger.info(f" Model: {self.orientation_model_path}")
|
|
|
+
|
|
|
+ # ✅ 创建一个简单的检测器适配器
|
|
|
+ class TextDetectorAdapter:
|
|
|
+ """适配器:将 TextDetector 包装为 OrientationClassifierV2 需要的接口"""
|
|
|
+ def __init__(self, text_detector):
|
|
|
+ self.text_detector = text_detector
|
|
|
+
|
|
|
+ def ocr(self, img, det=True, rec=False):
|
|
|
+ """执行文本检测"""
|
|
|
+ if not det:
|
|
|
+ return None
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 调用检测器
|
|
|
+ dt_boxes, _ = self.text_detector(img)
|
|
|
+
|
|
|
+ if dt_boxes is None or len(dt_boxes) == 0:
|
|
|
+ return [[]]
|
|
|
+
|
|
|
+ # 返回格式: [[box1], [box2], ...]
|
|
|
+ return [dt_boxes.tolist()]
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"Detection failed in adapter: {e}")
|
|
|
+ return [[]]
|
|
|
+
|
|
|
+ # ✅ 传入检测器适配器
|
|
|
+ detector_adapter = TextDetectorAdapter(self.text_detector)
|
|
|
+
|
|
|
+ self.orientation_classifier = OrientationClassifierV2(
|
|
|
+ model_path=self.orientation_model_path,
|
|
|
+ text_detector=detector_adapter, # ✅ 传入检测器
|
|
|
+ aspect_ratio_threshold=kwargs.get("aspect_ratio_threshold", 1.2),
|
|
|
+ vertical_text_ratio=kwargs.get("vertical_text_ratio", 0.28),
|
|
|
+ vertical_text_min_count=kwargs.get("vertical_text_min_count", 3),
|
|
|
+ use_gpu=kwargs.get("device", "cpu") != "cpu"
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.info(f"✅ Orientation classifier initialized")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"⚠️ Failed to initialize orientation classifier: {e}")
|
|
|
+ self.orientation_classifier = None
|
|
|
+
|
|
|
# ✅ 验证并修复字符集
|
|
|
logger.info(f"🔍 Validating text recognizer...")
|
|
|
|
|
|
@@ -284,7 +332,9 @@ class PytorchPaddleOCR(TextSystem):
|
|
|
logger.error('When input is a list of images, det must be False')
|
|
|
exit(1)
|
|
|
|
|
|
+ # np.ndarray: BGR 格式图像
|
|
|
img = check_img(img)
|
|
|
+
|
|
|
imgs = [img]
|
|
|
|
|
|
with warnings.catch_warnings():
|
|
|
@@ -398,12 +448,11 @@ class PytorchPaddleOCR(TextSystem):
|
|
|
filter_rec_res.append(rec_result)
|
|
|
|
|
|
return filter_boxes, filter_rec_res
|
|
|
-
|
|
|
def visualize(
|
|
|
self,
|
|
|
img: np.ndarray,
|
|
|
ocr_results: list,
|
|
|
- output_path: str = None,
|
|
|
+ output_path: Optional[str] = None,
|
|
|
show_text: bool = True,
|
|
|
show_confidence: bool = True,
|
|
|
font_scale: float = 0.5,
|
|
|
@@ -532,27 +581,28 @@ if __name__ == '__main__':
|
|
|
logger.remove()
|
|
|
logger.add(sys.stderr, level="INFO")
|
|
|
|
|
|
- print("🚀 Testing PytorchPaddleOCR with Visualization...")
|
|
|
+ print("🚀 Testing PytorchPaddleOCR with Orientation Classifier...")
|
|
|
|
|
|
try:
|
|
|
- ocr = PytorchPaddleOCR(lang='ch', device='cpu')
|
|
|
+ # ✅ 启用方向分类器
|
|
|
+ ocr = PytorchPaddleOCR(
|
|
|
+ lang='ch',
|
|
|
+ device='cpu',
|
|
|
+ use_orientation_cls=True, # ✅ 启用
|
|
|
+ orientation_model_path="/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/PP-LCNet_x1_0_doc_ori.onnx"
|
|
|
+ )
|
|
|
|
|
|
# ✅ 验证字符集
|
|
|
if hasattr(ocr.text_recognizer, 'postprocess_op'):
|
|
|
char_count = len(ocr.text_recognizer.postprocess_op.character)
|
|
|
print(f"\n📊 Character set loaded: {char_count} characters")
|
|
|
- if char_count > 0:
|
|
|
- print(f" Sample chars: {ocr.text_recognizer.postprocess_op.character[:20]}")
|
|
|
- else:
|
|
|
- print(f" ❌ ERROR: Character set is empty!")
|
|
|
- sys.exit(1)
|
|
|
|
|
|
# 测试图像列表
|
|
|
test_images = [
|
|
|
- {
|
|
|
- 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_001.png",
|
|
|
- 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_001_ocr_vis.png"
|
|
|
- },
|
|
|
+ # {
|
|
|
+ # 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_001.png",
|
|
|
+ # 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_001_ocr_vis.png"
|
|
|
+ # },
|
|
|
{
|
|
|
'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png",
|
|
|
'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_003_ocr_vis.png"
|
|
|
@@ -567,55 +617,47 @@ if __name__ == '__main__':
|
|
|
print(f"📄 Processing: {Path(img_path).name}")
|
|
|
print(f"{'='*60}")
|
|
|
|
|
|
- if not Path(img_path).exists():
|
|
|
- print(f"❌ Image not found: {img_path}")
|
|
|
- continue
|
|
|
-
|
|
|
img = cv2.imread(img_path)
|
|
|
-
|
|
|
- if img is None:
|
|
|
- print(f"❌ Failed to load image")
|
|
|
- continue
|
|
|
-
|
|
|
print(f"📖 Original image: {img.shape}")
|
|
|
|
|
|
- # 执行 OCR
|
|
|
+ # np.ndarray: BGR 格式图像
|
|
|
+ img = check_img(img)
|
|
|
+
|
|
|
+ # ✅ 使用方向分类器(如果启用)
|
|
|
+ if ocr.orientation_classifier is not None and isinstance(img, np.ndarray):
|
|
|
+ logger.info(f"🔄 Checking image orientation...")
|
|
|
+
|
|
|
+ # 预测方向
|
|
|
+ orientation_result = ocr.orientation_classifier.predict(img, return_debug=True)
|
|
|
+
|
|
|
+ logger.info(f" Orientation result: {orientation_result}")
|
|
|
+
|
|
|
+ # 如果需要旋转
|
|
|
+ if orientation_result.needs_rotation:
|
|
|
+ logger.info(f"🔄 Rotating image by {orientation_result.rotation_angle}°...")
|
|
|
+ img = ocr.orientation_classifier.rotate_image(img, orientation_result.rotation_angle)
|
|
|
+ logger.info(f" New shape: {img.shape}")
|
|
|
+
|
|
|
+ # 执行 OCR(会自动检测并旋转)
|
|
|
results = ocr.ocr(img, det=True, rec=True)
|
|
|
|
|
|
if results and results[0]:
|
|
|
- print(f"\n✅ OCR completed! Found {len(results[0])} text regions")
|
|
|
-
|
|
|
- # 打印前 10 个结果
|
|
|
- print(f"\n📝 First 10 results:")
|
|
|
+ print(f"\n✅ Found {len(results[0])} text regions:")
|
|
|
for i, (box, (text, conf)) in enumerate(results[0][:10], 1):
|
|
|
print(f" [{i}] {text} (conf={conf:.3f})")
|
|
|
|
|
|
- if len(results[0]) > 10:
|
|
|
- print(f" ... and {len(results[0]) - 10} more")
|
|
|
-
|
|
|
# 可视化
|
|
|
- print(f"\n🎨 Visualizing results...")
|
|
|
- img_vis = ocr.visualize(
|
|
|
- img,
|
|
|
- results,
|
|
|
- output_path=output_path,
|
|
|
- show_text=True,
|
|
|
- show_confidence=True
|
|
|
- )
|
|
|
+ img_vis = ocr.visualize(img, results, output_path=output_path, show_text=False, show_confidence=False)
|
|
|
|
|
|
- # 统计信息
|
|
|
- total_boxes = len(results[0])
|
|
|
- avg_conf = sum(conf for _, (_, conf) in results[0]) / total_boxes
|
|
|
- high_conf = sum(1 for _, (_, conf) in results[0] if conf >= 0.9)
|
|
|
- low_conf = sum(1 for _, (_, conf) in results[0] if conf < 0.7)
|
|
|
+ # 统计
|
|
|
+ total = len(results[0])
|
|
|
non_empty = sum(1 for _, (text, _) in results[0] if text)
|
|
|
+ avg_conf = sum(conf for _, (_, conf) in results[0]) / total
|
|
|
|
|
|
print(f"\n📊 Statistics:")
|
|
|
- print(f" Total boxes: {total_boxes}")
|
|
|
+ print(f" Total: {total}")
|
|
|
print(f" Non-empty: {non_empty}")
|
|
|
- print(f" Average confidence: {avg_conf:.3f}")
|
|
|
- print(f" High confidence (≥0.9): {high_conf}")
|
|
|
- print(f" Low confidence (<0.7): {low_conf}")
|
|
|
+ print(f" Avg confidence: {avg_conf:.3f}")
|
|
|
else:
|
|
|
print(f"⚠️ No results found")
|
|
|
|