""" PytorchPaddleOCR - 从 MinerU 移植 提供完整的 OCR 功能(检测 + 识别) """ import copy import os import sys import warnings from pathlib import Path import cv2 import numpy as np import yaml from loguru import logger import argparse from typing import Optional from vendor import get_device # 当作为脚本运行时,添加父目录到 Python 路径 current_dir = Path(__file__).resolve().parent parent_dir = current_dir.parent if str(parent_dir) not in sys.path: sys.path.insert(0, str(parent_dir)) # ✅ 导入方向分类器 from orientation_classifier_v2 import OrientationClassifierV2 from vendor import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image from vendor.infer import TextSystem from vendor.infer import pytorchocr_utility as utility latin_lang = [ "af", "az", "bs", "cs", "cy", "da", "de", "es", "et", "fr", "ga", "hr", "hu", "id", "is", "it", "ku", "la", "lt", "lv", "mi", "ms", "mt", "nl", "no", "oc", "pi", "pl", "pt", "ro", "rs_latin", "sk", "sl", "sq", "sv", "sw", "tl", "tr", "uz", "vi", "french", "german", "fi", "eu", "gl", "lb", "rm", "ca", "qu", ] arabic_lang = ["ar", "fa", "ug", "ur", "ps", "ku", "sd", "bal"] cyrillic_lang = [ "ru", "rs_cyrillic", "be", "bg", "uk", "mn", "abq", "ady", "kbd", "ava", "dar", "inh", "che", "lbe", "lez", "tab", "kk", "ky", "tg", "mk", "tt", "cv", "ba", "mhr", "mo", "udm", "kv", "os", "bua", "xal", "tyv", "sah", "kaa", ] east_slavic_lang = ["ru", "be", "uk"] devanagari_lang = [ "hi", "mr", "ne", "bh", "mai", "ang", "bho", "mah", "sck", "new", "gom", "sa", "bgc", ] def get_model_params(lang, config): """从配置文件获取模型参数""" if lang in config['lang']: params = config['lang'][lang] det = params.get('det') rec = params.get('rec') dict_file = params.get('dict') return det, rec, dict_file else: raise Exception(f'Language {lang} not supported') def auto_download_and_get_model_root_path(model_path): """ 模拟 MinerU 的模型下载逻辑 """ modelscope_cache = os.getenv('MODELSCOPE_CACHE_DIR', str(Path.home() / '.cache' / 'modelscope')) return modelscope_cache # 当前文件所在目录 root_dir = Path(__file__).resolve().parent class PytorchPaddleOCR(TextSystem): """ PytorchPaddleOCR - OCR 引擎 继承 TextSystem,提供完整的 OCR 功能 """ def __init__(self, *args, **kwargs): """初始化 OCR 引擎 Args: lang (str): 语言 ('ch', 'en', 'latin', 'korean', 'japan', 等) det_db_thresh (float): 检测二值化阈值 det_db_box_thresh (float): 检测框过滤阈值 rec_batch_num (int): 识别批大小 enable_merge_det_boxes (bool): 是否合并检测框 device (str): 设备 ('cpu', 'cuda:0', 'mps') det_model_path (str): 自定义检测模型路径 rec_model_path (str): 自定义识别模型路径 rec_char_dict_path (str): 自定义字典路径 """ # 初始化参数解析器 parser = utility.init_args() args_list = list(args) if args else [] parsed_args = parser.parse_args(args_list) # 获取语言设置 self.lang = kwargs.get('lang', 'ch') self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True) # ✅ 新增:方向分类器配置 self.use_orientation_cls = kwargs.get("use_orientation_cls", False) self.orientation_model_path = kwargs.get( "orientation_model_path", None ) # 自动检测设备 device = kwargs.get('device', get_device()) # ✅ CPU 优化:自动切换到轻量模型 if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']: logger.warning("CPU device detected. Switching to ch_lite for better performance.") self.lang = 'ch_lite' # ✅ 语言映射 if self.lang in latin_lang: self.lang = 'latin' elif self.lang in east_slavic_lang: self.lang = 'east_slavic' elif self.lang in arabic_lang: self.lang = 'arabic' elif self.lang in cyrillic_lang: self.lang = 'cyrillic' elif self.lang in devanagari_lang: self.lang = 'devanagari' # ✅ 读取模型配置 models_config_path = root_dir / 'vendor' / 'pytorchocr' / 'utils' / 'resources' / 'models_config.yml' if not models_config_path.exists(): raise FileNotFoundError(f"❌ Config file not found: {models_config_path}") logger.info(f"📄 Reading config: {models_config_path}") with open(models_config_path) as file: config = yaml.safe_load(file) det, rec, dict_file = get_model_params(self.lang, config) logger.info(f"📋 Config for lang '{self.lang}':") logger.info(f" Det model: {det}") logger.info(f" Rec model: {rec}") logger.info(f" Dict file: {dict_file}") # ✅ 模型路径(优先使用 kwargs 中的自定义路径) if 'det_model_path' not in kwargs: ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch" det_model_path = f"{ocr_models_dir}/{det}" det_model_full_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path) kwargs['det_model_path'] = det_model_full_path logger.info(f"🔍 Detection model path: {det_model_full_path}") # ✅ 验证模型文件存在 if not Path(det_model_full_path).exists(): logger.warning(f"⚠️ Detection model file not found: {det_model_full_path}") else: logger.info(f"✅ Detection model file exists: {det_model_full_path}") # ✅ 识别模型路径 if 'rec_model_path' not in kwargs: ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch" rec_model_path = f"{ocr_models_dir}/{rec}" rec_model_full_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path) kwargs['rec_model_path'] = rec_model_full_path logger.info(f"🔍 Recognition model path: {rec_model_full_path}") # ✅ 验证模型文件存在 if not Path(rec_model_full_path).exists(): logger.warning(f"⚠️ Recognition model file not found: {rec_model_full_path}") else: logger.info(f"✅ Recognition model file exists: {rec_model_full_path}") # ✅ 字典路径 if 'rec_char_dict_path' not in kwargs: dict_path = root_dir / 'vendor' / 'pytorchocr' / 'utils' / 'resources' / 'dict' / dict_file if not dict_path.exists(): logger.error(f"❌ Dictionary file not found: {dict_path}") raise FileNotFoundError(f"Dictionary file not found: {dict_path}") kwargs['rec_char_dict_path'] = str(dict_path) logger.info(f"📖 Dictionary: {dict_path.name}") # ✅ 默认参数 kwargs.setdefault('rec_batch_num', 6) kwargs.setdefault('device', device) # ✅ 合并参数 default_args = vars(parsed_args) default_args.update(kwargs) final_args = argparse.Namespace(**default_args) logger.info(f"🔧 Initializing TextSystem...") logger.info(f" device: {final_args.device}") logger.info(f" rec_char_dict_path: {final_args.rec_char_dict_path}") logger.info(f" rec_batch_num: {final_args.rec_batch_num}") # ✅ 初始化 TextSystem super().__init__(final_args) # ✅ 初始化方向分类器(在 TextSystem 之后,这样可以使用 self.text_detector) self.orientation_classifier = None if self.use_orientation_cls and self.orientation_model_path: try: logger.info(f"🔄 Initializing orientation classifier...") logger.info(f" Model: {self.orientation_model_path}") # ✅ 创建一个简单的检测器适配器 class TextDetectorAdapter: """适配器:将 TextDetector 包装为 OrientationClassifierV2 需要的接口""" def __init__(self, text_detector): self.text_detector = text_detector def ocr(self, img, det=True, rec=False): """执行文本检测""" if not det: return None try: # 调用检测器 dt_boxes, _ = self.text_detector(img) if dt_boxes is None or len(dt_boxes) == 0: return [[]] # 返回格式: [[box1], [box2], ...] return [dt_boxes.tolist()] except Exception as e: logger.warning(f"Detection failed in adapter: {e}") return [[]] # ✅ 传入检测器适配器 detector_adapter = TextDetectorAdapter(self.text_detector) self.orientation_classifier = OrientationClassifierV2( model_path=self.orientation_model_path, text_detector=detector_adapter, # ✅ 传入检测器 aspect_ratio_threshold=kwargs.get("aspect_ratio_threshold", 1.2), vertical_text_ratio=kwargs.get("vertical_text_ratio", 0.28), vertical_text_min_count=kwargs.get("vertical_text_min_count", 3), use_gpu=kwargs.get("device", "cpu") != "cpu" ) logger.info(f"✅ Orientation classifier initialized") except Exception as e: logger.warning(f"⚠️ Failed to initialize orientation classifier: {e}") self.orientation_classifier = None # ✅ 验证并修复字符集 logger.info(f"🔍 Validating text recognizer...") if hasattr(self, 'text_recognizer'): logger.info(f" ✅ text_recognizer exists") # ❌ text_recognizer 本身没有 character 属性 # ✅ 字符集存储在 postprocess_op.character 中 if hasattr(self.text_recognizer, 'postprocess_op'): postprocess_op = self.text_recognizer.postprocess_op logger.info(f" ✅ postprocess_op exists: {type(postprocess_op)}") if hasattr(postprocess_op, 'character'): char_count = len(postprocess_op.character) logger.info(f" ✅ postprocess_op.character exists") logger.info(f" Character set size: {char_count}") if char_count == 0: logger.error(f" ❌ Character set is EMPTY!") logger.error(f" Reloading dictionary...") # 强制重新加载字典 dict_path = final_args.rec_char_dict_path with open(dict_path, 'rb') as f: lines = f.readlines() character = [] for line in lines: line = line.decode('utf-8').strip('\n').strip('\r\n') character.append(line) # 更新字符集 postprocess_op.character = character logger.info(f" ✅ Character set reloaded: {len(character)} chars") else: logger.info(f" ✅ Character set loaded successfully") # 显示前10个字符(调试用) sample_chars = postprocess_op.character[:10] logger.info(f" First 10 chars: {sample_chars}") else: logger.error(f" ❌ postprocess_op.character NOT found!") logger.error(f" Available attributes: {dir(postprocess_op)}") else: logger.error(f" ❌ postprocess_op NOT found!") logger.error(f" Available attributes: {dir(self.text_recognizer)}") else: logger.error(f" ❌ text_recognizer NOT found!") logger.info(f"✅ OCR engine initialized") def ocr( self, img, det=True, rec=True, mfd_res=None, tqdm_enable=False, tqdm_desc="OCR-rec Predict", ): """ 执行 OCR Args: img: BGR 图像、图像列表、文件路径或字节流 det (bool): 是否执行检测 rec (bool): 是否执行识别 mfd_res (list): 公式检测结果(用于过滤文本框) tqdm_enable (bool): 是否显示进度条 tqdm_desc (str): 进度条描述 Returns: det=True, rec=True: [[[box], (text, conf)], ...] det=True, rec=False: [[boxes], ...] det=False, rec=True: [[(text, conf), ...]] """ assert isinstance(img, (np.ndarray, list, str, bytes)) if isinstance(img, list) and det: logger.error('When input is a list of images, det must be False') exit(1) # np.ndarray: BGR 格式图像 img = check_img(img) imgs = [img] with warnings.catch_warnings(): warnings.simplefilter("ignore", category=RuntimeWarning) if det and rec: # 检测 + 识别 ocr_res = [] for img in imgs: img = preprocess_image(img) dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res) if not dt_boxes and not rec_res: ocr_res.append(None) continue tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] ocr_res.append(tmp_res) return ocr_res elif det and not rec: # 仅检测 ocr_res = [] for img in imgs: img = preprocess_image(img) dt_boxes, elapse = self.text_detector(img) if dt_boxes is None: ocr_res.append(None) continue dt_boxes = sorted_boxes(dt_boxes) if self.enable_merge_det_boxes: dt_boxes = merge_det_boxes(dt_boxes) if mfd_res: dt_boxes = update_det_boxes(dt_boxes, mfd_res) tmp_res = [box.tolist() for box in dt_boxes] ocr_res.append(tmp_res) return ocr_res elif not det and rec: # 仅识别 ocr_res = [] for img in imgs: if not isinstance(img, list): img = preprocess_image(img) img = [img] rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable, tqdm_desc=tqdm_desc) ocr_res.append(rec_res) return ocr_res def __call__(self, img, mfd_res=None): """ 单张图像的 OCR(检测 + 识别) Args: img: 预处理后的图像 mfd_res: 公式检测结果 Returns: (dt_boxes, rec_res): 检测框和识别结果 """ if img is None: logger.debug("no valid image provided") return None, None ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) if dt_boxes is None: logger.debug(f"no dt_boxes found, elapsed: {elapse}") return None, None # 排序 dt_boxes = sorted_boxes(dt_boxes) # 合并相邻框 if self.enable_merge_det_boxes: dt_boxes = merge_det_boxes(dt_boxes) # 过滤与公式重叠的框 if mfd_res: dt_boxes = update_det_boxes(dt_boxes, mfd_res) # 裁剪文本区域 img_crop_list = [] for bno in range(len(dt_boxes)): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) logger.info(f"🔤 Recognizing {len(img_crop_list)} text regions...") rec_res, elapse = self.text_recognizer(img_crop_list) # 统计结果 non_empty = sum(1 for text, _ in rec_res if text) logger.info(f" Found {non_empty}/{len(rec_res)} non-empty results") if non_empty > 0: logger.info(f" First 5 non-empty results:") count = 0 for i, (text, conf) in enumerate(rec_res): if text: logger.info(f" [{i+1}] '{text}' (conf={conf:.3f})") count += 1 if count >= 5: break # 过滤低分结果 filter_boxes, filter_rec_res = [], [] for box, rec_result in zip(dt_boxes, rec_res): text, score = rec_result if score >= self.drop_score: filter_boxes.append(box) filter_rec_res.append(rec_result) return filter_boxes, filter_rec_res def visualize( self, img: np.ndarray, ocr_results: list, output_path: Optional[str] = None, show_text: bool = True, show_confidence: bool = True, font_scale: float = 0.5, thickness: int = 2 ) -> np.ndarray: """ 可视化 OCR 检测结果 Args: img: 原始图像(BGR 格式) ocr_results: OCR 结果 [[box, (text, conf)], ...] output_path: 输出路径(可选) show_text: 是否显示识别的文字 show_confidence: 是否显示置信度 font_scale: 字体大小 thickness: 边框粗细 Returns: 标注后的图像 """ img_vis = img.copy() if not ocr_results or ocr_results[0] is None: logger.warning("No OCR results to visualize") return img_vis # 颜色映射(根据置信度) def get_color_by_confidence(conf: float) -> tuple: """根据置信度返回颜色 (绿色->黄色->红色)""" if conf >= 0.9: return (0, 255, 0) # 高置信度:绿色 elif conf >= 0.7: return (0, 255, 255) # 中置信度:黄色 else: return (0, 165, 255) # 低置信度:橙色 for idx, result in enumerate(ocr_results[0], 1): box, (text, conf) = result # 转换坐标格式:[[x1,y1], [x2,y2], [x3,y3], [x4,y4]] if isinstance(box, list): points = np.array(box, dtype=np.int32).reshape((-1, 2)) else: points = box.astype(np.int32) # 获取边框颜色 color = get_color_by_confidence(conf) # 绘制多边形边框 cv2.polylines(img_vis, [points], True, color, thickness) # 计算文本位置(左上角) x1, y1 = points[0] # 构造标签 if show_text and show_confidence: label = f"[{idx}] {text} ({conf:.2f})" elif show_text: label = f"[{idx}] {text}" elif show_confidence: label = f"[{idx}] {conf:.2f}" else: label = f"[{idx}]" # 计算标签尺寸 label_size, baseline = cv2.getTextSize( label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness ) label_w, label_h = label_size # 确保标签不超出图像边界 y_label = max(y1 - 5, label_h + 5) # 绘制标签背景(半透明) overlay = img_vis.copy() cv2.rectangle( overlay, (x1, y_label - label_h - 5), (x1 + label_w + 10, y_label + baseline), color, -1 ) cv2.addWeighted(overlay, 0.6, img_vis, 0.4, 0, img_vis) # 绘制标签文字 cv2.putText( img_vis, label, (x1 + 5, y_label - 5), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), # 白色文字 thickness, cv2.LINE_AA ) # 添加统计信息 total_boxes = len(ocr_results[0]) avg_conf = sum(conf for _, (_, conf) in ocr_results[0]) / total_boxes if total_boxes > 0 else 0 stats_text = f"Total: {total_boxes} boxes | Avg Conf: {avg_conf:.3f}" cv2.putText( img_vis, stats_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), # 红色 2, cv2.LINE_AA ) # 保存图像 if output_path: Path(output_path).parent.mkdir(parents=True, exist_ok=True) cv2.imwrite(output_path, img_vis) logger.info(f"✅ Visualization saved to: {output_path}") return img_vis if __name__ == '__main__': from dotenv import load_dotenv load_dotenv(override=True) logger.remove() logger.add(sys.stderr, level="INFO") print("🚀 Testing PytorchPaddleOCR with Orientation Classifier...") try: # ✅ 启用方向分类器 ocr = PytorchPaddleOCR( lang='ch', device='cpu', use_orientation_cls=True, # ✅ 启用 orientation_model_path="/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/PP-LCNet_x1_0_doc_ori.onnx" ) # ✅ 验证字符集 if hasattr(ocr.text_recognizer, 'postprocess_op'): char_count = len(ocr.text_recognizer.postprocess_op.character) print(f"\n📊 Character set loaded: {char_count} characters") # 测试图像列表 test_images = [ # { # 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_001.png", # 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_001_ocr_vis.png" # }, { 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png", 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_003_ocr_vis.png" } ] for img_info in test_images: img_path = img_info['path'] output_path = img_info['output'] print(f"\n{'='*60}") print(f"📄 Processing: {Path(img_path).name}") print(f"{'='*60}") img = cv2.imread(img_path) print(f"📖 Original image: {img.shape}") # np.ndarray: BGR 格式图像 img = check_img(img) # ✅ 使用方向分类器(如果启用) if ocr.orientation_classifier is not None and isinstance(img, np.ndarray): logger.info(f"🔄 Checking image orientation...") # 预测方向 orientation_result = ocr.orientation_classifier.predict(img, return_debug=True) logger.info(f" Orientation result: {orientation_result}") # 如果需要旋转 if orientation_result.needs_rotation: logger.info(f"🔄 Rotating image by {orientation_result.rotation_angle}°...") img = ocr.orientation_classifier.rotate_image(img, orientation_result.rotation_angle) logger.info(f" New shape: {img.shape}") # 执行 OCR(会自动检测并旋转) results = ocr.ocr(img, det=True, rec=True) if results and results[0]: print(f"\n✅ Found {len(results[0])} text regions:") for i, (box, (text, conf)) in enumerate(results[0][:10], 1): print(f" [{i}] {text} (conf={conf:.3f})") # 可视化 img_vis = ocr.visualize(img, results, output_path=output_path, show_text=False, show_confidence=False) # 统计 total = len(results[0]) non_empty = sum(1 for _, (text, _) in results[0] if text) avg_conf = sum(conf for _, (_, conf) in results[0]) / total print(f"\n📊 Statistics:") print(f" Total: {total}") print(f" Non-empty: {non_empty}") print(f" Avg confidence: {avg_conf:.3f}") else: print(f"⚠️ No results found") except Exception as e: print(f"\n❌ Test failed: {e}") import traceback traceback.print_exc()