|
|
@@ -0,0 +1,625 @@
|
|
|
+"""
|
|
|
+PytorchPaddleOCR - 从 MinerU 移植
|
|
|
+提供完整的 OCR 功能(检测 + 识别)
|
|
|
+"""
|
|
|
+import copy
|
|
|
+import os
|
|
|
+import sys
|
|
|
+import warnings
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+import yaml
|
|
|
+from loguru import logger
|
|
|
+import argparse
|
|
|
+
|
|
|
+# ✅ 修改导入
|
|
|
+try:
|
|
|
+ from .device_utils import get_device
|
|
|
+except ImportError:
|
|
|
+ from device_utils import get_device
|
|
|
+
|
|
|
+# 当作为脚本运行时,添加父目录到 Python 路径
|
|
|
+current_dir = Path(__file__).resolve().parent
|
|
|
+parent_dir = current_dir.parent
|
|
|
+if str(parent_dir) not in sys.path:
|
|
|
+ sys.path.insert(0, str(parent_dir))
|
|
|
+
|
|
|
+# ✅ 根据运行模式选择导入方式
|
|
|
+try:
|
|
|
+ # 作为模块导入时(from vendor.pytorch_paddle import ...)
|
|
|
+ from .ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
|
|
|
+ from .infer.predict_system import TextSystem
|
|
|
+ from .infer import pytorchocr_utility as utility
|
|
|
+except ImportError:
|
|
|
+ # 作为脚本运行时(python pytorch_paddle.py)
|
|
|
+ from vendor.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
|
|
|
+ from vendor.infer.predict_system import TextSystem
|
|
|
+ from vendor.infer import pytorchocr_utility as utility
|
|
|
+
|
|
|
+latin_lang = [
|
|
|
+ "af", "az", "bs", "cs", "cy", "da", "de", "es", "et", "fr", "ga", "hr", "hu",
|
|
|
+ "id", "is", "it", "ku", "la", "lt", "lv", "mi", "ms", "mt", "nl", "no", "oc",
|
|
|
+ "pi", "pl", "pt", "ro", "rs_latin", "sk", "sl", "sq", "sv", "sw", "tl", "tr",
|
|
|
+ "uz", "vi", "french", "german", "fi", "eu", "gl", "lb", "rm", "ca", "qu",
|
|
|
+]
|
|
|
+arabic_lang = ["ar", "fa", "ug", "ur", "ps", "ku", "sd", "bal"]
|
|
|
+cyrillic_lang = [
|
|
|
+ "ru", "rs_cyrillic", "be", "bg", "uk", "mn", "abq", "ady", "kbd", "ava", "dar",
|
|
|
+ "inh", "che", "lbe", "lez", "tab", "kk", "ky", "tg", "mk", "tt", "cv", "ba",
|
|
|
+ "mhr", "mo", "udm", "kv", "os", "bua", "xal", "tyv", "sah", "kaa",
|
|
|
+]
|
|
|
+east_slavic_lang = ["ru", "be", "uk"]
|
|
|
+devanagari_lang = [
|
|
|
+ "hi", "mr", "ne", "bh", "mai", "ang", "bho", "mah", "sck", "new", "gom", "sa", "bgc",
|
|
|
+]
|
|
|
+
|
|
|
+
|
|
|
+def get_model_params(lang, config):
|
|
|
+ """从配置文件获取模型参数"""
|
|
|
+ if lang in config['lang']:
|
|
|
+ params = config['lang'][lang]
|
|
|
+ det = params.get('det')
|
|
|
+ rec = params.get('rec')
|
|
|
+ dict_file = params.get('dict')
|
|
|
+ return det, rec, dict_file
|
|
|
+ else:
|
|
|
+ raise Exception(f'Language {lang} not supported')
|
|
|
+
|
|
|
+
|
|
|
+def auto_download_and_get_model_root_path(model_path):
|
|
|
+ """
|
|
|
+ 模拟 MinerU 的模型下载逻辑
|
|
|
+ """
|
|
|
+ modelscope_cache = os.getenv('MODELSCOPE_CACHE_DIR', str(Path.home() / '.cache' / 'modelscope'))
|
|
|
+ return modelscope_cache
|
|
|
+
|
|
|
+
|
|
|
+# 当前文件所在目录
|
|
|
+root_dir = Path(__file__).resolve().parent
|
|
|
+
|
|
|
+
|
|
|
+class PytorchPaddleOCR(TextSystem):
|
|
|
+ """
|
|
|
+ PytorchPaddleOCR - OCR 引擎
|
|
|
+ 继承 TextSystem,提供完整的 OCR 功能
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
+ """初始化 OCR 引擎
|
|
|
+
|
|
|
+ Args:
|
|
|
+ lang (str): 语言 ('ch', 'en', 'latin', 'korean', 'japan', 等)
|
|
|
+ det_db_thresh (float): 检测二值化阈值
|
|
|
+ det_db_box_thresh (float): 检测框过滤阈值
|
|
|
+ rec_batch_num (int): 识别批大小
|
|
|
+ enable_merge_det_boxes (bool): 是否合并检测框
|
|
|
+ device (str): 设备 ('cpu', 'cuda:0', 'mps')
|
|
|
+ det_model_path (str): 自定义检测模型路径
|
|
|
+ rec_model_path (str): 自定义识别模型路径
|
|
|
+ rec_char_dict_path (str): 自定义字典路径
|
|
|
+ """
|
|
|
+ # 初始化参数解析器
|
|
|
+ parser = utility.init_args()
|
|
|
+ args_list = list(args) if args else []
|
|
|
+ parsed_args = parser.parse_args(args_list)
|
|
|
+
|
|
|
+ # 获取语言设置
|
|
|
+ self.lang = kwargs.get('lang', 'ch')
|
|
|
+ self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True)
|
|
|
+
|
|
|
+ # 自动检测设备
|
|
|
+ device = kwargs.get('device', get_device())
|
|
|
+
|
|
|
+ # ✅ CPU 优化:自动切换到轻量模型
|
|
|
+ if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
|
|
|
+ logger.warning("CPU device detected. Switching to ch_lite for better performance.")
|
|
|
+ self.lang = 'ch_lite'
|
|
|
+
|
|
|
+ # ✅ 语言映射
|
|
|
+ if self.lang in latin_lang:
|
|
|
+ self.lang = 'latin'
|
|
|
+ elif self.lang in east_slavic_lang:
|
|
|
+ self.lang = 'east_slavic'
|
|
|
+ elif self.lang in arabic_lang:
|
|
|
+ self.lang = 'arabic'
|
|
|
+ elif self.lang in cyrillic_lang:
|
|
|
+ self.lang = 'cyrillic'
|
|
|
+ elif self.lang in devanagari_lang:
|
|
|
+ self.lang = 'devanagari'
|
|
|
+
|
|
|
+ # ✅ 读取模型配置
|
|
|
+ models_config_path = root_dir / 'pytorchocr' / 'utils' / 'resources' / 'models_config.yml'
|
|
|
+
|
|
|
+ if not models_config_path.exists():
|
|
|
+ raise FileNotFoundError(f"❌ Config file not found: {models_config_path}")
|
|
|
+
|
|
|
+ logger.info(f"📄 Reading config: {models_config_path}")
|
|
|
+
|
|
|
+ with open(models_config_path) as file:
|
|
|
+ config = yaml.safe_load(file)
|
|
|
+ det, rec, dict_file = get_model_params(self.lang, config)
|
|
|
+
|
|
|
+ logger.info(f"📋 Config for lang '{self.lang}':")
|
|
|
+ logger.info(f" Det model: {det}")
|
|
|
+ logger.info(f" Rec model: {rec}")
|
|
|
+ logger.info(f" Dict file: {dict_file}")
|
|
|
+
|
|
|
+ # ✅ 模型路径(优先使用 kwargs 中的自定义路径)
|
|
|
+ if 'det_model_path' not in kwargs:
|
|
|
+ ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch"
|
|
|
+ det_model_path = f"{ocr_models_dir}/{det}"
|
|
|
+ det_model_full_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path)
|
|
|
+ kwargs['det_model_path'] = det_model_full_path
|
|
|
+
|
|
|
+ logger.info(f"🔍 Detection model path: {det_model_full_path}")
|
|
|
+
|
|
|
+ # ✅ 验证模型文件存在
|
|
|
+ if not Path(det_model_full_path).exists():
|
|
|
+ logger.warning(f"⚠️ Detection model file not found: {det_model_full_path}")
|
|
|
+ else:
|
|
|
+ logger.info(f"✅ Detection model file exists: {det_model_full_path}")
|
|
|
+
|
|
|
+ # ✅ 识别模型路径
|
|
|
+ if 'rec_model_path' not in kwargs:
|
|
|
+ ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch"
|
|
|
+ rec_model_path = f"{ocr_models_dir}/{rec}"
|
|
|
+ rec_model_full_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path)
|
|
|
+ kwargs['rec_model_path'] = rec_model_full_path
|
|
|
+
|
|
|
+ logger.info(f"🔍 Recognition model path: {rec_model_full_path}")
|
|
|
+
|
|
|
+ # ✅ 验证模型文件存在
|
|
|
+ if not Path(rec_model_full_path).exists():
|
|
|
+ logger.warning(f"⚠️ Recognition model file not found: {rec_model_full_path}")
|
|
|
+ else:
|
|
|
+ logger.info(f"✅ Recognition model file exists: {rec_model_full_path}")
|
|
|
+
|
|
|
+ # ✅ 字典路径
|
|
|
+ if 'rec_char_dict_path' not in kwargs:
|
|
|
+ dict_path = root_dir / 'pytorchocr' / 'utils' / 'resources' / 'dict' / dict_file
|
|
|
+
|
|
|
+ if not dict_path.exists():
|
|
|
+ logger.error(f"❌ Dictionary file not found: {dict_path}")
|
|
|
+ raise FileNotFoundError(f"Dictionary file not found: {dict_path}")
|
|
|
+
|
|
|
+ kwargs['rec_char_dict_path'] = str(dict_path)
|
|
|
+ logger.info(f"📖 Dictionary: {dict_path.name}")
|
|
|
+
|
|
|
+ # ✅ 默认参数
|
|
|
+ kwargs.setdefault('rec_batch_num', 6)
|
|
|
+ kwargs.setdefault('device', device)
|
|
|
+
|
|
|
+ # ✅ 合并参数
|
|
|
+ default_args = vars(parsed_args)
|
|
|
+ default_args.update(kwargs)
|
|
|
+ final_args = argparse.Namespace(**default_args)
|
|
|
+
|
|
|
+ logger.info(f"🔧 Initializing TextSystem...")
|
|
|
+ logger.info(f" device: {final_args.device}")
|
|
|
+ logger.info(f" rec_char_dict_path: {final_args.rec_char_dict_path}")
|
|
|
+ logger.info(f" rec_batch_num: {final_args.rec_batch_num}")
|
|
|
+
|
|
|
+ # ✅ 初始化 TextSystem
|
|
|
+ super().__init__(final_args)
|
|
|
+
|
|
|
+ # ✅ 验证并修复字符集
|
|
|
+ logger.info(f"🔍 Validating text recognizer...")
|
|
|
+
|
|
|
+ if hasattr(self, 'text_recognizer'):
|
|
|
+ logger.info(f" ✅ text_recognizer exists")
|
|
|
+
|
|
|
+ # ❌ text_recognizer 本身没有 character 属性
|
|
|
+ # ✅ 字符集存储在 postprocess_op.character 中
|
|
|
+ if hasattr(self.text_recognizer, 'postprocess_op'):
|
|
|
+ postprocess_op = self.text_recognizer.postprocess_op
|
|
|
+ logger.info(f" ✅ postprocess_op exists: {type(postprocess_op)}")
|
|
|
+
|
|
|
+ if hasattr(postprocess_op, 'character'):
|
|
|
+ char_count = len(postprocess_op.character)
|
|
|
+ logger.info(f" ✅ postprocess_op.character exists")
|
|
|
+ logger.info(f" Character set size: {char_count}")
|
|
|
+
|
|
|
+ if char_count == 0:
|
|
|
+ logger.error(f" ❌ Character set is EMPTY!")
|
|
|
+ logger.error(f" Reloading dictionary...")
|
|
|
+
|
|
|
+ # 强制重新加载字典
|
|
|
+ dict_path = final_args.rec_char_dict_path
|
|
|
+ with open(dict_path, 'rb') as f:
|
|
|
+ lines = f.readlines()
|
|
|
+ character = []
|
|
|
+ for line in lines:
|
|
|
+ line = line.decode('utf-8').strip('\n').strip('\r\n')
|
|
|
+ character.append(line)
|
|
|
+
|
|
|
+ # 更新字符集
|
|
|
+ postprocess_op.character = character
|
|
|
+ logger.info(f" ✅ Character set reloaded: {len(character)} chars")
|
|
|
+ else:
|
|
|
+ logger.info(f" ✅ Character set loaded successfully")
|
|
|
+ # 显示前10个字符(调试用)
|
|
|
+ sample_chars = postprocess_op.character[:10]
|
|
|
+ logger.info(f" First 10 chars: {sample_chars}")
|
|
|
+ else:
|
|
|
+ logger.error(f" ❌ postprocess_op.character NOT found!")
|
|
|
+ logger.error(f" Available attributes: {dir(postprocess_op)}")
|
|
|
+ else:
|
|
|
+ logger.error(f" ❌ postprocess_op NOT found!")
|
|
|
+ logger.error(f" Available attributes: {dir(self.text_recognizer)}")
|
|
|
+ else:
|
|
|
+ logger.error(f" ❌ text_recognizer NOT found!")
|
|
|
+
|
|
|
+ logger.info(f"✅ OCR engine initialized")
|
|
|
+
|
|
|
+ def ocr(
|
|
|
+ self,
|
|
|
+ img,
|
|
|
+ det=True,
|
|
|
+ rec=True,
|
|
|
+ mfd_res=None,
|
|
|
+ tqdm_enable=False,
|
|
|
+ tqdm_desc="OCR-rec Predict",
|
|
|
+ ):
|
|
|
+ """
|
|
|
+ 执行 OCR
|
|
|
+
|
|
|
+ Args:
|
|
|
+ img: BGR 图像、图像列表、文件路径或字节流
|
|
|
+ det (bool): 是否执行检测
|
|
|
+ rec (bool): 是否执行识别
|
|
|
+ mfd_res (list): 公式检测结果(用于过滤文本框)
|
|
|
+ tqdm_enable (bool): 是否显示进度条
|
|
|
+ tqdm_desc (str): 进度条描述
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ det=True, rec=True: [[[box], (text, conf)], ...]
|
|
|
+ det=True, rec=False: [[boxes], ...]
|
|
|
+ det=False, rec=True: [[(text, conf), ...]]
|
|
|
+ """
|
|
|
+ assert isinstance(img, (np.ndarray, list, str, bytes))
|
|
|
+
|
|
|
+ if isinstance(img, list) and det:
|
|
|
+ logger.error('When input is a list of images, det must be False')
|
|
|
+ exit(1)
|
|
|
+
|
|
|
+ img = check_img(img)
|
|
|
+ imgs = [img]
|
|
|
+
|
|
|
+ with warnings.catch_warnings():
|
|
|
+ warnings.simplefilter("ignore", category=RuntimeWarning)
|
|
|
+
|
|
|
+ if det and rec:
|
|
|
+ # 检测 + 识别
|
|
|
+ ocr_res = []
|
|
|
+ for img in imgs:
|
|
|
+ img = preprocess_image(img)
|
|
|
+ dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
|
|
|
+ if not dt_boxes and not rec_res:
|
|
|
+ ocr_res.append(None)
|
|
|
+ continue
|
|
|
+ tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
|
|
+ ocr_res.append(tmp_res)
|
|
|
+ return ocr_res
|
|
|
+
|
|
|
+ elif det and not rec:
|
|
|
+ # 仅检测
|
|
|
+ ocr_res = []
|
|
|
+ for img in imgs:
|
|
|
+ img = preprocess_image(img)
|
|
|
+ dt_boxes, elapse = self.text_detector(img)
|
|
|
+ if dt_boxes is None:
|
|
|
+ ocr_res.append(None)
|
|
|
+ continue
|
|
|
+ dt_boxes = sorted_boxes(dt_boxes)
|
|
|
+ if self.enable_merge_det_boxes:
|
|
|
+ dt_boxes = merge_det_boxes(dt_boxes)
|
|
|
+ if mfd_res:
|
|
|
+ dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
|
|
+ tmp_res = [box.tolist() for box in dt_boxes]
|
|
|
+ ocr_res.append(tmp_res)
|
|
|
+ return ocr_res
|
|
|
+
|
|
|
+ elif not det and rec:
|
|
|
+ # 仅识别
|
|
|
+ ocr_res = []
|
|
|
+ for img in imgs:
|
|
|
+ if not isinstance(img, list):
|
|
|
+ img = preprocess_image(img)
|
|
|
+ img = [img]
|
|
|
+ rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable, tqdm_desc=tqdm_desc)
|
|
|
+ ocr_res.append(rec_res)
|
|
|
+ return ocr_res
|
|
|
+
|
|
|
+ def __call__(self, img, mfd_res=None):
|
|
|
+ """
|
|
|
+ 单张图像的 OCR(检测 + 识别)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ img: 预处理后的图像
|
|
|
+ mfd_res: 公式检测结果
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ (dt_boxes, rec_res): 检测框和识别结果
|
|
|
+ """
|
|
|
+ if img is None:
|
|
|
+ logger.debug("no valid image provided")
|
|
|
+ return None, None
|
|
|
+
|
|
|
+ ori_im = img.copy()
|
|
|
+ dt_boxes, elapse = self.text_detector(img)
|
|
|
+
|
|
|
+ if dt_boxes is None:
|
|
|
+ logger.debug(f"no dt_boxes found, elapsed: {elapse}")
|
|
|
+ return None, None
|
|
|
+
|
|
|
+ # 排序
|
|
|
+ dt_boxes = sorted_boxes(dt_boxes)
|
|
|
+
|
|
|
+ # 合并相邻框
|
|
|
+ if self.enable_merge_det_boxes:
|
|
|
+ dt_boxes = merge_det_boxes(dt_boxes)
|
|
|
+
|
|
|
+ # 过滤与公式重叠的框
|
|
|
+ if mfd_res:
|
|
|
+ dt_boxes = update_det_boxes(dt_boxes, mfd_res)
|
|
|
+
|
|
|
+ # 裁剪文本区域
|
|
|
+ img_crop_list = []
|
|
|
+ for bno in range(len(dt_boxes)):
|
|
|
+ tmp_box = copy.deepcopy(dt_boxes[bno])
|
|
|
+ img_crop = get_rotate_crop_image(ori_im, tmp_box)
|
|
|
+ img_crop_list.append(img_crop)
|
|
|
+
|
|
|
+ logger.info(f"🔤 Recognizing {len(img_crop_list)} text regions...")
|
|
|
+ rec_res, elapse = self.text_recognizer(img_crop_list)
|
|
|
+
|
|
|
+ # 统计结果
|
|
|
+ non_empty = sum(1 for text, _ in rec_res if text)
|
|
|
+ logger.info(f" Found {non_empty}/{len(rec_res)} non-empty results")
|
|
|
+
|
|
|
+ if non_empty > 0:
|
|
|
+ logger.info(f" First 5 non-empty results:")
|
|
|
+ count = 0
|
|
|
+ for i, (text, conf) in enumerate(rec_res):
|
|
|
+ if text:
|
|
|
+ logger.info(f" [{i+1}] '{text}' (conf={conf:.3f})")
|
|
|
+ count += 1
|
|
|
+ if count >= 5:
|
|
|
+ break
|
|
|
+
|
|
|
+ # 过滤低分结果
|
|
|
+ filter_boxes, filter_rec_res = [], []
|
|
|
+ for box, rec_result in zip(dt_boxes, rec_res):
|
|
|
+ text, score = rec_result
|
|
|
+ if score >= self.drop_score:
|
|
|
+ filter_boxes.append(box)
|
|
|
+ filter_rec_res.append(rec_result)
|
|
|
+
|
|
|
+ return filter_boxes, filter_rec_res
|
|
|
+
|
|
|
+ def visualize(
|
|
|
+ self,
|
|
|
+ img: np.ndarray,
|
|
|
+ ocr_results: list,
|
|
|
+ output_path: str = None,
|
|
|
+ show_text: bool = True,
|
|
|
+ show_confidence: bool = True,
|
|
|
+ font_scale: float = 0.5,
|
|
|
+ thickness: int = 2
|
|
|
+ ) -> np.ndarray:
|
|
|
+ """
|
|
|
+ 可视化 OCR 检测结果
|
|
|
+
|
|
|
+ Args:
|
|
|
+ img: 原始图像(BGR 格式)
|
|
|
+ ocr_results: OCR 结果 [[box, (text, conf)], ...]
|
|
|
+ output_path: 输出路径(可选)
|
|
|
+ show_text: 是否显示识别的文字
|
|
|
+ show_confidence: 是否显示置信度
|
|
|
+ font_scale: 字体大小
|
|
|
+ thickness: 边框粗细
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 标注后的图像
|
|
|
+ """
|
|
|
+ img_vis = img.copy()
|
|
|
+
|
|
|
+ if not ocr_results or ocr_results[0] is None:
|
|
|
+ logger.warning("No OCR results to visualize")
|
|
|
+ return img_vis
|
|
|
+
|
|
|
+ # 颜色映射(根据置信度)
|
|
|
+ def get_color_by_confidence(conf: float) -> tuple:
|
|
|
+ """根据置信度返回颜色 (绿色->黄色->红色)"""
|
|
|
+ if conf >= 0.9:
|
|
|
+ return (0, 255, 0) # 高置信度:绿色
|
|
|
+ elif conf >= 0.7:
|
|
|
+ return (0, 255, 255) # 中置信度:黄色
|
|
|
+ else:
|
|
|
+ return (0, 165, 255) # 低置信度:橙色
|
|
|
+
|
|
|
+ for idx, result in enumerate(ocr_results[0], 1):
|
|
|
+ box, (text, conf) = result
|
|
|
+
|
|
|
+ # 转换坐标格式:[[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
|
|
+ if isinstance(box, list):
|
|
|
+ points = np.array(box, dtype=np.int32).reshape((-1, 2))
|
|
|
+ else:
|
|
|
+ points = box.astype(np.int32)
|
|
|
+
|
|
|
+ # 获取边框颜色
|
|
|
+ color = get_color_by_confidence(conf)
|
|
|
+
|
|
|
+ # 绘制多边形边框
|
|
|
+ cv2.polylines(img_vis, [points], True, color, thickness)
|
|
|
+
|
|
|
+ # 计算文本位置(左上角)
|
|
|
+ x1, y1 = points[0]
|
|
|
+
|
|
|
+ # 构造标签
|
|
|
+ if show_text and show_confidence:
|
|
|
+ label = f"[{idx}] {text} ({conf:.2f})"
|
|
|
+ elif show_text:
|
|
|
+ label = f"[{idx}] {text}"
|
|
|
+ elif show_confidence:
|
|
|
+ label = f"[{idx}] {conf:.2f}"
|
|
|
+ else:
|
|
|
+ label = f"[{idx}]"
|
|
|
+
|
|
|
+ # 计算标签尺寸
|
|
|
+ label_size, baseline = cv2.getTextSize(
|
|
|
+ label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness
|
|
|
+ )
|
|
|
+ label_w, label_h = label_size
|
|
|
+
|
|
|
+ # 确保标签不超出图像边界
|
|
|
+ y_label = max(y1 - 5, label_h + 5)
|
|
|
+
|
|
|
+ # 绘制标签背景(半透明)
|
|
|
+ overlay = img_vis.copy()
|
|
|
+ cv2.rectangle(
|
|
|
+ overlay,
|
|
|
+ (x1, y_label - label_h - 5),
|
|
|
+ (x1 + label_w + 10, y_label + baseline),
|
|
|
+ color,
|
|
|
+ -1
|
|
|
+ )
|
|
|
+ cv2.addWeighted(overlay, 0.6, img_vis, 0.4, 0, img_vis)
|
|
|
+
|
|
|
+ # 绘制标签文字
|
|
|
+ cv2.putText(
|
|
|
+ img_vis,
|
|
|
+ label,
|
|
|
+ (x1 + 5, y_label - 5),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
+ font_scale,
|
|
|
+ (255, 255, 255), # 白色文字
|
|
|
+ thickness,
|
|
|
+ cv2.LINE_AA
|
|
|
+ )
|
|
|
+
|
|
|
+ # 添加统计信息
|
|
|
+ total_boxes = len(ocr_results[0])
|
|
|
+ avg_conf = sum(conf for _, (_, conf) in ocr_results[0]) / total_boxes if total_boxes > 0 else 0
|
|
|
+
|
|
|
+ stats_text = f"Total: {total_boxes} boxes | Avg Conf: {avg_conf:.3f}"
|
|
|
+ cv2.putText(
|
|
|
+ img_vis,
|
|
|
+ stats_text,
|
|
|
+ (10, 30),
|
|
|
+ cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
+ 0.7,
|
|
|
+ (0, 0, 255), # 红色
|
|
|
+ 2,
|
|
|
+ cv2.LINE_AA
|
|
|
+ )
|
|
|
+
|
|
|
+ # 保存图像
|
|
|
+ if output_path:
|
|
|
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ cv2.imwrite(output_path, img_vis)
|
|
|
+ logger.info(f"✅ Visualization saved to: {output_path}")
|
|
|
+
|
|
|
+ return img_vis
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ from dotenv import load_dotenv
|
|
|
+ load_dotenv(override=True)
|
|
|
+
|
|
|
+ logger.remove()
|
|
|
+ logger.add(sys.stderr, level="INFO")
|
|
|
+
|
|
|
+ print("🚀 Testing PytorchPaddleOCR with Visualization...")
|
|
|
+
|
|
|
+ try:
|
|
|
+ ocr = PytorchPaddleOCR(lang='ch', device='cpu')
|
|
|
+
|
|
|
+ # ✅ 验证字符集
|
|
|
+ if hasattr(ocr.text_recognizer, 'postprocess_op'):
|
|
|
+ char_count = len(ocr.text_recognizer.postprocess_op.character)
|
|
|
+ print(f"\n📊 Character set loaded: {char_count} characters")
|
|
|
+ if char_count > 0:
|
|
|
+ print(f" Sample chars: {ocr.text_recognizer.postprocess_op.character[:20]}")
|
|
|
+ else:
|
|
|
+ print(f" ❌ ERROR: Character set is empty!")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ # 测试图像列表
|
|
|
+ test_images = [
|
|
|
+ {
|
|
|
+ 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_001.png",
|
|
|
+ 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_001_ocr_vis.png"
|
|
|
+ },
|
|
|
+ {
|
|
|
+ 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png",
|
|
|
+ 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_003_ocr_vis.png"
|
|
|
+ }
|
|
|
+ ]
|
|
|
+
|
|
|
+ for img_info in test_images:
|
|
|
+ img_path = img_info['path']
|
|
|
+ output_path = img_info['output']
|
|
|
+
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print(f"📄 Processing: {Path(img_path).name}")
|
|
|
+ print(f"{'='*60}")
|
|
|
+
|
|
|
+ if not Path(img_path).exists():
|
|
|
+ print(f"❌ Image not found: {img_path}")
|
|
|
+ continue
|
|
|
+
|
|
|
+ img = cv2.imread(img_path)
|
|
|
+
|
|
|
+ if img is None:
|
|
|
+ print(f"❌ Failed to load image")
|
|
|
+ continue
|
|
|
+
|
|
|
+ print(f"📖 Original image: {img.shape}")
|
|
|
+
|
|
|
+ # 执行 OCR
|
|
|
+ results = ocr.ocr(img, det=True, rec=True)
|
|
|
+
|
|
|
+ if results and results[0]:
|
|
|
+ print(f"\n✅ OCR completed! Found {len(results[0])} text regions")
|
|
|
+
|
|
|
+ # 打印前 10 个结果
|
|
|
+ print(f"\n📝 First 10 results:")
|
|
|
+ for i, (box, (text, conf)) in enumerate(results[0][:10], 1):
|
|
|
+ print(f" [{i}] {text} (conf={conf:.3f})")
|
|
|
+
|
|
|
+ if len(results[0]) > 10:
|
|
|
+ print(f" ... and {len(results[0]) - 10} more")
|
|
|
+
|
|
|
+ # 可视化
|
|
|
+ print(f"\n🎨 Visualizing results...")
|
|
|
+ img_vis = ocr.visualize(
|
|
|
+ img,
|
|
|
+ results,
|
|
|
+ output_path=output_path,
|
|
|
+ show_text=True,
|
|
|
+ show_confidence=True
|
|
|
+ )
|
|
|
+
|
|
|
+ # 统计信息
|
|
|
+ total_boxes = len(results[0])
|
|
|
+ avg_conf = sum(conf for _, (_, conf) in results[0]) / total_boxes
|
|
|
+ high_conf = sum(1 for _, (_, conf) in results[0] if conf >= 0.9)
|
|
|
+ low_conf = sum(1 for _, (_, conf) in results[0] if conf < 0.7)
|
|
|
+ non_empty = sum(1 for _, (text, _) in results[0] if text)
|
|
|
+
|
|
|
+ print(f"\n📊 Statistics:")
|
|
|
+ print(f" Total boxes: {total_boxes}")
|
|
|
+ print(f" Non-empty: {non_empty}")
|
|
|
+ print(f" Average confidence: {avg_conf:.3f}")
|
|
|
+ print(f" High confidence (≥0.9): {high_conf}")
|
|
|
+ print(f" Low confidence (<0.7): {low_conf}")
|
|
|
+ else:
|
|
|
+ print(f"⚠️ No results found")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"\n❌ Test failed: {e}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|