Sfoglia il codice sorgente

feat: 添加 PytorchPaddleOCR 模块,提供完整的 OCR 功能(检测 + 识别)

zhch158_admin 2 settimane fa
parent
commit
43ca18e223
1 ha cambiato i file con 625 aggiunte e 0 eliminazioni
  1. 625 0
      zhch/unified_pytorch_models/vendor/pytorch_paddle.py

+ 625 - 0
zhch/unified_pytorch_models/vendor/pytorch_paddle.py

@@ -0,0 +1,625 @@
+"""
+PytorchPaddleOCR - 从 MinerU 移植
+提供完整的 OCR 功能(检测 + 识别)
+"""
+import copy
+import os
+import sys
+import warnings
+from pathlib import Path
+
+import cv2
+import numpy as np
+import yaml
+from loguru import logger
+import argparse
+
+# ✅ 修改导入
+try:
+    from .device_utils import get_device
+except ImportError:
+    from device_utils import get_device
+
+# 当作为脚本运行时,添加父目录到 Python 路径
+current_dir = Path(__file__).resolve().parent
+parent_dir = current_dir.parent
+if str(parent_dir) not in sys.path:
+    sys.path.insert(0, str(parent_dir))
+
+# ✅ 根据运行模式选择导入方式
+try:
+    # 作为模块导入时(from vendor.pytorch_paddle import ...)
+    from .ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
+    from .infer.predict_system import TextSystem
+    from .infer import pytorchocr_utility as utility
+except ImportError:
+    # 作为脚本运行时(python pytorch_paddle.py)
+    from vendor.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
+    from vendor.infer.predict_system import TextSystem
+    from vendor.infer import pytorchocr_utility as utility
+
+latin_lang = [
+    "af", "az", "bs", "cs", "cy", "da", "de", "es", "et", "fr", "ga", "hr", "hu", 
+    "id", "is", "it", "ku", "la", "lt", "lv", "mi", "ms", "mt", "nl", "no", "oc", 
+    "pi", "pl", "pt", "ro", "rs_latin", "sk", "sl", "sq", "sv", "sw", "tl", "tr", 
+    "uz", "vi", "french", "german", "fi", "eu", "gl", "lb", "rm", "ca", "qu",
+]
+arabic_lang = ["ar", "fa", "ug", "ur", "ps", "ku", "sd", "bal"]
+cyrillic_lang = [
+    "ru", "rs_cyrillic", "be", "bg", "uk", "mn", "abq", "ady", "kbd", "ava", "dar", 
+    "inh", "che", "lbe", "lez", "tab", "kk", "ky", "tg", "mk", "tt", "cv", "ba", 
+    "mhr", "mo", "udm", "kv", "os", "bua", "xal", "tyv", "sah", "kaa",
+]
+east_slavic_lang = ["ru", "be", "uk"]
+devanagari_lang = [
+    "hi", "mr", "ne", "bh", "mai", "ang", "bho", "mah", "sck", "new", "gom", "sa", "bgc",
+]
+
+
+def get_model_params(lang, config):
+    """从配置文件获取模型参数"""
+    if lang in config['lang']:
+        params = config['lang'][lang]
+        det = params.get('det')
+        rec = params.get('rec')
+        dict_file = params.get('dict')
+        return det, rec, dict_file
+    else:
+        raise Exception(f'Language {lang} not supported')
+
+
+def auto_download_and_get_model_root_path(model_path):
+    """
+    模拟 MinerU 的模型下载逻辑
+    """
+    modelscope_cache = os.getenv('MODELSCOPE_CACHE_DIR', str(Path.home() / '.cache' / 'modelscope'))
+    return modelscope_cache
+
+
+# 当前文件所在目录
+root_dir = Path(__file__).resolve().parent
+
+
+class PytorchPaddleOCR(TextSystem):
+    """
+    PytorchPaddleOCR - OCR 引擎
+    继承 TextSystem,提供完整的 OCR 功能
+    """
+    
+    def __init__(self, *args, **kwargs):
+        """初始化 OCR 引擎
+        
+        Args:
+            lang (str): 语言 ('ch', 'en', 'latin', 'korean', 'japan', 等)
+            det_db_thresh (float): 检测二值化阈值
+            det_db_box_thresh (float): 检测框过滤阈值
+            rec_batch_num (int): 识别批大小
+            enable_merge_det_boxes (bool): 是否合并检测框
+            device (str): 设备 ('cpu', 'cuda:0', 'mps')
+            det_model_path (str): 自定义检测模型路径
+            rec_model_path (str): 自定义识别模型路径
+            rec_char_dict_path (str): 自定义字典路径
+        """
+        # 初始化参数解析器
+        parser = utility.init_args()
+        args_list = list(args) if args else []
+        parsed_args = parser.parse_args(args_list)
+
+        # 获取语言设置
+        self.lang = kwargs.get('lang', 'ch')
+        self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True)
+
+        # 自动检测设备
+        device = kwargs.get('device', get_device())
+        
+        # ✅ CPU 优化:自动切换到轻量模型
+        if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
+            logger.warning("CPU device detected. Switching to ch_lite for better performance.")
+            self.lang = 'ch_lite'
+
+        # ✅ 语言映射
+        if self.lang in latin_lang:
+            self.lang = 'latin'
+        elif self.lang in east_slavic_lang:
+            self.lang = 'east_slavic'
+        elif self.lang in arabic_lang:
+            self.lang = 'arabic'
+        elif self.lang in cyrillic_lang:
+            self.lang = 'cyrillic'
+        elif self.lang in devanagari_lang:
+            self.lang = 'devanagari'
+
+        # ✅ 读取模型配置
+        models_config_path = root_dir / 'pytorchocr' / 'utils' / 'resources' / 'models_config.yml'
+        
+        if not models_config_path.exists():
+            raise FileNotFoundError(f"❌ Config file not found: {models_config_path}")
+        
+        logger.info(f"📄 Reading config: {models_config_path}")
+        
+        with open(models_config_path) as file:
+            config = yaml.safe_load(file)
+            det, rec, dict_file = get_model_params(self.lang, config)
+            
+            logger.info(f"📋 Config for lang '{self.lang}':")
+            logger.info(f"   Det model: {det}")
+            logger.info(f"   Rec model: {rec}")
+            logger.info(f"   Dict file: {dict_file}")
+
+        # ✅ 模型路径(优先使用 kwargs 中的自定义路径)
+        if 'det_model_path' not in kwargs:
+            ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch"
+            det_model_path = f"{ocr_models_dir}/{det}"
+            det_model_full_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path)
+            kwargs['det_model_path'] = det_model_full_path
+            
+            logger.info(f"🔍 Detection model path: {det_model_full_path}")
+            
+            # ✅ 验证模型文件存在
+            if not Path(det_model_full_path).exists():
+                logger.warning(f"⚠️  Detection model file not found: {det_model_full_path}")
+            else:
+                logger.info(f"✅ Detection model file exists: {det_model_full_path}")
+
+        # ✅ 识别模型路径
+        if 'rec_model_path' not in kwargs:
+            ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch"
+            rec_model_path = f"{ocr_models_dir}/{rec}"
+            rec_model_full_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path)
+            kwargs['rec_model_path'] = rec_model_full_path
+            
+            logger.info(f"🔍 Recognition model path: {rec_model_full_path}")
+            
+            # ✅ 验证模型文件存在
+            if not Path(rec_model_full_path).exists():
+                logger.warning(f"⚠️  Recognition model file not found: {rec_model_full_path}")
+            else:
+                logger.info(f"✅ Recognition model file exists: {rec_model_full_path}")
+
+        # ✅ 字典路径
+        if 'rec_char_dict_path' not in kwargs:
+            dict_path = root_dir / 'pytorchocr' / 'utils' / 'resources' / 'dict' / dict_file
+            
+            if not dict_path.exists():
+                logger.error(f"❌ Dictionary file not found: {dict_path}")
+                raise FileNotFoundError(f"Dictionary file not found: {dict_path}")
+            
+            kwargs['rec_char_dict_path'] = str(dict_path)
+            logger.info(f"📖 Dictionary: {dict_path.name}")
+
+        # ✅ 默认参数
+        kwargs.setdefault('rec_batch_num', 6)
+        kwargs.setdefault('device', device)
+
+        # ✅ 合并参数
+        default_args = vars(parsed_args)
+        default_args.update(kwargs)
+        final_args = argparse.Namespace(**default_args)
+
+        logger.info(f"🔧 Initializing TextSystem...")
+        logger.info(f"   device: {final_args.device}")
+        logger.info(f"   rec_char_dict_path: {final_args.rec_char_dict_path}")
+        logger.info(f"   rec_batch_num: {final_args.rec_batch_num}")
+
+        # ✅ 初始化 TextSystem
+        super().__init__(final_args)
+        
+        # ✅ 验证并修复字符集
+        logger.info(f"🔍 Validating text recognizer...")
+        
+        if hasattr(self, 'text_recognizer'):
+            logger.info(f"   ✅ text_recognizer exists")
+            
+            # ❌ text_recognizer 本身没有 character 属性
+            # ✅ 字符集存储在 postprocess_op.character 中
+            if hasattr(self.text_recognizer, 'postprocess_op'):
+                postprocess_op = self.text_recognizer.postprocess_op
+                logger.info(f"   ✅ postprocess_op exists: {type(postprocess_op)}")
+                
+                if hasattr(postprocess_op, 'character'):
+                    char_count = len(postprocess_op.character)
+                    logger.info(f"   ✅ postprocess_op.character exists")
+                    logger.info(f"   Character set size: {char_count}")
+                    
+                    if char_count == 0:
+                        logger.error(f"   ❌ Character set is EMPTY!")
+                        logger.error(f"   Reloading dictionary...")
+                        
+                        # 强制重新加载字典
+                        dict_path = final_args.rec_char_dict_path
+                        with open(dict_path, 'rb') as f:
+                            lines = f.readlines()
+                            character = []
+                            for line in lines:
+                                line = line.decode('utf-8').strip('\n').strip('\r\n')
+                                character.append(line)
+                        
+                        # 更新字符集
+                        postprocess_op.character = character
+                        logger.info(f"   ✅ Character set reloaded: {len(character)} chars")
+                    else:
+                        logger.info(f"   ✅ Character set loaded successfully")
+                        # 显示前10个字符(调试用)
+                        sample_chars = postprocess_op.character[:10]
+                        logger.info(f"   First 10 chars: {sample_chars}")
+                else:
+                    logger.error(f"   ❌ postprocess_op.character NOT found!")
+                    logger.error(f"   Available attributes: {dir(postprocess_op)}")
+            else:
+                logger.error(f"   ❌ postprocess_op NOT found!")
+                logger.error(f"   Available attributes: {dir(self.text_recognizer)}")
+        else:
+            logger.error(f"   ❌ text_recognizer NOT found!")
+        
+        logger.info(f"✅ OCR engine initialized")
+
+    def ocr(
+        self,
+        img,
+        det=True,
+        rec=True,
+        mfd_res=None,
+        tqdm_enable=False,
+        tqdm_desc="OCR-rec Predict",
+    ):
+        """
+        执行 OCR
+        
+        Args:
+            img: BGR 图像、图像列表、文件路径或字节流
+            det (bool): 是否执行检测
+            rec (bool): 是否执行识别
+            mfd_res (list): 公式检测结果(用于过滤文本框)
+            tqdm_enable (bool): 是否显示进度条
+            tqdm_desc (str): 进度条描述
+            
+        Returns:
+            det=True, rec=True: [[[box], (text, conf)], ...]
+            det=True, rec=False: [[boxes], ...]
+            det=False, rec=True: [[(text, conf), ...]]
+        """
+        assert isinstance(img, (np.ndarray, list, str, bytes))
+        
+        if isinstance(img, list) and det:
+            logger.error('When input is a list of images, det must be False')
+            exit(1)
+        
+        img = check_img(img)
+        imgs = [img]
+        
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore", category=RuntimeWarning)
+            
+            if det and rec:
+                # 检测 + 识别
+                ocr_res = []
+                for img in imgs:
+                    img = preprocess_image(img)
+                    dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
+                    if not dt_boxes and not rec_res:
+                        ocr_res.append(None)
+                        continue
+                    tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
+                    ocr_res.append(tmp_res)
+                return ocr_res
+            
+            elif det and not rec:
+                # 仅检测
+                ocr_res = []
+                for img in imgs:
+                    img = preprocess_image(img)
+                    dt_boxes, elapse = self.text_detector(img)
+                    if dt_boxes is None:
+                        ocr_res.append(None)
+                        continue
+                    dt_boxes = sorted_boxes(dt_boxes)
+                    if self.enable_merge_det_boxes:
+                        dt_boxes = merge_det_boxes(dt_boxes)
+                    if mfd_res:
+                        dt_boxes = update_det_boxes(dt_boxes, mfd_res)
+                    tmp_res = [box.tolist() for box in dt_boxes]
+                    ocr_res.append(tmp_res)
+                return ocr_res
+            
+            elif not det and rec:
+                # 仅识别
+                ocr_res = []
+                for img in imgs:
+                    if not isinstance(img, list):
+                        img = preprocess_image(img)
+                        img = [img]
+                    rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable, tqdm_desc=tqdm_desc)
+                    ocr_res.append(rec_res)
+                return ocr_res
+
+    def __call__(self, img, mfd_res=None):
+        """
+        单张图像的 OCR(检测 + 识别)
+        
+        Args:
+            img: 预处理后的图像
+            mfd_res: 公式检测结果
+            
+        Returns:
+            (dt_boxes, rec_res): 检测框和识别结果
+        """
+        if img is None:
+            logger.debug("no valid image provided")
+            return None, None
+
+        ori_im = img.copy()
+        dt_boxes, elapse = self.text_detector(img)
+
+        if dt_boxes is None:
+            logger.debug(f"no dt_boxes found, elapsed: {elapse}")
+            return None, None
+
+        # 排序
+        dt_boxes = sorted_boxes(dt_boxes)
+
+        # 合并相邻框
+        if self.enable_merge_det_boxes:
+            dt_boxes = merge_det_boxes(dt_boxes)
+
+        # 过滤与公式重叠的框
+        if mfd_res:
+            dt_boxes = update_det_boxes(dt_boxes, mfd_res)
+
+        # 裁剪文本区域
+        img_crop_list = []
+        for bno in range(len(dt_boxes)):
+            tmp_box = copy.deepcopy(dt_boxes[bno])
+            img_crop = get_rotate_crop_image(ori_im, tmp_box)
+            img_crop_list.append(img_crop)
+
+        logger.info(f"🔤 Recognizing {len(img_crop_list)} text regions...")
+        rec_res, elapse = self.text_recognizer(img_crop_list)
+        
+        # 统计结果
+        non_empty = sum(1 for text, _ in rec_res if text)
+        logger.info(f"   Found {non_empty}/{len(rec_res)} non-empty results")
+        
+        if non_empty > 0:
+            logger.info(f"   First 5 non-empty results:")
+            count = 0
+            for i, (text, conf) in enumerate(rec_res):
+                if text:
+                    logger.info(f"      [{i+1}] '{text}' (conf={conf:.3f})")
+                    count += 1
+                    if count >= 5:
+                        break
+
+        # 过滤低分结果
+        filter_boxes, filter_rec_res = [], []
+        for box, rec_result in zip(dt_boxes, rec_res):
+            text, score = rec_result
+            if score >= self.drop_score:
+                filter_boxes.append(box)
+                filter_rec_res.append(rec_result)
+
+        return filter_boxes, filter_rec_res
+
+    def visualize(
+        self,
+        img: np.ndarray,
+        ocr_results: list,
+        output_path: str = None,
+        show_text: bool = True,
+        show_confidence: bool = True,
+        font_scale: float = 0.5,
+        thickness: int = 2
+    ) -> np.ndarray:
+        """
+        可视化 OCR 检测结果
+        
+        Args:
+            img: 原始图像(BGR 格式)
+            ocr_results: OCR 结果 [[box, (text, conf)], ...]
+            output_path: 输出路径(可选)
+            show_text: 是否显示识别的文字
+            show_confidence: 是否显示置信度
+            font_scale: 字体大小
+            thickness: 边框粗细
+            
+        Returns:
+            标注后的图像
+        """
+        img_vis = img.copy()
+        
+        if not ocr_results or ocr_results[0] is None:
+            logger.warning("No OCR results to visualize")
+            return img_vis
+        
+        # 颜色映射(根据置信度)
+        def get_color_by_confidence(conf: float) -> tuple:
+            """根据置信度返回颜色 (绿色->黄色->红色)"""
+            if conf >= 0.9:
+                return (0, 255, 0)      # 高置信度:绿色
+            elif conf >= 0.7:
+                return (0, 255, 255)    # 中置信度:黄色
+            else:
+                return (0, 165, 255)    # 低置信度:橙色
+        
+        for idx, result in enumerate(ocr_results[0], 1):
+            box, (text, conf) = result
+            
+            # 转换坐标格式:[[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
+            if isinstance(box, list):
+                points = np.array(box, dtype=np.int32).reshape((-1, 2))
+            else:
+                points = box.astype(np.int32)
+            
+            # 获取边框颜色
+            color = get_color_by_confidence(conf)
+            
+            # 绘制多边形边框
+            cv2.polylines(img_vis, [points], True, color, thickness)
+            
+            # 计算文本位置(左上角)
+            x1, y1 = points[0]
+            
+            # 构造标签
+            if show_text and show_confidence:
+                label = f"[{idx}] {text} ({conf:.2f})"
+            elif show_text:
+                label = f"[{idx}] {text}"
+            elif show_confidence:
+                label = f"[{idx}] {conf:.2f}"
+            else:
+                label = f"[{idx}]"
+            
+            # 计算标签尺寸
+            label_size, baseline = cv2.getTextSize(
+                label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness
+            )
+            label_w, label_h = label_size
+            
+            # 确保标签不超出图像边界
+            y_label = max(y1 - 5, label_h + 5)
+            
+            # 绘制标签背景(半透明)
+            overlay = img_vis.copy()
+            cv2.rectangle(
+                overlay, 
+                (x1, y_label - label_h - 5), 
+                (x1 + label_w + 10, y_label + baseline), 
+                color, 
+                -1
+            )
+            cv2.addWeighted(overlay, 0.6, img_vis, 0.4, 0, img_vis)
+            
+            # 绘制标签文字
+            cv2.putText(
+                img_vis, 
+                label, 
+                (x1 + 5, y_label - 5), 
+                cv2.FONT_HERSHEY_SIMPLEX, 
+                font_scale, 
+                (255, 255, 255),  # 白色文字
+                thickness,
+                cv2.LINE_AA
+            )
+        
+        # 添加统计信息
+        total_boxes = len(ocr_results[0])
+        avg_conf = sum(conf for _, (_, conf) in ocr_results[0]) / total_boxes if total_boxes > 0 else 0
+        
+        stats_text = f"Total: {total_boxes} boxes | Avg Conf: {avg_conf:.3f}"
+        cv2.putText(
+            img_vis,
+            stats_text,
+            (10, 30),
+            cv2.FONT_HERSHEY_SIMPLEX,
+            0.7,
+            (0, 0, 255),  # 红色
+            2,
+            cv2.LINE_AA
+        )
+        
+        # 保存图像
+        if output_path:
+            Path(output_path).parent.mkdir(parents=True, exist_ok=True)
+            cv2.imwrite(output_path, img_vis)
+            logger.info(f"✅ Visualization saved to: {output_path}")
+        
+        return img_vis
+
+
+if __name__ == '__main__':
+    from dotenv import load_dotenv
+    load_dotenv(override=True)
+    
+    logger.remove()
+    logger.add(sys.stderr, level="INFO")
+ 
+    print("🚀 Testing PytorchPaddleOCR with Visualization...")
+    
+    try:
+        ocr = PytorchPaddleOCR(lang='ch', device='cpu')
+        
+        # ✅ 验证字符集
+        if hasattr(ocr.text_recognizer, 'postprocess_op'):
+            char_count = len(ocr.text_recognizer.postprocess_op.character)
+            print(f"\n📊 Character set loaded: {char_count} characters")
+            if char_count > 0:
+                print(f"   Sample chars: {ocr.text_recognizer.postprocess_op.character[:20]}")
+            else:
+                print(f"   ❌ ERROR: Character set is empty!")
+                sys.exit(1)
+        
+        # 测试图像列表
+        test_images = [
+            {
+                'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_001.png",
+                'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_001_ocr_vis.png"
+            },
+            {
+                'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png",
+                'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_003_ocr_vis.png"
+            }
+        ]
+        
+        for img_info in test_images:
+            img_path = img_info['path']
+            output_path = img_info['output']
+            
+            print(f"\n{'='*60}")
+            print(f"📄 Processing: {Path(img_path).name}")
+            print(f"{'='*60}")
+            
+            if not Path(img_path).exists():
+                print(f"❌ Image not found: {img_path}")
+                continue
+            
+            img = cv2.imread(img_path)
+            
+            if img is None:
+                print(f"❌ Failed to load image")
+                continue
+            
+            print(f"📖 Original image: {img.shape}")
+            
+            # 执行 OCR
+            results = ocr.ocr(img, det=True, rec=True)
+            
+            if results and results[0]:
+                print(f"\n✅ OCR completed! Found {len(results[0])} text regions")
+                
+                # 打印前 10 个结果
+                print(f"\n📝 First 10 results:")
+                for i, (box, (text, conf)) in enumerate(results[0][:10], 1):
+                    print(f"  [{i}] {text} (conf={conf:.3f})")
+                
+                if len(results[0]) > 10:
+                    print(f"  ... and {len(results[0]) - 10} more")
+                
+                # 可视化
+                print(f"\n🎨 Visualizing results...")
+                img_vis = ocr.visualize(
+                    img, 
+                    results, 
+                    output_path=output_path,
+                    show_text=True,
+                    show_confidence=True
+                )
+                
+                # 统计信息
+                total_boxes = len(results[0])
+                avg_conf = sum(conf for _, (_, conf) in results[0]) / total_boxes
+                high_conf = sum(1 for _, (_, conf) in results[0] if conf >= 0.9)
+                low_conf = sum(1 for _, (_, conf) in results[0] if conf < 0.7)
+                non_empty = sum(1 for _, (text, _) in results[0] if text)
+                
+                print(f"\n📊 Statistics:")
+                print(f"  Total boxes: {total_boxes}")
+                print(f"  Non-empty: {non_empty}")
+                print(f"  Average confidence: {avg_conf:.3f}")
+                print(f"  High confidence (≥0.9): {high_conf}")
+                print(f"  Low confidence (<0.7): {low_conf}")
+            else:
+                print(f"⚠️  No results found")
+            
+    except Exception as e:
+        print(f"\n❌ Test failed: {e}")
+        import traceback
+        traceback.print_exc()