doc_preprocessor_adapter.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. """
  2. 文档预处理适配器
  3. 使用 MinerU 的方向判断算法,但保留 PaddleX 的模型
  4. """
  5. import sys
  6. from pathlib import Path
  7. from typing import Any, Dict, List, Optional, Union
  8. import numpy as np
  9. import cv2
  10. from paddlex.inference.pipelines.doc_preprocessor.result import DocPreprocessorResult
  11. from paddlex.inference.common.reader import ReadImage
  12. from paddlex.inference.common.batch_sampler import ImageBatchSampler
  13. from paddlex.inference.pipelines.components import rotate_image
  14. class EnhancedDocPreprocessor:
  15. """
  16. 增强版文档预处理器
  17. 核心思路:采用 MinerU 的两阶段方向判断算法
  18. 1. 快速过滤:宽高比判断(纵向图片才需要方向分类)
  19. 2. OCR 引导:检测文本框,判断是否有大量垂直文本
  20. 3. 精确分类:仅对疑似旋转的图片调用分类模型
  21. """
  22. def __init__(
  23. self,
  24. doc_ori_classify_model,
  25. doc_unwarping_model,
  26. ocr_det_model=None, # 🎯 OCR 检测模型(可选)
  27. device: str = "cpu",
  28. use_doc_orientation_classify: bool = True,
  29. use_doc_unwarping: bool = False,
  30. batch_size: int = 1,
  31. ):
  32. """
  33. Args:
  34. doc_ori_classify_model: PaddleX 的方向分类模型
  35. doc_unwarping_model: PaddleX 的文档矫正模型
  36. ocr_det_model: OCR 文本检测模型(用于判断是否需要旋转,可选)
  37. device: 设备类型(cpu/gpu)
  38. use_doc_orientation_classify: 是否使用方向分类
  39. use_doc_unwarping: 是否使用文档矫正
  40. batch_size: 批处理大小
  41. """
  42. self.doc_ori_classify_model = doc_ori_classify_model
  43. self.doc_unwarping_model = doc_unwarping_model
  44. self.device = device
  45. self.use_doc_orientation_classify = use_doc_orientation_classify
  46. self.use_doc_unwarping = use_doc_unwarping
  47. self.batch_size = batch_size
  48. self.img_reader = ReadImage(format="BGR")
  49. self.batch_sampler = ImageBatchSampler(batch_size=batch_size)
  50. # 🎯 MinerU 算法参数
  51. self.portrait_threshold = 1.2 # 宽高比阈值
  52. self.vertical_ratio_threshold = 0.28 # 垂直文本框比例阈值
  53. self.min_vertical_count = 3 # 最少垂直文本框数量
  54. # 🎯 初始化 OCR 检测模型(只初始化一次)
  55. self.ocr_det_model = ocr_det_model
  56. if self.ocr_det_model is None:
  57. self._initialize_ocr_det_model()
  58. print(f"📐 Enhanced DocPreprocessor initialized")
  59. print(f" - Device: {self.device}")
  60. print(f" - Portrait threshold: {self.portrait_threshold}")
  61. print(f" - Vertical ratio threshold: {self.vertical_ratio_threshold}")
  62. print(f" - Min vertical count: {self.min_vertical_count}")
  63. print(f" - OCR detection model: {'✅ Available' if self.ocr_det_model else '❌ Not available'}")
  64. def _initialize_ocr_det_model(self):
  65. """初始化 OCR 检测模型(只执行一次)"""
  66. try:
  67. from paddlex import create_model
  68. print("🔧 Initializing OCR detection model...")
  69. self.ocr_det_model = create_model(
  70. 'PP-OCRv5_server_det',
  71. device=self.device
  72. )
  73. print("✅ OCR detection model initialized successfully")
  74. except Exception as e:
  75. print(f"⚠️ Failed to initialize OCR detection model: {e}")
  76. print(" Will skip OCR-guided filtering")
  77. self.ocr_det_model = None
  78. def _is_portrait_image(self, image: np.ndarray) -> bool:
  79. """判断是否为纵向图片"""
  80. img_height, img_width = image.shape[:2]
  81. aspect_ratio = img_height / img_width if img_width > 0 else 1.0
  82. is_portrait = aspect_ratio > self.portrait_threshold
  83. print(f" 📏 Image size: {img_width}x{img_height}, aspect_ratio: {aspect_ratio:.2f}, is_portrait: {is_portrait}")
  84. return is_portrait
  85. def _detect_vertical_text_boxes(self, image: np.ndarray) -> tuple[int, int]:
  86. """
  87. 检测图片中的垂直文本框
  88. Returns:
  89. (vertical_count, total_count): 垂直文本框数量和总数量
  90. """
  91. if self.ocr_det_model is None:
  92. print(" ⚠️ OCR detection model not available")
  93. return 0, 0
  94. try:
  95. # 🎯 调用 OCR 检测模型
  96. det_results = list(self.ocr_det_model([image]))
  97. if not det_results or len(det_results) == 0:
  98. print(" ℹ️ No OCR detection results")
  99. return 0, 0
  100. det_result = det_results[0]
  101. # 🎯 从检测结果中提取文本框
  102. # PaddleX 的检测结果格式: {"dt_polys": [...], ...}
  103. boxes = None
  104. if isinstance(det_result, dict):
  105. boxes = det_result.get('dt_polys', None)
  106. elif isinstance(det_result, np.ndarray):
  107. boxes = det_result
  108. if boxes is None or len(boxes) == 0:
  109. print(" ℹ️ No text boxes detected")
  110. return 0, 0
  111. # 🎯 统计垂直文本框
  112. vertical_count = 0
  113. total_count = len(boxes)
  114. # 🎯 处理 numpy 数组格式: shape=(N, 4, 2)
  115. if isinstance(boxes, np.ndarray):
  116. if len(boxes.shape) == 3 and boxes.shape[1] == 4 and boxes.shape[2] == 2:
  117. # 格式: (N, 4, 2) - 每个框有4个点,每个点有(x,y)坐标
  118. for box in boxes:
  119. # box: shape=(4, 2) - [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
  120. p1, p2, p3, p4 = box
  121. # 计算宽高
  122. width = abs(float(p2[0] - p1[0])) # x2 - x1
  123. height = abs(float(p3[1] - p2[1])) # y3 - y2
  124. if height == 0:
  125. continue
  126. aspect_ratio = width / height
  127. # 🎯 MinerU 的判断标准:宽高比 < 0.8 为垂直文本
  128. if aspect_ratio < 0.8:
  129. vertical_count += 1
  130. else:
  131. # 其他格式,尝试遍历处理
  132. for box in boxes:
  133. if isinstance(box, np.ndarray) and len(box) >= 4:
  134. self._process_single_box(box, vertical_count)
  135. else:
  136. # 处理列表格式
  137. for box in boxes:
  138. if isinstance(box, (list, tuple, np.ndarray)):
  139. if len(box) >= 4:
  140. # 格式: [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
  141. if isinstance(box[0], (list, tuple, np.ndarray)) and len(box[0]) >= 2:
  142. p1, p2, p3, p4 = box[:4]
  143. width = abs(float(p2[0]) - float(p1[0]))
  144. height = abs(float(p3[1]) - float(p2[1]))
  145. # 格式: [x1,y1,x2,y2,x3,y3,x4,y4]
  146. elif len(box) >= 8:
  147. width = abs(float(box[2]) - float(box[0]))
  148. height = abs(float(box[5]) - float(box[3]))
  149. else:
  150. continue
  151. if height == 0:
  152. continue
  153. aspect_ratio = width / height
  154. # 🎯 MinerU 的判断标准:宽高比 < 0.8 为垂直文本
  155. if aspect_ratio < 0.8:
  156. vertical_count += 1
  157. print(f" 📊 OCR detection: {vertical_count}/{total_count} vertical boxes ({vertical_count/total_count:.1%} vertical)")
  158. return vertical_count, total_count
  159. except Exception as e:
  160. print(f" ⚠️ OCR detection failed: {e}")
  161. import traceback
  162. traceback.print_exc()
  163. return 0, 0
  164. def _should_classify_orientation(self, image: np.ndarray) -> bool:
  165. """
  166. 判断是否需要进行方向分类
  167. 参考 MinerU 的两阶段判断逻辑
  168. Returns:
  169. True: 需要分类
  170. False: 跳过分类(直接使用原图)
  171. """
  172. print("🔍 Checking if orientation classification is needed...")
  173. # 🎯 阶段 1: 快速过滤 - 宽高比检查
  174. if not self._is_portrait_image(image):
  175. print(" ⏭️ Skipped: Image is landscape")
  176. return False
  177. # 🎯 阶段 2: OCR 引导判断 - 检测垂直文本框
  178. vertical_count, total_count = self._detect_vertical_text_boxes(image)
  179. if total_count == 0:
  180. print(" ⏭️ Skipped: No text detected")
  181. return False
  182. # 🎯 MinerU 的判断标准:
  183. # 垂直文本框比例 >= 28% 且数量 >= 3,才认为可能需要旋转
  184. vertical_ratio = vertical_count / total_count
  185. is_rotated = (
  186. vertical_ratio >= self.vertical_ratio_threshold and
  187. vertical_count >= self.min_vertical_count
  188. )
  189. print(f" 📈 Vertical ratio: {vertical_ratio:.1%} (threshold: {self.vertical_ratio_threshold:.1%})")
  190. print(f" 📊 Vertical count: {vertical_count} (min: {self.min_vertical_count})")
  191. print(f" 🎯 Need classification: {is_rotated}")
  192. return is_rotated
  193. def _predict_orientation(self, image: np.ndarray) -> int:
  194. """
  195. 预测图像方向
  196. Args:
  197. image: BGR 格式的图像
  198. Returns:
  199. 旋转角度 (0, 90, 180, 270)
  200. """
  201. if not self.use_doc_orientation_classify or self.doc_ori_classify_model is None:
  202. return 0
  203. try:
  204. # 调用 PaddleX 的分类模型
  205. preds = list(self.doc_ori_classify_model([image]))
  206. if preds and len(preds) > 0:
  207. pred = preds[0]
  208. angle = int(pred["label_names"][0])
  209. print(f" 🔄 Orientation classification result: {angle}°")
  210. return angle
  211. return 0
  212. except Exception as e:
  213. print(f" ⚠️ Orientation prediction failed: {e}")
  214. return 0
  215. def predict(
  216. self,
  217. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  218. use_doc_orientation_classify: Optional[bool] = None,
  219. use_doc_unwarping: Optional[bool] = None,
  220. ):
  221. """
  222. 预测文档预处理结果
  223. Args:
  224. input: 输入图像路径、数组或列表
  225. use_doc_orientation_classify: 是否使用方向分类
  226. use_doc_unwarping: 是否使用文档矫正
  227. Yields:
  228. DocPreprocessorResult: 预处理结果
  229. """
  230. # 处理模型设置
  231. if use_doc_orientation_classify is None:
  232. use_doc_orientation_classify = self.use_doc_orientation_classify
  233. if use_doc_unwarping is None:
  234. use_doc_unwarping = self.use_doc_unwarping
  235. model_settings = {
  236. "use_doc_orientation_classify": use_doc_orientation_classify,
  237. "use_doc_unwarping": use_doc_unwarping,
  238. }
  239. print(f"\n{'='*60}")
  240. print(f"🎯 Enhanced DocPreprocessor - MinerU Algorithm")
  241. print(f" Settings: orientation={use_doc_orientation_classify}, unwarping={use_doc_unwarping}")
  242. print(f"{'='*60}\n")
  243. # 批处理
  244. for batch_data in self.batch_sampler(input):
  245. # 读取图像
  246. image_arrays = self.img_reader(batch_data.instances)
  247. # 🎯 增强的方向分类和旋转逻辑
  248. angles = []
  249. rot_imgs = []
  250. for idx, img in enumerate(image_arrays):
  251. print(f"\n📄 Processing image {idx + 1}/{len(image_arrays)}")
  252. if use_doc_orientation_classify:
  253. # 🎯 关键改进:先判断是否需要分类
  254. if self._should_classify_orientation(img):
  255. # 需要分类:调用模型预测角度
  256. angle = self._predict_orientation(img)
  257. else:
  258. # 跳过分类:直接使用 0 度
  259. angle = 0
  260. print(" ⏭️ Skipped orientation classification")
  261. angles.append(angle)
  262. if angle != 0:
  263. rot_img = rotate_image(img, angle)
  264. else:
  265. rot_img = img
  266. rot_imgs.append(rot_img)
  267. else:
  268. angles.append(-1) # -1 表示未进行方向分类
  269. rot_imgs.append(img)
  270. # 文档矫正
  271. if use_doc_unwarping and self.doc_unwarping_model is not None:
  272. output_imgs = [
  273. item["doctr_img"][:, :, ::-1]
  274. for item in self.doc_unwarping_model(rot_imgs)
  275. ]
  276. else:
  277. output_imgs = rot_imgs
  278. # 生成结果
  279. for input_path, page_index, image_array, angle, rot_img, output_img in zip(
  280. batch_data.input_paths,
  281. batch_data.page_indexes,
  282. image_arrays,
  283. angles,
  284. rot_imgs,
  285. output_imgs,
  286. ):
  287. single_img_res = {
  288. "input_path": input_path,
  289. "page_index": page_index,
  290. "input_img": image_array,
  291. "model_settings": model_settings,
  292. "angle": angle,
  293. "rot_img": rot_img,
  294. "output_img": output_img,
  295. }
  296. yield DocPreprocessorResult(single_img_res)
  297. def __call__(self, *args, **kwargs):
  298. """支持像函数一样调用"""
  299. return self.predict(*args, **kwargs)
  300. class DocPreprocessorAdapter:
  301. """
  302. 文档预处理适配器
  303. 替换 _DocPreprocessorPipeline 的 predict 方法
  304. """
  305. _original_predict = None
  306. _shared_ocr_det_model = None # 🎯 共享的 OCR 检测模型
  307. _enhanced_preprocessor_cache = {} # 🎯 缓存 enhanced_preprocessor 实例
  308. @classmethod
  309. def _get_cache_key(cls, device: str, use_doc_orientation_classify: bool,
  310. use_doc_unwarping: bool, batch_size: int) -> str:
  311. """生成缓存键"""
  312. return f"{device}_{use_doc_orientation_classify}_{use_doc_unwarping}_{batch_size}"
  313. @classmethod
  314. def apply(cls, use_enhanced: bool = True):
  315. """
  316. 应用适配器
  317. Args:
  318. use_enhanced: 是否使用增强版预处理器
  319. """
  320. if not use_enhanced:
  321. cls.restore()
  322. return False
  323. try:
  324. from paddlex.inference.pipelines.doc_preprocessor import pipeline
  325. # 保存原始方法
  326. if cls._original_predict is None:
  327. cls._original_predict = pipeline._DocPreprocessorPipeline.predict
  328. # 创建增强版 predict 方法
  329. def enhanced_predict(
  330. self,
  331. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  332. use_doc_orientation_classify: Optional[bool] = None,
  333. use_doc_unwarping: Optional[bool] = None,
  334. ):
  335. """增强版 predict 方法"""
  336. # 🎯 关键改进 1:初始化共享的 OCR 检测模型(只初始化一次)
  337. if cls._shared_ocr_det_model is None:
  338. print("\n" + "="*80)
  339. print(">>> [Adapter] Enhanced DocPreprocessor - First Time Initialization")
  340. print("="*80)
  341. print("🔧 Initializing shared OCR detection model...")
  342. try:
  343. from paddlex import create_model
  344. cls._shared_ocr_det_model = create_model(
  345. 'PP-OCRv5_server_det',
  346. device=self.device
  347. )
  348. print("✅ Shared OCR detection model initialized")
  349. except Exception as e:
  350. print(f"⚠️ Failed to initialize OCR detection model: {e}")
  351. cls._shared_ocr_det_model = None
  352. # 🎯 关键改进 2:使用缓存的 enhanced_preprocessor(只创建一次)
  353. cache_key = cls._get_cache_key(
  354. device=self.device,
  355. use_doc_orientation_classify=self.use_doc_orientation_classify,
  356. use_doc_unwarping=self.use_doc_unwarping,
  357. batch_size=self.batch_sampler.batch_size
  358. )
  359. if cache_key not in cls._enhanced_preprocessor_cache:
  360. print("🔧 Creating new enhanced preprocessor instance...")
  361. enhanced_preprocessor = EnhancedDocPreprocessor(
  362. doc_ori_classify_model=self.doc_ori_classify_model if self.use_doc_orientation_classify else None,
  363. doc_unwarping_model=self.doc_unwarping_model if self.use_doc_unwarping else None,
  364. ocr_det_model=cls._shared_ocr_det_model, # 使用共享的模型
  365. device=self.device,
  366. use_doc_orientation_classify=self.use_doc_orientation_classify,
  367. use_doc_unwarping=self.use_doc_unwarping,
  368. batch_size=self.batch_sampler.batch_size,
  369. )
  370. cls._enhanced_preprocessor_cache[cache_key] = enhanced_preprocessor
  371. print(f"✅ Enhanced preprocessor cached with key: {cache_key}")
  372. else:
  373. enhanced_preprocessor = cls._enhanced_preprocessor_cache[cache_key]
  374. print(f"♻️ Reusing cached enhanced preprocessor: {cache_key}")
  375. # 调用增强版处理逻辑
  376. return enhanced_preprocessor.predict(
  377. input,
  378. use_doc_orientation_classify,
  379. use_doc_unwarping,
  380. )
  381. # 替换方法
  382. pipeline._DocPreprocessorPipeline.predict = enhanced_predict
  383. print("✅ DocPreprocessor adapter applied successfully (MinerU algorithm)")
  384. return True
  385. except Exception as e:
  386. print(f"❌ Failed to apply DocPreprocessor adapter: {e}")
  387. import traceback
  388. traceback.print_exc()
  389. return False
  390. @classmethod
  391. def restore(cls):
  392. """恢复原始方法"""
  393. if cls._original_predict is None:
  394. return False
  395. try:
  396. from paddlex.inference.pipelines.doc_preprocessor import pipeline
  397. pipeline._DocPreprocessorPipeline.predict = cls._original_predict
  398. cls._original_predict = None
  399. # 🎯 清理共享资源
  400. cls._shared_ocr_det_model = None
  401. cls._enhanced_preprocessor_cache.clear()
  402. print("✅ DocPreprocessor adapter restored")
  403. return True
  404. except Exception as e:
  405. print(f"❌ Failed to restore DocPreprocessor adapter: {e}")
  406. return False
  407. # 🎯 便捷函数
  408. def apply_enhanced_doc_preprocessor():
  409. """应用增强版文档预处理器"""
  410. return DocPreprocessorAdapter.apply(use_enhanced=True)
  411. def restore_paddlex_doc_preprocessor():
  412. """恢复 PaddleX 原始文档预处理器"""
  413. return DocPreprocessorAdapter.restore()
  414. # 导出
  415. __all__ = [
  416. 'EnhancedDocPreprocessor',
  417. 'DocPreprocessorAdapter',
  418. 'apply_enhanced_doc_preprocessor',
  419. 'restore_paddlex_doc_preprocessor',
  420. ]