pytorch_paddle.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. """
  2. PytorchPaddleOCR - 从 MinerU 移植
  3. 提供完整的 OCR 功能(检测 + 识别)
  4. """
  5. import copy
  6. import os
  7. import sys
  8. import warnings
  9. from pathlib import Path
  10. import cv2
  11. import numpy as np
  12. import yaml
  13. from loguru import logger
  14. import argparse
  15. # ✅ 修改导入
  16. try:
  17. from .device_utils import get_device
  18. except ImportError:
  19. from device_utils import get_device
  20. # 当作为脚本运行时,添加父目录到 Python 路径
  21. current_dir = Path(__file__).resolve().parent
  22. parent_dir = current_dir.parent
  23. if str(parent_dir) not in sys.path:
  24. sys.path.insert(0, str(parent_dir))
  25. # ✅ 根据运行模式选择导入方式
  26. try:
  27. # 作为模块导入时(from vendor.pytorch_paddle import ...)
  28. from .ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
  29. from .infer.predict_system import TextSystem
  30. from .infer import pytorchocr_utility as utility
  31. except ImportError:
  32. # 作为脚本运行时(python pytorch_paddle.py)
  33. from vendor.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
  34. from vendor.infer.predict_system import TextSystem
  35. from vendor.infer import pytorchocr_utility as utility
  36. latin_lang = [
  37. "af", "az", "bs", "cs", "cy", "da", "de", "es", "et", "fr", "ga", "hr", "hu",
  38. "id", "is", "it", "ku", "la", "lt", "lv", "mi", "ms", "mt", "nl", "no", "oc",
  39. "pi", "pl", "pt", "ro", "rs_latin", "sk", "sl", "sq", "sv", "sw", "tl", "tr",
  40. "uz", "vi", "french", "german", "fi", "eu", "gl", "lb", "rm", "ca", "qu",
  41. ]
  42. arabic_lang = ["ar", "fa", "ug", "ur", "ps", "ku", "sd", "bal"]
  43. cyrillic_lang = [
  44. "ru", "rs_cyrillic", "be", "bg", "uk", "mn", "abq", "ady", "kbd", "ava", "dar",
  45. "inh", "che", "lbe", "lez", "tab", "kk", "ky", "tg", "mk", "tt", "cv", "ba",
  46. "mhr", "mo", "udm", "kv", "os", "bua", "xal", "tyv", "sah", "kaa",
  47. ]
  48. east_slavic_lang = ["ru", "be", "uk"]
  49. devanagari_lang = [
  50. "hi", "mr", "ne", "bh", "mai", "ang", "bho", "mah", "sck", "new", "gom", "sa", "bgc",
  51. ]
  52. def get_model_params(lang, config):
  53. """从配置文件获取模型参数"""
  54. if lang in config['lang']:
  55. params = config['lang'][lang]
  56. det = params.get('det')
  57. rec = params.get('rec')
  58. dict_file = params.get('dict')
  59. return det, rec, dict_file
  60. else:
  61. raise Exception(f'Language {lang} not supported')
  62. def auto_download_and_get_model_root_path(model_path):
  63. """
  64. 模拟 MinerU 的模型下载逻辑
  65. """
  66. modelscope_cache = os.getenv('MODELSCOPE_CACHE_DIR', str(Path.home() / '.cache' / 'modelscope'))
  67. return modelscope_cache
  68. # 当前文件所在目录
  69. root_dir = Path(__file__).resolve().parent
  70. class PytorchPaddleOCR(TextSystem):
  71. """
  72. PytorchPaddleOCR - OCR 引擎
  73. 继承 TextSystem,提供完整的 OCR 功能
  74. """
  75. def __init__(self, *args, **kwargs):
  76. """初始化 OCR 引擎
  77. Args:
  78. lang (str): 语言 ('ch', 'en', 'latin', 'korean', 'japan', 等)
  79. det_db_thresh (float): 检测二值化阈值
  80. det_db_box_thresh (float): 检测框过滤阈值
  81. rec_batch_num (int): 识别批大小
  82. enable_merge_det_boxes (bool): 是否合并检测框
  83. device (str): 设备 ('cpu', 'cuda:0', 'mps')
  84. det_model_path (str): 自定义检测模型路径
  85. rec_model_path (str): 自定义识别模型路径
  86. rec_char_dict_path (str): 自定义字典路径
  87. """
  88. # 初始化参数解析器
  89. parser = utility.init_args()
  90. args_list = list(args) if args else []
  91. parsed_args = parser.parse_args(args_list)
  92. # 获取语言设置
  93. self.lang = kwargs.get('lang', 'ch')
  94. self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True)
  95. # 自动检测设备
  96. device = kwargs.get('device', get_device())
  97. # ✅ CPU 优化:自动切换到轻量模型
  98. if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
  99. logger.warning("CPU device detected. Switching to ch_lite for better performance.")
  100. self.lang = 'ch_lite'
  101. # ✅ 语言映射
  102. if self.lang in latin_lang:
  103. self.lang = 'latin'
  104. elif self.lang in east_slavic_lang:
  105. self.lang = 'east_slavic'
  106. elif self.lang in arabic_lang:
  107. self.lang = 'arabic'
  108. elif self.lang in cyrillic_lang:
  109. self.lang = 'cyrillic'
  110. elif self.lang in devanagari_lang:
  111. self.lang = 'devanagari'
  112. # ✅ 读取模型配置
  113. models_config_path = root_dir / 'pytorchocr' / 'utils' / 'resources' / 'models_config.yml'
  114. if not models_config_path.exists():
  115. raise FileNotFoundError(f"❌ Config file not found: {models_config_path}")
  116. logger.info(f"📄 Reading config: {models_config_path}")
  117. with open(models_config_path) as file:
  118. config = yaml.safe_load(file)
  119. det, rec, dict_file = get_model_params(self.lang, config)
  120. logger.info(f"📋 Config for lang '{self.lang}':")
  121. logger.info(f" Det model: {det}")
  122. logger.info(f" Rec model: {rec}")
  123. logger.info(f" Dict file: {dict_file}")
  124. # ✅ 模型路径(优先使用 kwargs 中的自定义路径)
  125. if 'det_model_path' not in kwargs:
  126. ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch"
  127. det_model_path = f"{ocr_models_dir}/{det}"
  128. det_model_full_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path)
  129. kwargs['det_model_path'] = det_model_full_path
  130. logger.info(f"🔍 Detection model path: {det_model_full_path}")
  131. # ✅ 验证模型文件存在
  132. if not Path(det_model_full_path).exists():
  133. logger.warning(f"⚠️ Detection model file not found: {det_model_full_path}")
  134. else:
  135. logger.info(f"✅ Detection model file exists: {det_model_full_path}")
  136. # ✅ 识别模型路径
  137. if 'rec_model_path' not in kwargs:
  138. ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch"
  139. rec_model_path = f"{ocr_models_dir}/{rec}"
  140. rec_model_full_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path)
  141. kwargs['rec_model_path'] = rec_model_full_path
  142. logger.info(f"🔍 Recognition model path: {rec_model_full_path}")
  143. # ✅ 验证模型文件存在
  144. if not Path(rec_model_full_path).exists():
  145. logger.warning(f"⚠️ Recognition model file not found: {rec_model_full_path}")
  146. else:
  147. logger.info(f"✅ Recognition model file exists: {rec_model_full_path}")
  148. # ✅ 字典路径
  149. if 'rec_char_dict_path' not in kwargs:
  150. dict_path = root_dir / 'pytorchocr' / 'utils' / 'resources' / 'dict' / dict_file
  151. if not dict_path.exists():
  152. logger.error(f"❌ Dictionary file not found: {dict_path}")
  153. raise FileNotFoundError(f"Dictionary file not found: {dict_path}")
  154. kwargs['rec_char_dict_path'] = str(dict_path)
  155. logger.info(f"📖 Dictionary: {dict_path.name}")
  156. # ✅ 默认参数
  157. kwargs.setdefault('rec_batch_num', 6)
  158. kwargs.setdefault('device', device)
  159. # ✅ 合并参数
  160. default_args = vars(parsed_args)
  161. default_args.update(kwargs)
  162. final_args = argparse.Namespace(**default_args)
  163. logger.info(f"🔧 Initializing TextSystem...")
  164. logger.info(f" device: {final_args.device}")
  165. logger.info(f" rec_char_dict_path: {final_args.rec_char_dict_path}")
  166. logger.info(f" rec_batch_num: {final_args.rec_batch_num}")
  167. # ✅ 初始化 TextSystem
  168. super().__init__(final_args)
  169. # ✅ 验证并修复字符集
  170. logger.info(f"🔍 Validating text recognizer...")
  171. if hasattr(self, 'text_recognizer'):
  172. logger.info(f" ✅ text_recognizer exists")
  173. # ❌ text_recognizer 本身没有 character 属性
  174. # ✅ 字符集存储在 postprocess_op.character 中
  175. if hasattr(self.text_recognizer, 'postprocess_op'):
  176. postprocess_op = self.text_recognizer.postprocess_op
  177. logger.info(f" ✅ postprocess_op exists: {type(postprocess_op)}")
  178. if hasattr(postprocess_op, 'character'):
  179. char_count = len(postprocess_op.character)
  180. logger.info(f" ✅ postprocess_op.character exists")
  181. logger.info(f" Character set size: {char_count}")
  182. if char_count == 0:
  183. logger.error(f" ❌ Character set is EMPTY!")
  184. logger.error(f" Reloading dictionary...")
  185. # 强制重新加载字典
  186. dict_path = final_args.rec_char_dict_path
  187. with open(dict_path, 'rb') as f:
  188. lines = f.readlines()
  189. character = []
  190. for line in lines:
  191. line = line.decode('utf-8').strip('\n').strip('\r\n')
  192. character.append(line)
  193. # 更新字符集
  194. postprocess_op.character = character
  195. logger.info(f" ✅ Character set reloaded: {len(character)} chars")
  196. else:
  197. logger.info(f" ✅ Character set loaded successfully")
  198. # 显示前10个字符(调试用)
  199. sample_chars = postprocess_op.character[:10]
  200. logger.info(f" First 10 chars: {sample_chars}")
  201. else:
  202. logger.error(f" ❌ postprocess_op.character NOT found!")
  203. logger.error(f" Available attributes: {dir(postprocess_op)}")
  204. else:
  205. logger.error(f" ❌ postprocess_op NOT found!")
  206. logger.error(f" Available attributes: {dir(self.text_recognizer)}")
  207. else:
  208. logger.error(f" ❌ text_recognizer NOT found!")
  209. logger.info(f"✅ OCR engine initialized")
  210. def ocr(
  211. self,
  212. img,
  213. det=True,
  214. rec=True,
  215. mfd_res=None,
  216. tqdm_enable=False,
  217. tqdm_desc="OCR-rec Predict",
  218. ):
  219. """
  220. 执行 OCR
  221. Args:
  222. img: BGR 图像、图像列表、文件路径或字节流
  223. det (bool): 是否执行检测
  224. rec (bool): 是否执行识别
  225. mfd_res (list): 公式检测结果(用于过滤文本框)
  226. tqdm_enable (bool): 是否显示进度条
  227. tqdm_desc (str): 进度条描述
  228. Returns:
  229. det=True, rec=True: [[[box], (text, conf)], ...]
  230. det=True, rec=False: [[boxes], ...]
  231. det=False, rec=True: [[(text, conf), ...]]
  232. """
  233. assert isinstance(img, (np.ndarray, list, str, bytes))
  234. if isinstance(img, list) and det:
  235. logger.error('When input is a list of images, det must be False')
  236. exit(1)
  237. img = check_img(img)
  238. imgs = [img]
  239. with warnings.catch_warnings():
  240. warnings.simplefilter("ignore", category=RuntimeWarning)
  241. if det and rec:
  242. # 检测 + 识别
  243. ocr_res = []
  244. for img in imgs:
  245. img = preprocess_image(img)
  246. dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
  247. if not dt_boxes and not rec_res:
  248. ocr_res.append(None)
  249. continue
  250. tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
  251. ocr_res.append(tmp_res)
  252. return ocr_res
  253. elif det and not rec:
  254. # 仅检测
  255. ocr_res = []
  256. for img in imgs:
  257. img = preprocess_image(img)
  258. dt_boxes, elapse = self.text_detector(img)
  259. if dt_boxes is None:
  260. ocr_res.append(None)
  261. continue
  262. dt_boxes = sorted_boxes(dt_boxes)
  263. if self.enable_merge_det_boxes:
  264. dt_boxes = merge_det_boxes(dt_boxes)
  265. if mfd_res:
  266. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  267. tmp_res = [box.tolist() for box in dt_boxes]
  268. ocr_res.append(tmp_res)
  269. return ocr_res
  270. elif not det and rec:
  271. # 仅识别
  272. ocr_res = []
  273. for img in imgs:
  274. if not isinstance(img, list):
  275. img = preprocess_image(img)
  276. img = [img]
  277. rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable, tqdm_desc=tqdm_desc)
  278. ocr_res.append(rec_res)
  279. return ocr_res
  280. def __call__(self, img, mfd_res=None):
  281. """
  282. 单张图像的 OCR(检测 + 识别)
  283. Args:
  284. img: 预处理后的图像
  285. mfd_res: 公式检测结果
  286. Returns:
  287. (dt_boxes, rec_res): 检测框和识别结果
  288. """
  289. if img is None:
  290. logger.debug("no valid image provided")
  291. return None, None
  292. ori_im = img.copy()
  293. dt_boxes, elapse = self.text_detector(img)
  294. if dt_boxes is None:
  295. logger.debug(f"no dt_boxes found, elapsed: {elapse}")
  296. return None, None
  297. # 排序
  298. dt_boxes = sorted_boxes(dt_boxes)
  299. # 合并相邻框
  300. if self.enable_merge_det_boxes:
  301. dt_boxes = merge_det_boxes(dt_boxes)
  302. # 过滤与公式重叠的框
  303. if mfd_res:
  304. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  305. # 裁剪文本区域
  306. img_crop_list = []
  307. for bno in range(len(dt_boxes)):
  308. tmp_box = copy.deepcopy(dt_boxes[bno])
  309. img_crop = get_rotate_crop_image(ori_im, tmp_box)
  310. img_crop_list.append(img_crop)
  311. logger.info(f"🔤 Recognizing {len(img_crop_list)} text regions...")
  312. rec_res, elapse = self.text_recognizer(img_crop_list)
  313. # 统计结果
  314. non_empty = sum(1 for text, _ in rec_res if text)
  315. logger.info(f" Found {non_empty}/{len(rec_res)} non-empty results")
  316. if non_empty > 0:
  317. logger.info(f" First 5 non-empty results:")
  318. count = 0
  319. for i, (text, conf) in enumerate(rec_res):
  320. if text:
  321. logger.info(f" [{i+1}] '{text}' (conf={conf:.3f})")
  322. count += 1
  323. if count >= 5:
  324. break
  325. # 过滤低分结果
  326. filter_boxes, filter_rec_res = [], []
  327. for box, rec_result in zip(dt_boxes, rec_res):
  328. text, score = rec_result
  329. if score >= self.drop_score:
  330. filter_boxes.append(box)
  331. filter_rec_res.append(rec_result)
  332. return filter_boxes, filter_rec_res
  333. def visualize(
  334. self,
  335. img: np.ndarray,
  336. ocr_results: list,
  337. output_path: str = None,
  338. show_text: bool = True,
  339. show_confidence: bool = True,
  340. font_scale: float = 0.5,
  341. thickness: int = 2
  342. ) -> np.ndarray:
  343. """
  344. 可视化 OCR 检测结果
  345. Args:
  346. img: 原始图像(BGR 格式)
  347. ocr_results: OCR 结果 [[box, (text, conf)], ...]
  348. output_path: 输出路径(可选)
  349. show_text: 是否显示识别的文字
  350. show_confidence: 是否显示置信度
  351. font_scale: 字体大小
  352. thickness: 边框粗细
  353. Returns:
  354. 标注后的图像
  355. """
  356. img_vis = img.copy()
  357. if not ocr_results or ocr_results[0] is None:
  358. logger.warning("No OCR results to visualize")
  359. return img_vis
  360. # 颜色映射(根据置信度)
  361. def get_color_by_confidence(conf: float) -> tuple:
  362. """根据置信度返回颜色 (绿色->黄色->红色)"""
  363. if conf >= 0.9:
  364. return (0, 255, 0) # 高置信度:绿色
  365. elif conf >= 0.7:
  366. return (0, 255, 255) # 中置信度:黄色
  367. else:
  368. return (0, 165, 255) # 低置信度:橙色
  369. for idx, result in enumerate(ocr_results[0], 1):
  370. box, (text, conf) = result
  371. # 转换坐标格式:[[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
  372. if isinstance(box, list):
  373. points = np.array(box, dtype=np.int32).reshape((-1, 2))
  374. else:
  375. points = box.astype(np.int32)
  376. # 获取边框颜色
  377. color = get_color_by_confidence(conf)
  378. # 绘制多边形边框
  379. cv2.polylines(img_vis, [points], True, color, thickness)
  380. # 计算文本位置(左上角)
  381. x1, y1 = points[0]
  382. # 构造标签
  383. if show_text and show_confidence:
  384. label = f"[{idx}] {text} ({conf:.2f})"
  385. elif show_text:
  386. label = f"[{idx}] {text}"
  387. elif show_confidence:
  388. label = f"[{idx}] {conf:.2f}"
  389. else:
  390. label = f"[{idx}]"
  391. # 计算标签尺寸
  392. label_size, baseline = cv2.getTextSize(
  393. label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness
  394. )
  395. label_w, label_h = label_size
  396. # 确保标签不超出图像边界
  397. y_label = max(y1 - 5, label_h + 5)
  398. # 绘制标签背景(半透明)
  399. overlay = img_vis.copy()
  400. cv2.rectangle(
  401. overlay,
  402. (x1, y_label - label_h - 5),
  403. (x1 + label_w + 10, y_label + baseline),
  404. color,
  405. -1
  406. )
  407. cv2.addWeighted(overlay, 0.6, img_vis, 0.4, 0, img_vis)
  408. # 绘制标签文字
  409. cv2.putText(
  410. img_vis,
  411. label,
  412. (x1 + 5, y_label - 5),
  413. cv2.FONT_HERSHEY_SIMPLEX,
  414. font_scale,
  415. (255, 255, 255), # 白色文字
  416. thickness,
  417. cv2.LINE_AA
  418. )
  419. # 添加统计信息
  420. total_boxes = len(ocr_results[0])
  421. avg_conf = sum(conf for _, (_, conf) in ocr_results[0]) / total_boxes if total_boxes > 0 else 0
  422. stats_text = f"Total: {total_boxes} boxes | Avg Conf: {avg_conf:.3f}"
  423. cv2.putText(
  424. img_vis,
  425. stats_text,
  426. (10, 30),
  427. cv2.FONT_HERSHEY_SIMPLEX,
  428. 0.7,
  429. (0, 0, 255), # 红色
  430. 2,
  431. cv2.LINE_AA
  432. )
  433. # 保存图像
  434. if output_path:
  435. Path(output_path).parent.mkdir(parents=True, exist_ok=True)
  436. cv2.imwrite(output_path, img_vis)
  437. logger.info(f"✅ Visualization saved to: {output_path}")
  438. return img_vis
  439. if __name__ == '__main__':
  440. from dotenv import load_dotenv
  441. load_dotenv(override=True)
  442. logger.remove()
  443. logger.add(sys.stderr, level="INFO")
  444. print("🚀 Testing PytorchPaddleOCR with Visualization...")
  445. try:
  446. ocr = PytorchPaddleOCR(lang='ch', device='cpu')
  447. # ✅ 验证字符集
  448. if hasattr(ocr.text_recognizer, 'postprocess_op'):
  449. char_count = len(ocr.text_recognizer.postprocess_op.character)
  450. print(f"\n📊 Character set loaded: {char_count} characters")
  451. if char_count > 0:
  452. print(f" Sample chars: {ocr.text_recognizer.postprocess_op.character[:20]}")
  453. else:
  454. print(f" ❌ ERROR: Character set is empty!")
  455. sys.exit(1)
  456. # 测试图像列表
  457. test_images = [
  458. {
  459. 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_001.png",
  460. 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_001_ocr_vis.png"
  461. },
  462. {
  463. 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png",
  464. 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_003_ocr_vis.png"
  465. }
  466. ]
  467. for img_info in test_images:
  468. img_path = img_info['path']
  469. output_path = img_info['output']
  470. print(f"\n{'='*60}")
  471. print(f"📄 Processing: {Path(img_path).name}")
  472. print(f"{'='*60}")
  473. if not Path(img_path).exists():
  474. print(f"❌ Image not found: {img_path}")
  475. continue
  476. img = cv2.imread(img_path)
  477. if img is None:
  478. print(f"❌ Failed to load image")
  479. continue
  480. print(f"📖 Original image: {img.shape}")
  481. # 执行 OCR
  482. results = ocr.ocr(img, det=True, rec=True)
  483. if results and results[0]:
  484. print(f"\n✅ OCR completed! Found {len(results[0])} text regions")
  485. # 打印前 10 个结果
  486. print(f"\n📝 First 10 results:")
  487. for i, (box, (text, conf)) in enumerate(results[0][:10], 1):
  488. print(f" [{i}] {text} (conf={conf:.3f})")
  489. if len(results[0]) > 10:
  490. print(f" ... and {len(results[0]) - 10} more")
  491. # 可视化
  492. print(f"\n🎨 Visualizing results...")
  493. img_vis = ocr.visualize(
  494. img,
  495. results,
  496. output_path=output_path,
  497. show_text=True,
  498. show_confidence=True
  499. )
  500. # 统计信息
  501. total_boxes = len(results[0])
  502. avg_conf = sum(conf for _, (_, conf) in results[0]) / total_boxes
  503. high_conf = sum(1 for _, (_, conf) in results[0] if conf >= 0.9)
  504. low_conf = sum(1 for _, (_, conf) in results[0] if conf < 0.7)
  505. non_empty = sum(1 for _, (text, _) in results[0] if text)
  506. print(f"\n📊 Statistics:")
  507. print(f" Total boxes: {total_boxes}")
  508. print(f" Non-empty: {non_empty}")
  509. print(f" Average confidence: {avg_conf:.3f}")
  510. print(f" High confidence (≥0.9): {high_conf}")
  511. print(f" Low confidence (<0.7): {low_conf}")
  512. else:
  513. print(f"⚠️ No results found")
  514. except Exception as e:
  515. print(f"\n❌ Test failed: {e}")
  516. import traceback
  517. traceback.print_exc()