""" PytorchPaddleOCR - 从 MinerU 移植 提供完整的 OCR 功能(检测 + 识别) """ import copy import os import sys import warnings from pathlib import Path import cv2 import numpy as np import yaml from loguru import logger import argparse # ✅ 修改导入 try: from .device_utils import get_device except ImportError: from device_utils import get_device # 当作为脚本运行时,添加父目录到 Python 路径 current_dir = Path(__file__).resolve().parent parent_dir = current_dir.parent if str(parent_dir) not in sys.path: sys.path.insert(0, str(parent_dir)) # ✅ 根据运行模式选择导入方式 try: # 作为模块导入时(from vendor.pytorch_paddle import ...) from .ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image from .infer.predict_system import TextSystem from .infer import pytorchocr_utility as utility except ImportError: # 作为脚本运行时(python pytorch_paddle.py) from vendor.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image from vendor.infer.predict_system import TextSystem from vendor.infer import pytorchocr_utility as utility latin_lang = [ "af", "az", "bs", "cs", "cy", "da", "de", "es", "et", "fr", "ga", "hr", "hu", "id", "is", "it", "ku", "la", "lt", "lv", "mi", "ms", "mt", "nl", "no", "oc", "pi", "pl", "pt", "ro", "rs_latin", "sk", "sl", "sq", "sv", "sw", "tl", "tr", "uz", "vi", "french", "german", "fi", "eu", "gl", "lb", "rm", "ca", "qu", ] arabic_lang = ["ar", "fa", "ug", "ur", "ps", "ku", "sd", "bal"] cyrillic_lang = [ "ru", "rs_cyrillic", "be", "bg", "uk", "mn", "abq", "ady", "kbd", "ava", "dar", "inh", "che", "lbe", "lez", "tab", "kk", "ky", "tg", "mk", "tt", "cv", "ba", "mhr", "mo", "udm", "kv", "os", "bua", "xal", "tyv", "sah", "kaa", ] east_slavic_lang = ["ru", "be", "uk"] devanagari_lang = [ "hi", "mr", "ne", "bh", "mai", "ang", "bho", "mah", "sck", "new", "gom", "sa", "bgc", ] def get_model_params(lang, config): """从配置文件获取模型参数""" if lang in config['lang']: params = config['lang'][lang] det = params.get('det') rec = params.get('rec') dict_file = params.get('dict') return det, rec, dict_file else: raise Exception(f'Language {lang} not supported') def auto_download_and_get_model_root_path(model_path): """ 模拟 MinerU 的模型下载逻辑 """ modelscope_cache = os.getenv('MODELSCOPE_CACHE_DIR', str(Path.home() / '.cache' / 'modelscope')) return modelscope_cache # 当前文件所在目录 root_dir = Path(__file__).resolve().parent class PytorchPaddleOCR(TextSystem): """ PytorchPaddleOCR - OCR 引擎 继承 TextSystem,提供完整的 OCR 功能 """ def __init__(self, *args, **kwargs): """初始化 OCR 引擎 Args: lang (str): 语言 ('ch', 'en', 'latin', 'korean', 'japan', 等) det_db_thresh (float): 检测二值化阈值 det_db_box_thresh (float): 检测框过滤阈值 rec_batch_num (int): 识别批大小 enable_merge_det_boxes (bool): 是否合并检测框 device (str): 设备 ('cpu', 'cuda:0', 'mps') det_model_path (str): 自定义检测模型路径 rec_model_path (str): 自定义识别模型路径 rec_char_dict_path (str): 自定义字典路径 """ # 初始化参数解析器 parser = utility.init_args() args_list = list(args) if args else [] parsed_args = parser.parse_args(args_list) # 获取语言设置 self.lang = kwargs.get('lang', 'ch') self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True) # 自动检测设备 device = kwargs.get('device', get_device()) # ✅ CPU 优化:自动切换到轻量模型 if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']: logger.warning("CPU device detected. Switching to ch_lite for better performance.") self.lang = 'ch_lite' # ✅ 语言映射 if self.lang in latin_lang: self.lang = 'latin' elif self.lang in east_slavic_lang: self.lang = 'east_slavic' elif self.lang in arabic_lang: self.lang = 'arabic' elif self.lang in cyrillic_lang: self.lang = 'cyrillic' elif self.lang in devanagari_lang: self.lang = 'devanagari' # ✅ 读取模型配置 models_config_path = root_dir / 'pytorchocr' / 'utils' / 'resources' / 'models_config.yml' if not models_config_path.exists(): raise FileNotFoundError(f"❌ Config file not found: {models_config_path}") logger.info(f"📄 Reading config: {models_config_path}") with open(models_config_path) as file: config = yaml.safe_load(file) det, rec, dict_file = get_model_params(self.lang, config) logger.info(f"📋 Config for lang '{self.lang}':") logger.info(f" Det model: {det}") logger.info(f" Rec model: {rec}") logger.info(f" Dict file: {dict_file}") # ✅ 模型路径(优先使用 kwargs 中的自定义路径) if 'det_model_path' not in kwargs: ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch" det_model_path = f"{ocr_models_dir}/{det}" det_model_full_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path) kwargs['det_model_path'] = det_model_full_path logger.info(f"🔍 Detection model path: {det_model_full_path}") # ✅ 验证模型文件存在 if not Path(det_model_full_path).exists(): logger.warning(f"⚠️ Detection model file not found: {det_model_full_path}") else: logger.info(f"✅ Detection model file exists: {det_model_full_path}") # ✅ 识别模型路径 if 'rec_model_path' not in kwargs: ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch" rec_model_path = f"{ocr_models_dir}/{rec}" rec_model_full_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path) kwargs['rec_model_path'] = rec_model_full_path logger.info(f"🔍 Recognition model path: {rec_model_full_path}") # ✅ 验证模型文件存在 if not Path(rec_model_full_path).exists(): logger.warning(f"⚠️ Recognition model file not found: {rec_model_full_path}") else: logger.info(f"✅ Recognition model file exists: {rec_model_full_path}") # ✅ 字典路径 if 'rec_char_dict_path' not in kwargs: dict_path = root_dir / 'pytorchocr' / 'utils' / 'resources' / 'dict' / dict_file if not dict_path.exists(): logger.error(f"❌ Dictionary file not found: {dict_path}") raise FileNotFoundError(f"Dictionary file not found: {dict_path}") kwargs['rec_char_dict_path'] = str(dict_path) logger.info(f"📖 Dictionary: {dict_path.name}") # ✅ 默认参数 kwargs.setdefault('rec_batch_num', 6) kwargs.setdefault('device', device) # ✅ 合并参数 default_args = vars(parsed_args) default_args.update(kwargs) final_args = argparse.Namespace(**default_args) logger.info(f"🔧 Initializing TextSystem...") logger.info(f" device: {final_args.device}") logger.info(f" rec_char_dict_path: {final_args.rec_char_dict_path}") logger.info(f" rec_batch_num: {final_args.rec_batch_num}") # ✅ 初始化 TextSystem super().__init__(final_args) # ✅ 验证并修复字符集 logger.info(f"🔍 Validating text recognizer...") if hasattr(self, 'text_recognizer'): logger.info(f" ✅ text_recognizer exists") # ❌ text_recognizer 本身没有 character 属性 # ✅ 字符集存储在 postprocess_op.character 中 if hasattr(self.text_recognizer, 'postprocess_op'): postprocess_op = self.text_recognizer.postprocess_op logger.info(f" ✅ postprocess_op exists: {type(postprocess_op)}") if hasattr(postprocess_op, 'character'): char_count = len(postprocess_op.character) logger.info(f" ✅ postprocess_op.character exists") logger.info(f" Character set size: {char_count}") if char_count == 0: logger.error(f" ❌ Character set is EMPTY!") logger.error(f" Reloading dictionary...") # 强制重新加载字典 dict_path = final_args.rec_char_dict_path with open(dict_path, 'rb') as f: lines = f.readlines() character = [] for line in lines: line = line.decode('utf-8').strip('\n').strip('\r\n') character.append(line) # 更新字符集 postprocess_op.character = character logger.info(f" ✅ Character set reloaded: {len(character)} chars") else: logger.info(f" ✅ Character set loaded successfully") # 显示前10个字符(调试用) sample_chars = postprocess_op.character[:10] logger.info(f" First 10 chars: {sample_chars}") else: logger.error(f" ❌ postprocess_op.character NOT found!") logger.error(f" Available attributes: {dir(postprocess_op)}") else: logger.error(f" ❌ postprocess_op NOT found!") logger.error(f" Available attributes: {dir(self.text_recognizer)}") else: logger.error(f" ❌ text_recognizer NOT found!") logger.info(f"✅ OCR engine initialized") def ocr( self, img, det=True, rec=True, mfd_res=None, tqdm_enable=False, tqdm_desc="OCR-rec Predict", ): """ 执行 OCR Args: img: BGR 图像、图像列表、文件路径或字节流 det (bool): 是否执行检测 rec (bool): 是否执行识别 mfd_res (list): 公式检测结果(用于过滤文本框) tqdm_enable (bool): 是否显示进度条 tqdm_desc (str): 进度条描述 Returns: det=True, rec=True: [[[box], (text, conf)], ...] det=True, rec=False: [[boxes], ...] det=False, rec=True: [[(text, conf), ...]] """ assert isinstance(img, (np.ndarray, list, str, bytes)) if isinstance(img, list) and det: logger.error('When input is a list of images, det must be False') exit(1) img = check_img(img) imgs = [img] with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) if det and rec: # 检测 + 识别 ocr_res = [] for img in imgs: img = preprocess_image(img) dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res) if not dt_boxes and not rec_res: ocr_res.append(None) continue tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] ocr_res.append(tmp_res) return ocr_res elif det and not rec: # 仅检测 ocr_res = [] for img in imgs: img = preprocess_image(img) dt_boxes, elapse = self.text_detector(img) if dt_boxes is None: ocr_res.append(None) continue dt_boxes = sorted_boxes(dt_boxes) if self.enable_merge_det_boxes: dt_boxes = merge_det_boxes(dt_boxes) if mfd_res: dt_boxes = update_det_boxes(dt_boxes, mfd_res) tmp_res = [box.tolist() for box in dt_boxes] ocr_res.append(tmp_res) return ocr_res elif not det and rec: # 仅识别 ocr_res = [] for img in imgs: if not isinstance(img, list): img = preprocess_image(img) img = [img] rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable, tqdm_desc=tqdm_desc) ocr_res.append(rec_res) return ocr_res def __call__(self, img, mfd_res=None): """ 单张图像的 OCR(检测 + 识别) Args: img: 预处理后的图像 mfd_res: 公式检测结果 Returns: (dt_boxes, rec_res): 检测框和识别结果 """ if img is None: logger.debug("no valid image provided") return None, None ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) if dt_boxes is None: logger.debug(f"no dt_boxes found, elapsed: {elapse}") return None, None # 排序 dt_boxes = sorted_boxes(dt_boxes) # 合并相邻框 if self.enable_merge_det_boxes: dt_boxes = merge_det_boxes(dt_boxes) # 过滤与公式重叠的框 if mfd_res: dt_boxes = update_det_boxes(dt_boxes, mfd_res) # 裁剪文本区域 img_crop_list = [] for bno in range(len(dt_boxes)): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) logger.info(f"🔤 Recognizing {len(img_crop_list)} text regions...") rec_res, elapse = self.text_recognizer(img_crop_list) # 统计结果 non_empty = sum(1 for text, _ in rec_res if text) logger.info(f" Found {non_empty}/{len(rec_res)} non-empty results") if non_empty > 0: logger.info(f" First 5 non-empty results:") count = 0 for i, (text, conf) in enumerate(rec_res): if text: logger.info(f" [{i+1}] '{text}' (conf={conf:.3f})") count += 1 if count >= 5: break # 过滤低分结果 filter_boxes, filter_rec_res = [], [] for box, rec_result in zip(dt_boxes, rec_res): text, score = rec_result if score >= self.drop_score: filter_boxes.append(box) filter_rec_res.append(rec_result) return filter_boxes, filter_rec_res def visualize( self, img: np.ndarray, ocr_results: list, output_path: str = None, show_text: bool = True, show_confidence: bool = True, font_scale: float = 0.5, thickness: int = 2 ) -> np.ndarray: """ 可视化 OCR 检测结果 Args: img: 原始图像(BGR 格式) ocr_results: OCR 结果 [[box, (text, conf)], ...] output_path: 输出路径(可选) show_text: 是否显示识别的文字 show_confidence: 是否显示置信度 font_scale: 字体大小 thickness: 边框粗细 Returns: 标注后的图像 """ img_vis = img.copy() if not ocr_results or ocr_results[0] is None: logger.warning("No OCR results to visualize") return img_vis # 颜色映射(根据置信度) def get_color_by_confidence(conf: float) -> tuple: """根据置信度返回颜色 (绿色->黄色->红色)""" if conf >= 0.9: return (0, 255, 0) # 高置信度:绿色 elif conf >= 0.7: return (0, 255, 255) # 中置信度:黄色 else: return (0, 165, 255) # 低置信度:橙色 for idx, result in enumerate(ocr_results[0], 1): box, (text, conf) = result # 转换坐标格式:[[x1,y1], [x2,y2], [x3,y3], [x4,y4]] if isinstance(box, list): points = np.array(box, dtype=np.int32).reshape((-1, 2)) else: points = box.astype(np.int32) # 获取边框颜色 color = get_color_by_confidence(conf) # 绘制多边形边框 cv2.polylines(img_vis, [points], True, color, thickness) # 计算文本位置(左上角) x1, y1 = points[0] # 构造标签 if show_text and show_confidence: label = f"[{idx}] {text} ({conf:.2f})" elif show_text: label = f"[{idx}] {text}" elif show_confidence: label = f"[{idx}] {conf:.2f}" else: label = f"[{idx}]" # 计算标签尺寸 label_size, baseline = cv2.getTextSize( label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness ) label_w, label_h = label_size # 确保标签不超出图像边界 y_label = max(y1 - 5, label_h + 5) # 绘制标签背景(半透明) overlay = img_vis.copy() cv2.rectangle( overlay, (x1, y_label - label_h - 5), (x1 + label_w + 10, y_label + baseline), color, -1 ) cv2.addWeighted(overlay, 0.6, img_vis, 0.4, 0, img_vis) # 绘制标签文字 cv2.putText( img_vis, label, (x1 + 5, y_label - 5), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), # 白色文字 thickness, cv2.LINE_AA ) # 添加统计信息 total_boxes = len(ocr_results[0]) avg_conf = sum(conf for _, (_, conf) in ocr_results[0]) / total_boxes if total_boxes > 0 else 0 stats_text = f"Total: {total_boxes} boxes | Avg Conf: {avg_conf:.3f}" cv2.putText( img_vis, stats_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), # 红色 2, cv2.LINE_AA ) # 保存图像 if output_path: Path(output_path).parent.mkdir(parents=True, exist_ok=True) cv2.imwrite(output_path, img_vis) logger.info(f"✅ Visualization saved to: {output_path}") return img_vis if __name__ == '__main__': from dotenv import load_dotenv load_dotenv(override=True) logger.remove() logger.add(sys.stderr, level="INFO") print("🚀 Testing PytorchPaddleOCR with Visualization...") try: ocr = PytorchPaddleOCR(lang='ch', device='cpu') # ✅ 验证字符集 if hasattr(ocr.text_recognizer, 'postprocess_op'): char_count = len(ocr.text_recognizer.postprocess_op.character) print(f"\n📊 Character set loaded: {char_count} characters") if char_count > 0: print(f" Sample chars: {ocr.text_recognizer.postprocess_op.character[:20]}") else: print(f" ❌ ERROR: Character set is empty!") sys.exit(1) # 测试图像列表 test_images = [ { 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_001.png", 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_001_ocr_vis.png" }, { 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png", 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_003_ocr_vis.png" } ] for img_info in test_images: img_path = img_info['path'] output_path = img_info['output'] print(f"\n{'='*60}") print(f"📄 Processing: {Path(img_path).name}") print(f"{'='*60}") if not Path(img_path).exists(): print(f"❌ Image not found: {img_path}") continue img = cv2.imread(img_path) if img is None: print(f"❌ Failed to load image") continue print(f"📖 Original image: {img.shape}") # 执行 OCR results = ocr.ocr(img, det=True, rec=True) if results and results[0]: print(f"\n✅ OCR completed! Found {len(results[0])} text regions") # 打印前 10 个结果 print(f"\n📝 First 10 results:") for i, (box, (text, conf)) in enumerate(results[0][:10], 1): print(f" [{i}] {text} (conf={conf:.3f})") if len(results[0]) > 10: print(f" ... and {len(results[0]) - 10} more") # 可视化 print(f"\n🎨 Visualizing results...") img_vis = ocr.visualize( img, results, output_path=output_path, show_text=True, show_confidence=True ) # 统计信息 total_boxes = len(results[0]) avg_conf = sum(conf for _, (_, conf) in results[0]) / total_boxes high_conf = sum(1 for _, (_, conf) in results[0] if conf >= 0.9) low_conf = sum(1 for _, (_, conf) in results[0] if conf < 0.7) non_empty = sum(1 for _, (text, _) in results[0] if text) print(f"\n📊 Statistics:") print(f" Total boxes: {total_boxes}") print(f" Non-empty: {non_empty}") print(f" Average confidence: {avg_conf:.3f}") print(f" High confidence (≥0.9): {high_conf}") print(f" Low confidence (<0.7): {low_conf}") else: print(f"⚠️ No results found") except Exception as e: print(f"\n❌ Test failed: {e}") import traceback traceback.print_exc()