| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667 |
- """
- PytorchPaddleOCR - 从 MinerU 移植
- 提供完整的 OCR 功能(检测 + 识别)
- """
- import copy
- import os
- import sys
- import warnings
- from pathlib import Path
- import cv2
- import numpy as np
- import yaml
- from loguru import logger
- import argparse
- from typing import Optional
- from vendor import get_device
- # 当作为脚本运行时,添加父目录到 Python 路径
- current_dir = Path(__file__).resolve().parent
- parent_dir = current_dir.parent
- if str(parent_dir) not in sys.path:
- sys.path.insert(0, str(parent_dir))
- # ✅ 导入方向分类器
- from orientation_classifier_v2 import OrientationClassifierV2
- from vendor import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
- from vendor.infer import TextSystem
- from vendor.infer import pytorchocr_utility as utility
- latin_lang = [
- "af", "az", "bs", "cs", "cy", "da", "de", "es", "et", "fr", "ga", "hr", "hu",
- "id", "is", "it", "ku", "la", "lt", "lv", "mi", "ms", "mt", "nl", "no", "oc",
- "pi", "pl", "pt", "ro", "rs_latin", "sk", "sl", "sq", "sv", "sw", "tl", "tr",
- "uz", "vi", "french", "german", "fi", "eu", "gl", "lb", "rm", "ca", "qu",
- ]
- arabic_lang = ["ar", "fa", "ug", "ur", "ps", "ku", "sd", "bal"]
- cyrillic_lang = [
- "ru", "rs_cyrillic", "be", "bg", "uk", "mn", "abq", "ady", "kbd", "ava", "dar",
- "inh", "che", "lbe", "lez", "tab", "kk", "ky", "tg", "mk", "tt", "cv", "ba",
- "mhr", "mo", "udm", "kv", "os", "bua", "xal", "tyv", "sah", "kaa",
- ]
- east_slavic_lang = ["ru", "be", "uk"]
- devanagari_lang = [
- "hi", "mr", "ne", "bh", "mai", "ang", "bho", "mah", "sck", "new", "gom", "sa", "bgc",
- ]
- def get_model_params(lang, config):
- """从配置文件获取模型参数"""
- if lang in config['lang']:
- params = config['lang'][lang]
- det = params.get('det')
- rec = params.get('rec')
- dict_file = params.get('dict')
- return det, rec, dict_file
- else:
- raise Exception(f'Language {lang} not supported')
- def auto_download_and_get_model_root_path(model_path):
- """
- 模拟 MinerU 的模型下载逻辑
- """
- modelscope_cache = os.getenv('MODELSCOPE_CACHE_DIR', str(Path.home() / '.cache' / 'modelscope'))
- return modelscope_cache
- # 当前文件所在目录
- root_dir = Path(__file__).resolve().parent
- class PytorchPaddleOCR(TextSystem):
- """
- PytorchPaddleOCR - OCR 引擎
- 继承 TextSystem,提供完整的 OCR 功能
- """
-
- def __init__(self, *args, **kwargs):
- """初始化 OCR 引擎
-
- Args:
- lang (str): 语言 ('ch', 'en', 'latin', 'korean', 'japan', 等)
- det_db_thresh (float): 检测二值化阈值
- det_db_box_thresh (float): 检测框过滤阈值
- rec_batch_num (int): 识别批大小
- enable_merge_det_boxes (bool): 是否合并检测框
- device (str): 设备 ('cpu', 'cuda:0', 'mps')
- det_model_path (str): 自定义检测模型路径
- rec_model_path (str): 自定义识别模型路径
- rec_char_dict_path (str): 自定义字典路径
- """
- # 初始化参数解析器
- parser = utility.init_args()
- args_list = list(args) if args else []
- parsed_args = parser.parse_args(args_list)
- # 获取语言设置
- self.lang = kwargs.get('lang', 'ch')
- self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True)
-
- # ✅ 新增:方向分类器配置
- self.use_orientation_cls = kwargs.get("use_orientation_cls", False)
- self.orientation_model_path = kwargs.get(
- "orientation_model_path",
- None
- )
-
- # 自动检测设备
- device = kwargs.get('device', get_device())
-
- # ✅ CPU 优化:自动切换到轻量模型
- if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
- logger.warning("CPU device detected. Switching to ch_lite for better performance.")
- self.lang = 'ch_lite'
- # ✅ 语言映射
- if self.lang in latin_lang:
- self.lang = 'latin'
- elif self.lang in east_slavic_lang:
- self.lang = 'east_slavic'
- elif self.lang in arabic_lang:
- self.lang = 'arabic'
- elif self.lang in cyrillic_lang:
- self.lang = 'cyrillic'
- elif self.lang in devanagari_lang:
- self.lang = 'devanagari'
- # ✅ 读取模型配置
- models_config_path = root_dir / 'vendor' / 'pytorchocr' / 'utils' / 'resources' / 'models_config.yml'
-
- if not models_config_path.exists():
- raise FileNotFoundError(f"❌ Config file not found: {models_config_path}")
-
- logger.info(f"📄 Reading config: {models_config_path}")
-
- with open(models_config_path) as file:
- config = yaml.safe_load(file)
- det, rec, dict_file = get_model_params(self.lang, config)
-
- logger.info(f"📋 Config for lang '{self.lang}':")
- logger.info(f" Det model: {det}")
- logger.info(f" Rec model: {rec}")
- logger.info(f" Dict file: {dict_file}")
- # ✅ 模型路径(优先使用 kwargs 中的自定义路径)
- if 'det_model_path' not in kwargs:
- ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch"
- det_model_path = f"{ocr_models_dir}/{det}"
- det_model_full_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path)
- kwargs['det_model_path'] = det_model_full_path
-
- logger.info(f"🔍 Detection model path: {det_model_full_path}")
-
- # ✅ 验证模型文件存在
- if not Path(det_model_full_path).exists():
- logger.warning(f"⚠️ Detection model file not found: {det_model_full_path}")
- else:
- logger.info(f"✅ Detection model file exists: {det_model_full_path}")
- # ✅ 识别模型路径
- if 'rec_model_path' not in kwargs:
- ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch"
- rec_model_path = f"{ocr_models_dir}/{rec}"
- rec_model_full_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path)
- kwargs['rec_model_path'] = rec_model_full_path
-
- logger.info(f"🔍 Recognition model path: {rec_model_full_path}")
-
- # ✅ 验证模型文件存在
- if not Path(rec_model_full_path).exists():
- logger.warning(f"⚠️ Recognition model file not found: {rec_model_full_path}")
- else:
- logger.info(f"✅ Recognition model file exists: {rec_model_full_path}")
- # ✅ 字典路径
- if 'rec_char_dict_path' not in kwargs:
- dict_path = root_dir / 'vendor' / 'pytorchocr' / 'utils' / 'resources' / 'dict' / dict_file
-
- if not dict_path.exists():
- logger.error(f"❌ Dictionary file not found: {dict_path}")
- raise FileNotFoundError(f"Dictionary file not found: {dict_path}")
-
- kwargs['rec_char_dict_path'] = str(dict_path)
- logger.info(f"📖 Dictionary: {dict_path.name}")
- # ✅ 默认参数
- kwargs.setdefault('rec_batch_num', 6)
- kwargs.setdefault('device', device)
- # ✅ 合并参数
- default_args = vars(parsed_args)
- default_args.update(kwargs)
- final_args = argparse.Namespace(**default_args)
- logger.info(f"🔧 Initializing TextSystem...")
- logger.info(f" device: {final_args.device}")
- logger.info(f" rec_char_dict_path: {final_args.rec_char_dict_path}")
- logger.info(f" rec_batch_num: {final_args.rec_batch_num}")
- # ✅ 初始化 TextSystem
- super().__init__(final_args)
-
- # ✅ 初始化方向分类器(在 TextSystem 之后,这样可以使用 self.text_detector)
- self.orientation_classifier = None
- if self.use_orientation_cls and self.orientation_model_path:
- try:
- logger.info(f"🔄 Initializing orientation classifier...")
- logger.info(f" Model: {self.orientation_model_path}")
-
- # ✅ 创建一个简单的检测器适配器
- class TextDetectorAdapter:
- """适配器:将 TextDetector 包装为 OrientationClassifierV2 需要的接口"""
- def __init__(self, text_detector):
- self.text_detector = text_detector
-
- def ocr(self, img, det=True, rec=False):
- """执行文本检测"""
- if not det:
- return None
-
- try:
- # 调用检测器
- dt_boxes, _ = self.text_detector(img)
-
- if dt_boxes is None or len(dt_boxes) == 0:
- return [[]]
-
- # 返回格式: [[box1], [box2], ...]
- return [dt_boxes.tolist()]
- except Exception as e:
- logger.warning(f"Detection failed in adapter: {e}")
- return [[]]
-
- # ✅ 传入检测器适配器
- detector_adapter = TextDetectorAdapter(self.text_detector)
-
- self.orientation_classifier = OrientationClassifierV2(
- model_path=self.orientation_model_path,
- text_detector=detector_adapter, # ✅ 传入检测器
- aspect_ratio_threshold=kwargs.get("aspect_ratio_threshold", 1.2),
- vertical_text_ratio=kwargs.get("vertical_text_ratio", 0.28),
- vertical_text_min_count=kwargs.get("vertical_text_min_count", 3),
- use_gpu=kwargs.get("device", "cpu") != "cpu"
- )
-
- logger.info(f"✅ Orientation classifier initialized")
-
- except Exception as e:
- logger.warning(f"⚠️ Failed to initialize orientation classifier: {e}")
- self.orientation_classifier = None
-
- # ✅ 验证并修复字符集
- logger.info(f"🔍 Validating text recognizer...")
-
- if hasattr(self, 'text_recognizer'):
- logger.info(f" ✅ text_recognizer exists")
-
- # ❌ text_recognizer 本身没有 character 属性
- # ✅ 字符集存储在 postprocess_op.character 中
- if hasattr(self.text_recognizer, 'postprocess_op'):
- postprocess_op = self.text_recognizer.postprocess_op
- logger.info(f" ✅ postprocess_op exists: {type(postprocess_op)}")
-
- if hasattr(postprocess_op, 'character'):
- char_count = len(postprocess_op.character)
- logger.info(f" ✅ postprocess_op.character exists")
- logger.info(f" Character set size: {char_count}")
-
- if char_count == 0:
- logger.error(f" ❌ Character set is EMPTY!")
- logger.error(f" Reloading dictionary...")
-
- # 强制重新加载字典
- dict_path = final_args.rec_char_dict_path
- with open(dict_path, 'rb') as f:
- lines = f.readlines()
- character = []
- for line in lines:
- line = line.decode('utf-8').strip('\n').strip('\r\n')
- character.append(line)
-
- # 更新字符集
- postprocess_op.character = character
- logger.info(f" ✅ Character set reloaded: {len(character)} chars")
- else:
- logger.info(f" ✅ Character set loaded successfully")
- # 显示前10个字符(调试用)
- sample_chars = postprocess_op.character[:10]
- logger.info(f" First 10 chars: {sample_chars}")
- else:
- logger.error(f" ❌ postprocess_op.character NOT found!")
- logger.error(f" Available attributes: {dir(postprocess_op)}")
- else:
- logger.error(f" ❌ postprocess_op NOT found!")
- logger.error(f" Available attributes: {dir(self.text_recognizer)}")
- else:
- logger.error(f" ❌ text_recognizer NOT found!")
-
- logger.info(f"✅ OCR engine initialized")
- def ocr(
- self,
- img,
- det=True,
- rec=True,
- mfd_res=None,
- tqdm_enable=False,
- tqdm_desc="OCR-rec Predict",
- ):
- """
- 执行 OCR
-
- Args:
- img: BGR 图像、图像列表、文件路径或字节流
- det (bool): 是否执行检测
- rec (bool): 是否执行识别
- mfd_res (list): 公式检测结果(用于过滤文本框)
- tqdm_enable (bool): 是否显示进度条
- tqdm_desc (str): 进度条描述
-
- Returns:
- det=True, rec=True: [[[box], (text, conf)], ...]
- det=True, rec=False: [[boxes], ...]
- det=False, rec=True: [[(text, conf), ...]]
- """
- assert isinstance(img, (np.ndarray, list, str, bytes))
-
- if isinstance(img, list) and det:
- logger.error('When input is a list of images, det must be False')
- exit(1)
-
- # np.ndarray: BGR 格式图像
- img = check_img(img)
-
- imgs = [img]
-
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", category=RuntimeWarning)
-
- if det and rec:
- # 检测 + 识别
- ocr_res = []
- for img in imgs:
- img = preprocess_image(img)
- dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
- if not dt_boxes and not rec_res:
- ocr_res.append(None)
- continue
- tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
- ocr_res.append(tmp_res)
- return ocr_res
-
- elif det and not rec:
- # 仅检测
- ocr_res = []
- for img in imgs:
- img = preprocess_image(img)
- dt_boxes, elapse = self.text_detector(img)
- if dt_boxes is None:
- ocr_res.append(None)
- continue
- dt_boxes = sorted_boxes(dt_boxes)
- if self.enable_merge_det_boxes:
- dt_boxes = merge_det_boxes(dt_boxes)
- if mfd_res:
- dt_boxes = update_det_boxes(dt_boxes, mfd_res)
- tmp_res = [box.tolist() for box in dt_boxes]
- ocr_res.append(tmp_res)
- return ocr_res
-
- elif not det and rec:
- # 仅识别
- ocr_res = []
- for img in imgs:
- if not isinstance(img, list):
- img = preprocess_image(img)
- img = [img]
- rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable, tqdm_desc=tqdm_desc)
- ocr_res.append(rec_res)
- return ocr_res
- def __call__(self, img, mfd_res=None):
- """
- 单张图像的 OCR(检测 + 识别)
-
- Args:
- img: 预处理后的图像
- mfd_res: 公式检测结果
-
- Returns:
- (dt_boxes, rec_res): 检测框和识别结果
- """
- if img is None:
- logger.debug("no valid image provided")
- return None, None
- ori_im = img.copy()
- dt_boxes, elapse = self.text_detector(img)
- if dt_boxes is None:
- logger.debug(f"no dt_boxes found, elapsed: {elapse}")
- return None, None
- # 排序
- dt_boxes = sorted_boxes(dt_boxes)
- # 合并相邻框
- if self.enable_merge_det_boxes:
- dt_boxes = merge_det_boxes(dt_boxes)
- # 过滤与公式重叠的框
- if mfd_res:
- dt_boxes = update_det_boxes(dt_boxes, mfd_res)
- # 裁剪文本区域
- img_crop_list = []
- for bno in range(len(dt_boxes)):
- tmp_box = copy.deepcopy(dt_boxes[bno])
- img_crop = get_rotate_crop_image(ori_im, tmp_box)
- img_crop_list.append(img_crop)
- logger.info(f"🔤 Recognizing {len(img_crop_list)} text regions...")
- rec_res, elapse = self.text_recognizer(img_crop_list)
-
- # 统计结果
- non_empty = sum(1 for text, _ in rec_res if text)
- logger.info(f" Found {non_empty}/{len(rec_res)} non-empty results")
-
- if non_empty > 0:
- logger.info(f" First 5 non-empty results:")
- count = 0
- for i, (text, conf) in enumerate(rec_res):
- if text:
- logger.info(f" [{i+1}] '{text}' (conf={conf:.3f})")
- count += 1
- if count >= 5:
- break
- # 过滤低分结果
- filter_boxes, filter_rec_res = [], []
- for box, rec_result in zip(dt_boxes, rec_res):
- text, score = rec_result
- if score >= self.drop_score:
- filter_boxes.append(box)
- filter_rec_res.append(rec_result)
- return filter_boxes, filter_rec_res
- def visualize(
- self,
- img: np.ndarray,
- ocr_results: list,
- output_path: Optional[str] = None,
- show_text: bool = True,
- show_confidence: bool = True,
- font_scale: float = 0.5,
- thickness: int = 2
- ) -> np.ndarray:
- """
- 可视化 OCR 检测结果
-
- Args:
- img: 原始图像(BGR 格式)
- ocr_results: OCR 结果 [[box, (text, conf)], ...]
- output_path: 输出路径(可选)
- show_text: 是否显示识别的文字
- show_confidence: 是否显示置信度
- font_scale: 字体大小
- thickness: 边框粗细
-
- Returns:
- 标注后的图像
- """
- img_vis = img.copy()
-
- if not ocr_results or ocr_results[0] is None:
- logger.warning("No OCR results to visualize")
- return img_vis
-
- # 颜色映射(根据置信度)
- def get_color_by_confidence(conf: float) -> tuple:
- """根据置信度返回颜色 (绿色->黄色->红色)"""
- if conf >= 0.9:
- return (0, 255, 0) # 高置信度:绿色
- elif conf >= 0.7:
- return (0, 255, 255) # 中置信度:黄色
- else:
- return (0, 165, 255) # 低置信度:橙色
-
- for idx, result in enumerate(ocr_results[0], 1):
- box, (text, conf) = result
-
- # 转换坐标格式:[[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
- if isinstance(box, list):
- points = np.array(box, dtype=np.int32).reshape((-1, 2))
- else:
- points = box.astype(np.int32)
-
- # 获取边框颜色
- color = get_color_by_confidence(conf)
-
- # 绘制多边形边框
- cv2.polylines(img_vis, [points], True, color, thickness)
-
- # 计算文本位置(左上角)
- x1, y1 = points[0]
-
- # 构造标签
- if show_text and show_confidence:
- label = f"[{idx}] {text} ({conf:.2f})"
- elif show_text:
- label = f"[{idx}] {text}"
- elif show_confidence:
- label = f"[{idx}] {conf:.2f}"
- else:
- label = f"[{idx}]"
-
- # 计算标签尺寸
- label_size, baseline = cv2.getTextSize(
- label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness
- )
- label_w, label_h = label_size
-
- # 确保标签不超出图像边界
- y_label = max(y1 - 5, label_h + 5)
-
- # 绘制标签背景(半透明)
- overlay = img_vis.copy()
- cv2.rectangle(
- overlay,
- (x1, y_label - label_h - 5),
- (x1 + label_w + 10, y_label + baseline),
- color,
- -1
- )
- cv2.addWeighted(overlay, 0.6, img_vis, 0.4, 0, img_vis)
-
- # 绘制标签文字
- cv2.putText(
- img_vis,
- label,
- (x1 + 5, y_label - 5),
- cv2.FONT_HERSHEY_SIMPLEX,
- font_scale,
- (255, 255, 255), # 白色文字
- thickness,
- cv2.LINE_AA
- )
-
- # 添加统计信息
- total_boxes = len(ocr_results[0])
- avg_conf = sum(conf for _, (_, conf) in ocr_results[0]) / total_boxes if total_boxes > 0 else 0
-
- stats_text = f"Total: {total_boxes} boxes | Avg Conf: {avg_conf:.3f}"
- cv2.putText(
- img_vis,
- stats_text,
- (10, 30),
- cv2.FONT_HERSHEY_SIMPLEX,
- 0.7,
- (0, 0, 255), # 红色
- 2,
- cv2.LINE_AA
- )
-
- # 保存图像
- if output_path:
- Path(output_path).parent.mkdir(parents=True, exist_ok=True)
- cv2.imwrite(output_path, img_vis)
- logger.info(f"✅ Visualization saved to: {output_path}")
-
- return img_vis
- if __name__ == '__main__':
- from dotenv import load_dotenv
- load_dotenv(override=True)
-
- logger.remove()
- logger.add(sys.stderr, level="INFO")
-
- print("🚀 Testing PytorchPaddleOCR with Orientation Classifier...")
-
- try:
- # ✅ 启用方向分类器
- ocr = PytorchPaddleOCR(
- lang='ch',
- device='cpu',
- use_orientation_cls=True, # ✅ 启用
- orientation_model_path="/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/PP-LCNet_x1_0_doc_ori.onnx"
- )
-
- # ✅ 验证字符集
- if hasattr(ocr.text_recognizer, 'postprocess_op'):
- char_count = len(ocr.text_recognizer.postprocess_op.character)
- print(f"\n📊 Character set loaded: {char_count} characters")
-
- # 测试图像列表
- test_images = [
- # {
- # 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_001.png",
- # 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_001_ocr_vis.png"
- # },
- {
- 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png",
- 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_003_ocr_vis.png"
- }
- ]
-
- for img_info in test_images:
- img_path = img_info['path']
- output_path = img_info['output']
-
- print(f"\n{'='*60}")
- print(f"📄 Processing: {Path(img_path).name}")
- print(f"{'='*60}")
-
- img = cv2.imread(img_path)
- print(f"📖 Original image: {img.shape}")
-
- # np.ndarray: BGR 格式图像
- img = check_img(img)
-
- # ✅ 使用方向分类器(如果启用)
- if ocr.orientation_classifier is not None and isinstance(img, np.ndarray):
- logger.info(f"🔄 Checking image orientation...")
-
- # 预测方向
- orientation_result = ocr.orientation_classifier.predict(img, return_debug=True)
-
- logger.info(f" Orientation result: {orientation_result}")
-
- # 如果需要旋转
- if orientation_result.needs_rotation:
- logger.info(f"🔄 Rotating image by {orientation_result.rotation_angle}°...")
- img = ocr.orientation_classifier.rotate_image(img, orientation_result.rotation_angle)
- logger.info(f" New shape: {img.shape}")
-
- # 执行 OCR(会自动检测并旋转)
- results = ocr.ocr(img, det=True, rec=True)
-
- if results and results[0]:
- print(f"\n✅ Found {len(results[0])} text regions:")
- for i, (box, (text, conf)) in enumerate(results[0][:10], 1):
- print(f" [{i}] {text} (conf={conf:.3f})")
-
- # 可视化
- img_vis = ocr.visualize(img, results, output_path=output_path, show_text=False, show_confidence=False)
-
- # 统计
- total = len(results[0])
- non_empty = sum(1 for _, (text, _) in results[0] if text)
- avg_conf = sum(conf for _, (_, conf) in results[0]) / total
-
- print(f"\n📊 Statistics:")
- print(f" Total: {total}")
- print(f" Non-empty: {non_empty}")
- print(f" Avg confidence: {avg_conf:.3f}")
- else:
- print(f"⚠️ No results found")
-
- except Exception as e:
- print(f"\n❌ Test failed: {e}")
- import traceback
- traceback.print_exc()
|