pytorch_paddle.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. """
  2. PytorchPaddleOCR - 从 MinerU 移植
  3. 提供完整的 OCR 功能(检测 + 识别)
  4. """
  5. import copy
  6. import os
  7. import sys
  8. import warnings
  9. from pathlib import Path
  10. import cv2
  11. import numpy as np
  12. import yaml
  13. from loguru import logger
  14. import argparse
  15. from typing import Optional
  16. # 添加 ocr_platform 根目录到 Python 路径(用于导入 ocr_utils)
  17. ocr_platform_root = Path(__file__).parents[2] # pytorch_models -> ocr_tools -> ocr_platform -> repository.git
  18. if str(ocr_platform_root) not in sys.path:
  19. sys.path.insert(0, str(ocr_platform_root))
  20. # 添加当前目录到 Python 路径(用于相对导入)
  21. current_dir = Path(__file__).resolve().parent
  22. if str(current_dir) not in sys.path:
  23. sys.path.insert(0, str(current_dir))
  24. # 从 ocr_utils 导入设备工具
  25. from ocr_utils.device_utils import get_device
  26. # ✅ 导入方向分类器
  27. from orientation_classifier_v2 import OrientationClassifierV2
  28. # 从 vendor 导入 OCR 工具(vendor 会从 ocr_utils 导入图像处理工具)
  29. from vendor import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
  30. from vendor.infer import TextSystem
  31. from vendor.infer import pytorchocr_utility as utility
  32. latin_lang = [
  33. "af", "az", "bs", "cs", "cy", "da", "de", "es", "et", "fr", "ga", "hr", "hu",
  34. "id", "is", "it", "ku", "la", "lt", "lv", "mi", "ms", "mt", "nl", "no", "oc",
  35. "pi", "pl", "pt", "ro", "rs_latin", "sk", "sl", "sq", "sv", "sw", "tl", "tr",
  36. "uz", "vi", "french", "german", "fi", "eu", "gl", "lb", "rm", "ca", "qu",
  37. ]
  38. arabic_lang = ["ar", "fa", "ug", "ur", "ps", "ku", "sd", "bal"]
  39. cyrillic_lang = [
  40. "ru", "rs_cyrillic", "be", "bg", "uk", "mn", "abq", "ady", "kbd", "ava", "dar",
  41. "inh", "che", "lbe", "lez", "tab", "kk", "ky", "tg", "mk", "tt", "cv", "ba",
  42. "mhr", "mo", "udm", "kv", "os", "bua", "xal", "tyv", "sah", "kaa",
  43. ]
  44. east_slavic_lang = ["ru", "be", "uk"]
  45. devanagari_lang = [
  46. "hi", "mr", "ne", "bh", "mai", "ang", "bho", "mah", "sck", "new", "gom", "sa", "bgc",
  47. ]
  48. def get_model_params(lang, config):
  49. """从配置文件获取模型参数"""
  50. if lang in config['lang']:
  51. params = config['lang'][lang]
  52. det = params.get('det')
  53. rec = params.get('rec')
  54. dict_file = params.get('dict')
  55. return det, rec, dict_file
  56. else:
  57. raise Exception(f'Language {lang} not supported')
  58. def auto_download_and_get_model_root_path(model_path):
  59. """
  60. 模拟 MinerU 的模型下载逻辑
  61. """
  62. modelscope_cache = os.getenv('MODELSCOPE_CACHE_DIR', str(Path.home() / '.cache' / 'modelscope'))
  63. return modelscope_cache
  64. # 当前文件所在目录
  65. root_dir = Path(__file__).resolve().parent
  66. class PytorchPaddleOCR(TextSystem):
  67. """
  68. PytorchPaddleOCR - OCR 引擎
  69. 继承 TextSystem,提供完整的 OCR 功能
  70. """
  71. def __init__(self, *args, **kwargs):
  72. """初始化 OCR 引擎
  73. Args:
  74. lang (str): 语言 ('ch', 'en', 'latin', 'korean', 'japan', 等)
  75. det_db_thresh (float): 检测二值化阈值
  76. det_db_box_thresh (float): 检测框过滤阈值
  77. rec_batch_num (int): 识别批大小
  78. enable_merge_det_boxes (bool): 是否合并检测框
  79. device (str): 设备 ('cpu', 'cuda:0', 'mps')
  80. det_model_path (str): 自定义检测模型路径
  81. rec_model_path (str): 自定义识别模型路径
  82. rec_char_dict_path (str): 自定义字典路径
  83. """
  84. # 初始化参数解析器
  85. parser = utility.init_args()
  86. args_list = list(args) if args else []
  87. parsed_args = parser.parse_args(args_list)
  88. # 获取语言设置
  89. self.lang = kwargs.get('lang', 'ch')
  90. self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True)
  91. # ✅ 新增:方向分类器配置
  92. self.use_orientation_cls = kwargs.get("use_orientation_cls", False)
  93. self.orientation_model_path = kwargs.get(
  94. "orientation_model_path",
  95. None
  96. )
  97. # 自动检测设备
  98. device = kwargs.get('device', get_device())
  99. # ✅ CPU 优化:自动切换到轻量模型
  100. if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
  101. logger.warning("CPU device detected. Switching to ch_lite for better performance.")
  102. self.lang = 'ch_lite'
  103. # ✅ 语言映射
  104. if self.lang in latin_lang:
  105. self.lang = 'latin'
  106. elif self.lang in east_slavic_lang:
  107. self.lang = 'east_slavic'
  108. elif self.lang in arabic_lang:
  109. self.lang = 'arabic'
  110. elif self.lang in cyrillic_lang:
  111. self.lang = 'cyrillic'
  112. elif self.lang in devanagari_lang:
  113. self.lang = 'devanagari'
  114. # ✅ 读取模型配置
  115. models_config_path = root_dir / 'vendor' / 'pytorchocr' / 'utils' / 'resources' / 'models_config.yml'
  116. if not models_config_path.exists():
  117. raise FileNotFoundError(f"❌ Config file not found: {models_config_path}")
  118. logger.info(f"📄 Reading config: {models_config_path}")
  119. with open(models_config_path) as file:
  120. config = yaml.safe_load(file)
  121. det, rec, dict_file = get_model_params(self.lang, config)
  122. logger.info(f"📋 Config for lang '{self.lang}':")
  123. logger.info(f" Det model: {det}")
  124. logger.info(f" Rec model: {rec}")
  125. logger.info(f" Dict file: {dict_file}")
  126. # ✅ 模型路径(优先使用 kwargs 中的自定义路径)
  127. if 'det_model_path' not in kwargs:
  128. ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch"
  129. det_model_path = f"{ocr_models_dir}/{det}"
  130. det_model_full_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path)
  131. kwargs['det_model_path'] = det_model_full_path
  132. logger.info(f"🔍 Detection model path: {det_model_full_path}")
  133. # ✅ 验证模型文件存在
  134. if not Path(det_model_full_path).exists():
  135. logger.warning(f"⚠️ Detection model file not found: {det_model_full_path}")
  136. else:
  137. logger.info(f"✅ Detection model file exists: {det_model_full_path}")
  138. # ✅ 识别模型路径
  139. if 'rec_model_path' not in kwargs:
  140. ocr_models_dir = "models/OpenDataLab/PDF-Extract-Kit-1.0/models/OCR/paddleocr_torch"
  141. rec_model_path = f"{ocr_models_dir}/{rec}"
  142. rec_model_full_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path)
  143. kwargs['rec_model_path'] = rec_model_full_path
  144. logger.info(f"🔍 Recognition model path: {rec_model_full_path}")
  145. # ✅ 验证模型文件存在
  146. if not Path(rec_model_full_path).exists():
  147. logger.warning(f"⚠️ Recognition model file not found: {rec_model_full_path}")
  148. else:
  149. logger.info(f"✅ Recognition model file exists: {rec_model_full_path}")
  150. # ✅ 字典路径
  151. if 'rec_char_dict_path' not in kwargs:
  152. dict_path = root_dir / 'vendor' / 'pytorchocr' / 'utils' / 'resources' / 'dict' / dict_file
  153. if not dict_path.exists():
  154. logger.error(f"❌ Dictionary file not found: {dict_path}")
  155. raise FileNotFoundError(f"Dictionary file not found: {dict_path}")
  156. kwargs['rec_char_dict_path'] = str(dict_path)
  157. logger.info(f"📖 Dictionary: {dict_path.name}")
  158. # ✅ 默认参数
  159. kwargs.setdefault('rec_batch_num', 6)
  160. kwargs.setdefault('device', device)
  161. # ✅ 合并参数
  162. default_args = vars(parsed_args)
  163. default_args.update(kwargs)
  164. final_args = argparse.Namespace(**default_args)
  165. logger.info(f"🔧 Initializing TextSystem...")
  166. logger.info(f" device: {final_args.device}")
  167. logger.info(f" rec_char_dict_path: {final_args.rec_char_dict_path}")
  168. logger.info(f" rec_batch_num: {final_args.rec_batch_num}")
  169. # ✅ 初始化 TextSystem
  170. super().__init__(final_args)
  171. # ✅ 初始化方向分类器(在 TextSystem 之后,这样可以使用 self.text_detector)
  172. self.orientation_classifier = None
  173. if self.use_orientation_cls and self.orientation_model_path:
  174. try:
  175. logger.info(f"🔄 Initializing orientation classifier...")
  176. logger.info(f" Model: {self.orientation_model_path}")
  177. # ✅ 创建一个简单的检测器适配器
  178. class TextDetectorAdapter:
  179. """适配器:将 TextDetector 包装为 OrientationClassifierV2 需要的接口"""
  180. def __init__(self, text_detector):
  181. self.text_detector = text_detector
  182. def ocr(self, img, det=True, rec=False):
  183. """执行文本检测"""
  184. if not det:
  185. return None
  186. try:
  187. # 调用检测器
  188. dt_boxes, _ = self.text_detector(img)
  189. if dt_boxes is None or len(dt_boxes) == 0:
  190. return [[]]
  191. # 返回格式: [[box1], [box2], ...]
  192. return [dt_boxes.tolist()]
  193. except Exception as e:
  194. logger.warning(f"Detection failed in adapter: {e}")
  195. return [[]]
  196. # ✅ 传入检测器适配器
  197. detector_adapter = TextDetectorAdapter(self.text_detector)
  198. self.orientation_classifier = OrientationClassifierV2(
  199. model_path=self.orientation_model_path,
  200. text_detector=detector_adapter, # ✅ 传入检测器
  201. aspect_ratio_threshold=kwargs.get("aspect_ratio_threshold", 1.2),
  202. vertical_text_ratio=kwargs.get("vertical_text_ratio", 0.28),
  203. vertical_text_min_count=kwargs.get("vertical_text_min_count", 3),
  204. use_gpu=kwargs.get("device", "cpu") != "cpu"
  205. )
  206. logger.info(f"✅ Orientation classifier initialized")
  207. except Exception as e:
  208. logger.warning(f"⚠️ Failed to initialize orientation classifier: {e}")
  209. self.orientation_classifier = None
  210. # ✅ 验证并修复字符集
  211. logger.info(f"🔍 Validating text recognizer...")
  212. if hasattr(self, 'text_recognizer'):
  213. logger.info(f" ✅ text_recognizer exists")
  214. # ❌ text_recognizer 本身没有 character 属性
  215. # ✅ 字符集存储在 postprocess_op.character 中
  216. if hasattr(self.text_recognizer, 'postprocess_op'):
  217. postprocess_op = self.text_recognizer.postprocess_op
  218. logger.info(f" ✅ postprocess_op exists: {type(postprocess_op)}")
  219. if hasattr(postprocess_op, 'character'):
  220. char_count = len(postprocess_op.character)
  221. logger.info(f" ✅ postprocess_op.character exists")
  222. logger.info(f" Character set size: {char_count}")
  223. if char_count == 0:
  224. logger.error(f" ❌ Character set is EMPTY!")
  225. logger.error(f" Reloading dictionary...")
  226. # 强制重新加载字典
  227. dict_path = final_args.rec_char_dict_path
  228. with open(dict_path, 'rb') as f:
  229. lines = f.readlines()
  230. character = []
  231. for line in lines:
  232. line = line.decode('utf-8').strip('\n').strip('\r\n')
  233. character.append(line)
  234. # 更新字符集
  235. postprocess_op.character = character
  236. logger.info(f" ✅ Character set reloaded: {len(character)} chars")
  237. else:
  238. logger.info(f" ✅ Character set loaded successfully")
  239. # 显示前10个字符(调试用)
  240. sample_chars = postprocess_op.character[:10]
  241. logger.info(f" First 10 chars: {sample_chars}")
  242. else:
  243. logger.error(f" ❌ postprocess_op.character NOT found!")
  244. logger.error(f" Available attributes: {dir(postprocess_op)}")
  245. else:
  246. logger.error(f" ❌ postprocess_op NOT found!")
  247. logger.error(f" Available attributes: {dir(self.text_recognizer)}")
  248. else:
  249. logger.error(f" ❌ text_recognizer NOT found!")
  250. logger.info(f"✅ OCR engine initialized")
  251. def ocr(
  252. self,
  253. img,
  254. det=True,
  255. rec=True,
  256. mfd_res=None,
  257. tqdm_enable=False,
  258. tqdm_desc="OCR-rec Predict",
  259. ):
  260. """
  261. 执行 OCR
  262. Args:
  263. img: BGR 图像、图像列表、文件路径或字节流
  264. det (bool): 是否执行检测
  265. rec (bool): 是否执行识别
  266. mfd_res (list): 公式检测结果(用于过滤文本框)
  267. tqdm_enable (bool): 是否显示进度条
  268. tqdm_desc (str): 进度条描述
  269. Returns:
  270. det=True, rec=True: [[[box], (text, conf)], ...]
  271. det=True, rec=False: [[boxes], ...]
  272. det=False, rec=True: [[(text, conf), ...]]
  273. """
  274. assert isinstance(img, (np.ndarray, list, str, bytes))
  275. if isinstance(img, list) and det:
  276. logger.error('When input is a list of images, det must be False')
  277. exit(1)
  278. # np.ndarray: BGR 格式图像
  279. img = check_img(img)
  280. imgs = [img]
  281. with warnings.catch_warnings():
  282. warnings.simplefilter("ignore", category=RuntimeWarning)
  283. if det and rec:
  284. # 检测 + 识别
  285. ocr_res = []
  286. for img in imgs:
  287. img = preprocess_image(img)
  288. dt_boxes, rec_res = self.__call__(img, mfd_res=mfd_res)
  289. if not dt_boxes and not rec_res:
  290. ocr_res.append(None)
  291. continue
  292. tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
  293. ocr_res.append(tmp_res)
  294. return ocr_res
  295. elif det and not rec:
  296. # 仅检测
  297. ocr_res = []
  298. for img in imgs:
  299. img = preprocess_image(img)
  300. dt_boxes, elapse = self.text_detector(img)
  301. if dt_boxes is None:
  302. ocr_res.append(None)
  303. continue
  304. dt_boxes = sorted_boxes(dt_boxes)
  305. if self.enable_merge_det_boxes:
  306. dt_boxes = merge_det_boxes(dt_boxes)
  307. if mfd_res:
  308. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  309. tmp_res = [box.tolist() for box in dt_boxes]
  310. ocr_res.append(tmp_res)
  311. return ocr_res
  312. elif not det and rec:
  313. # 仅识别
  314. ocr_res = []
  315. for img in imgs:
  316. if not isinstance(img, list):
  317. img = preprocess_image(img)
  318. img = [img]
  319. rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable, tqdm_desc=tqdm_desc)
  320. ocr_res.append(rec_res)
  321. return ocr_res
  322. def __call__(self, img, mfd_res=None):
  323. """
  324. 单张图像的 OCR(检测 + 识别)
  325. Args:
  326. img: 预处理后的图像
  327. mfd_res: 公式检测结果
  328. Returns:
  329. (dt_boxes, rec_res): 检测框和识别结果
  330. """
  331. if img is None:
  332. logger.debug("no valid image provided")
  333. return None, None
  334. ori_im = img.copy()
  335. dt_boxes, elapse = self.text_detector(img)
  336. if dt_boxes is None:
  337. logger.debug(f"no dt_boxes found, elapsed: {elapse}")
  338. return None, None
  339. # 排序
  340. dt_boxes = sorted_boxes(dt_boxes)
  341. # 合并相邻框
  342. if self.enable_merge_det_boxes:
  343. dt_boxes = merge_det_boxes(dt_boxes)
  344. # 过滤与公式重叠的框
  345. if mfd_res:
  346. dt_boxes = update_det_boxes(dt_boxes, mfd_res)
  347. # 裁剪文本区域
  348. img_crop_list = []
  349. for bno in range(len(dt_boxes)):
  350. tmp_box = copy.deepcopy(dt_boxes[bno])
  351. img_crop = get_rotate_crop_image(ori_im, tmp_box)
  352. img_crop_list.append(img_crop)
  353. logger.info(f"🔤 Recognizing {len(img_crop_list)} text regions...")
  354. rec_res, elapse = self.text_recognizer(img_crop_list)
  355. # 统计结果
  356. non_empty = sum(1 for text, _ in rec_res if text)
  357. logger.info(f" Found {non_empty}/{len(rec_res)} non-empty results")
  358. if non_empty > 0:
  359. logger.info(f" First 5 non-empty results:")
  360. count = 0
  361. for i, (text, conf) in enumerate(rec_res):
  362. if text:
  363. logger.info(f" [{i+1}] '{text}' (conf={conf:.3f})")
  364. count += 1
  365. if count >= 5:
  366. break
  367. # 过滤低分结果
  368. filter_boxes, filter_rec_res = [], []
  369. for box, rec_result in zip(dt_boxes, rec_res):
  370. text, score = rec_result
  371. if score >= self.drop_score:
  372. filter_boxes.append(box)
  373. filter_rec_res.append(rec_result)
  374. return filter_boxes, filter_rec_res
  375. def visualize(
  376. self,
  377. img: np.ndarray,
  378. ocr_results: list,
  379. output_path: Optional[str] = None,
  380. show_text: bool = True,
  381. show_confidence: bool = True,
  382. font_scale: float = 0.5,
  383. thickness: int = 2
  384. ) -> np.ndarray:
  385. """
  386. 可视化 OCR 检测结果
  387. Args:
  388. img: 原始图像(BGR 格式)
  389. ocr_results: OCR 结果 [[box, (text, conf)], ...]
  390. output_path: 输出路径(可选)
  391. show_text: 是否显示识别的文字
  392. show_confidence: 是否显示置信度
  393. font_scale: 字体大小
  394. thickness: 边框粗细
  395. Returns:
  396. 标注后的图像
  397. """
  398. img_vis = img.copy()
  399. if not ocr_results or ocr_results[0] is None:
  400. logger.warning("No OCR results to visualize")
  401. return img_vis
  402. # 颜色映射(根据置信度)
  403. def get_color_by_confidence(conf: float) -> tuple:
  404. """根据置信度返回颜色 (绿色->黄色->红色)"""
  405. if conf >= 0.9:
  406. return (0, 255, 0) # 高置信度:绿色
  407. elif conf >= 0.7:
  408. return (0, 255, 255) # 中置信度:黄色
  409. else:
  410. return (0, 165, 255) # 低置信度:橙色
  411. for idx, result in enumerate(ocr_results[0], 1):
  412. box, (text, conf) = result
  413. # 转换坐标格式:[[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
  414. if isinstance(box, list):
  415. points = np.array(box, dtype=np.int32).reshape((-1, 2))
  416. else:
  417. points = box.astype(np.int32)
  418. # 获取边框颜色
  419. color = get_color_by_confidence(conf)
  420. # 绘制多边形边框
  421. cv2.polylines(img_vis, [points], True, color, thickness)
  422. # 计算文本位置(左上角)
  423. x1, y1 = points[0]
  424. # 构造标签
  425. if show_text and show_confidence:
  426. label = f"[{idx}] {text} ({conf:.2f})"
  427. elif show_text:
  428. label = f"[{idx}] {text}"
  429. elif show_confidence:
  430. label = f"[{idx}] {conf:.2f}"
  431. else:
  432. label = f"[{idx}]"
  433. # 计算标签尺寸
  434. label_size, baseline = cv2.getTextSize(
  435. label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness
  436. )
  437. label_w, label_h = label_size
  438. # 确保标签不超出图像边界
  439. y_label = max(y1 - 5, label_h + 5)
  440. # 绘制标签背景(半透明)
  441. overlay = img_vis.copy()
  442. cv2.rectangle(
  443. overlay,
  444. (x1, y_label - label_h - 5),
  445. (x1 + label_w + 10, y_label + baseline),
  446. color,
  447. -1
  448. )
  449. cv2.addWeighted(overlay, 0.6, img_vis, 0.4, 0, img_vis)
  450. # 绘制标签文字
  451. cv2.putText(
  452. img_vis,
  453. label,
  454. (x1 + 5, y_label - 5),
  455. cv2.FONT_HERSHEY_SIMPLEX,
  456. font_scale,
  457. (255, 255, 255), # 白色文字
  458. thickness,
  459. cv2.LINE_AA
  460. )
  461. # 添加统计信息
  462. total_boxes = len(ocr_results[0])
  463. avg_conf = sum(conf for _, (_, conf) in ocr_results[0]) / total_boxes if total_boxes > 0 else 0
  464. stats_text = f"Total: {total_boxes} boxes | Avg Conf: {avg_conf:.3f}"
  465. cv2.putText(
  466. img_vis,
  467. stats_text,
  468. (10, 30),
  469. cv2.FONT_HERSHEY_SIMPLEX,
  470. 0.7,
  471. (0, 0, 255), # 红色
  472. 2,
  473. cv2.LINE_AA
  474. )
  475. # 保存图像
  476. if output_path:
  477. Path(output_path).parent.mkdir(parents=True, exist_ok=True)
  478. cv2.imwrite(output_path, img_vis)
  479. logger.info(f"✅ Visualization saved to: {output_path}")
  480. return img_vis
  481. if __name__ == '__main__':
  482. from dotenv import load_dotenv
  483. load_dotenv(override=True)
  484. logger.remove()
  485. logger.add(sys.stderr, level="INFO")
  486. print("🚀 Testing PytorchPaddleOCR with Orientation Classifier...")
  487. try:
  488. # ✅ 启用方向分类器
  489. ocr = PytorchPaddleOCR(
  490. lang='ch',
  491. device='cpu',
  492. use_orientation_cls=True, # ✅ 启用
  493. orientation_model_path="/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/PP-LCNet_x1_0_doc_ori.onnx"
  494. )
  495. # ✅ 验证字符集
  496. if hasattr(ocr.text_recognizer, 'postprocess_op'):
  497. char_count = len(ocr.text_recognizer.postprocess_op.character)
  498. print(f"\n📊 Character set loaded: {char_count} characters")
  499. # 测试图像列表
  500. test_images = [
  501. # {
  502. # 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_001.png",
  503. # 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_001_ocr_vis.png"
  504. # },
  505. {
  506. 'path': "/Users/zhch158/workspace/data/至远彩色印刷工业有限公司/2023年度报告母公司.img/2023年度报告母公司_page_003.png",
  507. 'output': "/Users/zhch158/workspace/repository.git/PaddleX/zhch/sample_data/pytorch_paddle_ocr/page_003_ocr_vis.png"
  508. }
  509. ]
  510. for img_info in test_images:
  511. img_path = img_info['path']
  512. output_path = img_info['output']
  513. print(f"\n{'='*60}")
  514. print(f"📄 Processing: {Path(img_path).name}")
  515. print(f"{'='*60}")
  516. img = cv2.imread(img_path)
  517. print(f"📖 Original image: {img.shape}")
  518. # np.ndarray: BGR 格式图像
  519. img = check_img(img)
  520. # ✅ 使用方向分类器(如果启用)
  521. if ocr.orientation_classifier is not None and isinstance(img, np.ndarray):
  522. logger.info(f"🔄 Checking image orientation...")
  523. # 预测方向
  524. orientation_result = ocr.orientation_classifier.predict(img, return_debug=True)
  525. logger.info(f" Orientation result: {orientation_result}")
  526. # 如果需要旋转
  527. if orientation_result.needs_rotation:
  528. logger.info(f"🔄 Rotating image by {orientation_result.rotation_angle}°...")
  529. img = ocr.orientation_classifier.rotate_image(img, orientation_result.rotation_angle)
  530. logger.info(f" New shape: {img.shape}")
  531. # 执行 OCR(会自动检测并旋转)
  532. results = ocr.ocr(img, det=True, rec=True)
  533. if results and results[0]:
  534. print(f"\n✅ Found {len(results[0])} text regions:")
  535. for i, (box, (text, conf)) in enumerate(results[0][:10], 1):
  536. print(f" [{i}] {text} (conf={conf:.3f})")
  537. # 可视化
  538. img_vis = ocr.visualize(img, results, output_path=output_path, show_text=False, show_confidence=False)
  539. # 统计
  540. total = len(results[0])
  541. non_empty = sum(1 for _, (text, _) in results[0] if text)
  542. avg_conf = sum(conf for _, (_, conf) in results[0]) / total
  543. print(f"\n📊 Statistics:")
  544. print(f" Total: {total}")
  545. print(f" Non-empty: {non_empty}")
  546. print(f" Avg confidence: {avg_conf:.3f}")
  547. else:
  548. print(f"⚠️ No results found")
  549. except Exception as e:
  550. print(f"\n❌ Test failed: {e}")
  551. import traceback
  552. traceback.print_exc()