paddle_layout_detector.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679
  1. """使用 ONNX Runtime 进行布局检测的统一接口 (符合 BaseLayoutDetector 规范)"""
  2. import cv2
  3. import numpy as np
  4. import onnxruntime as ort
  5. from pathlib import Path
  6. from typing import Dict, List, Tuple, Union, Any
  7. from PIL import Image
  8. import sys
  9. try:
  10. from .base import BaseLayoutDetector
  11. except ImportError:
  12. # 如果相对导入失败,尝试绝对导入(适用于测试环境)
  13. from base import BaseLayoutDetector
  14. class PaddleLayoutDetector(BaseLayoutDetector):
  15. """PaddleX RT-DETR 布局检测器 (ONNX 版本)"""
  16. # ⚠️ 修正:使用官方的 RT-DETR-H_layout_17cls 类别定义
  17. # 映射到 MinerU 的类别体系
  18. CATEGORY_MAP = {
  19. 0: 'title', # paragraph_title -> title
  20. 1: 'image_body', # image -> image_body
  21. 2: 'text', # text -> text
  22. 3: 'text', # number -> text (合并到text)
  23. 4: 'text', # abstract -> text
  24. 5: 'text', # content -> text
  25. 6: 'image_caption', # figure_title -> image_caption
  26. 7: 'interline_equation', # formula -> interline_equation
  27. 8: 'table_body', # table -> table_body
  28. 9: 'table_caption', # table_title -> table_caption
  29. 10: 'text', # reference -> text
  30. 11: 'title', # doc_title -> title
  31. 12: 'table_footnote', # footnote -> table_footnote
  32. 13: 'abandon', # header -> abandon (页眉通常不需要)
  33. 14: 'text', # algorithm -> text
  34. 15: 'abandon', # footer -> abandon (页脚通常不需要)
  35. 16: 'abandon' # seal -> abandon (印章通常不需要)
  36. }
  37. ORIGINAL_CATEGORY_NAMES = {
  38. 0: 'paragraph_title',
  39. 1: 'image',
  40. 2: 'text',
  41. 3: 'number',
  42. 4: 'abstract',
  43. 5: 'content',
  44. 6: 'figure_title',
  45. 7: 'formula',
  46. 8: 'table',
  47. 9: 'table_title',
  48. 10: 'reference',
  49. 11: 'doc_title',
  50. 12: 'footnote',
  51. 13: 'header',
  52. 14: 'algorithm',
  53. 15: 'footer',
  54. 16: 'seal'
  55. }
  56. def __init__(self, config: Dict[str, Any]):
  57. super().__init__(config)
  58. self.session = None
  59. self.inputs = {}
  60. self.outputs = {}
  61. self.target_size = 640
  62. def initialize(self):
  63. """初始化 ONNX 模型"""
  64. try:
  65. onnx_path = self.config.get('model_dir')
  66. if not onnx_path:
  67. raise ValueError("model_dir not specified in config")
  68. if not Path(onnx_path).exists():
  69. raise FileNotFoundError(f"ONNX model not found: {onnx_path}")
  70. # 根据配置选择执行提供器
  71. device = self.config.get('device', 'cpu')
  72. if device == 'gpu':
  73. # Mac 支持 CoreML
  74. providers = ['CoreMLExecutionProvider', 'CPUExecutionProvider']
  75. else:
  76. providers = ['CPUExecutionProvider']
  77. self.session = ort.InferenceSession(onnx_path, providers=providers)
  78. # 获取模型输入输出信息
  79. self.inputs = {inp.name: inp for inp in self.session.get_inputs()}
  80. self.outputs = {out.name: out for out in self.session.get_outputs()}
  81. # 自动检测输入尺寸
  82. self.target_size = self._detect_input_size()
  83. print(f"✅ PaddleX Layout Detector initialized")
  84. print(f" - Model: {Path(onnx_path).name}")
  85. print(f" - Target size: {self.target_size}")
  86. print(f" - Device: {device}")
  87. print(f" - Providers: {self.session.get_providers()}")
  88. except Exception as e:
  89. print(f"❌ Failed to initialize PaddleX Layout Detector: {e}")
  90. raise
  91. def cleanup(self):
  92. """清理资源"""
  93. self.session = None
  94. self.inputs = {}
  95. self.outputs = {}
  96. def detect(self, image: Union[np.ndarray, Image.Image]) -> List[Dict[str, Any]]:
  97. """
  98. 检测布局
  99. Args:
  100. image: 输入图像 (numpy数组或PIL图像)
  101. Returns:
  102. 检测结果列表,每个元素包含:
  103. - category: MinerU类别名称
  104. - bbox: [x1, y1, x2, y2]
  105. - confidence: 置信度
  106. - raw: 原始检测结果
  107. """
  108. if self.session is None:
  109. raise RuntimeError("Model not initialized. Call initialize() first.")
  110. # 转换为numpy数组
  111. if isinstance(image, Image.Image):
  112. image = np.array(image)
  113. if image.ndim == 2: # 灰度图
  114. image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
  115. elif image.shape[2] == 4: # RGBA
  116. image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
  117. elif image.shape[2] == 3: # RGB
  118. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  119. # 执行预测
  120. conf_threshold = self.config.get('conf', 0.25)
  121. results = self._predict(image, conf_threshold)
  122. # 转换为 MinerU 格式
  123. formatted_results = []
  124. for result in results:
  125. # 映射类别
  126. original_category_id = result['category_id']
  127. mineru_category = self.CATEGORY_MAP.get(original_category_id, 'text')
  128. formatted_results.append({
  129. 'category': mineru_category,
  130. 'bbox': result['bbox'],
  131. 'confidence': result['score'],
  132. 'raw': {
  133. 'original_category_id': original_category_id,
  134. 'original_category_name': result['category_name'],
  135. 'poly': result['poly'],
  136. 'width': result['width'],
  137. 'height': result['height']
  138. }
  139. })
  140. return formatted_results
  141. def _detect_input_size(self) -> int:
  142. """自动检测模型的输入尺寸"""
  143. if 'image' in self.inputs:
  144. shape = self.inputs['image'].shape
  145. # shape 通常是 [batch, channels, height, width]
  146. if len(shape) >= 3:
  147. # 尝试从 shape[2] 或 shape[3] 获取尺寸
  148. for dim in shape[2:]:
  149. if isinstance(dim, int) and dim > 0:
  150. return dim
  151. return 640 # 默认值
  152. def _preprocess(
  153. self,
  154. img: np.ndarray
  155. ) -> Tuple[Dict[str, np.ndarray], Tuple[float, float], Tuple[int, int]]:
  156. """
  157. 预处理图像 (根据 RT-DETR 的配置)
  158. Returns:
  159. input_dict: 包含所有输入的字典
  160. scale: (scale_h, scale_w) 缩放因子
  161. orig_shape: (h, w) 原始图像尺寸
  162. """
  163. orig_h, orig_w = img.shape[:2]
  164. target_size = self.target_size # 640
  165. # 1. Resize 到目标尺寸 (不保持长宽比)
  166. img_resized = cv2.resize(
  167. img,
  168. (target_size, target_size),
  169. interpolation=cv2.INTER_LINEAR
  170. )
  171. # 2. 转换为 RGB
  172. img_rgb = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB)
  173. # ✅ 修正 3: 归一化 (mean=[0,0,0], std=[1,1,1], norm_type=none)
  174. # 只做 /255,不做均值减法和标准差除法
  175. img_normalized = img_rgb.astype(np.float32) / 255.0
  176. # 4. 转换为 CHW 格式
  177. img_chw = img_normalized.transpose(2, 0, 1)
  178. img_tensor = img_chw[None, ...].astype(np.float32) # [1, 3, H, W]
  179. # 5. 准备所有输入
  180. input_dict = {}
  181. # 主图像输入
  182. if 'image' in self.inputs:
  183. input_dict['image'] = img_tensor
  184. elif 'images' in self.inputs:
  185. input_dict['images'] = img_tensor
  186. else:
  187. # 使用第一个输入
  188. first_input_name = list(self.inputs.keys())[0]
  189. input_dict[first_input_name] = img_tensor
  190. # ✅ 修正 4: 计算缩放因子 (实际图像尺寸 / 目标尺寸)
  191. scale_h = orig_h / target_size
  192. scale_w = orig_w / target_size
  193. # im_shape 输入 (原始图像尺寸)
  194. if 'im_shape' in self.inputs:
  195. im_shape = np.array([[float(orig_h), float(orig_w)]], dtype=np.float32)
  196. input_dict['im_shape'] = im_shape
  197. # scale_factor 输入
  198. if 'scale_factor' in self.inputs:
  199. # ⚠️ 注意:这里是原始尺寸/目标尺寸的比例
  200. scale_factor = np.array([[scale_h, scale_w]], dtype=np.float32)
  201. input_dict['scale_factor'] = scale_factor
  202. # ✅ 返回的 scale 用于后处理坐标还原
  203. # 因为不保持长宽比,所以需要分别记录 x 和 y 的缩放
  204. return input_dict, (scale_h, scale_w), (orig_h, orig_w)
  205. def _postprocess(
  206. self,
  207. outputs: List[np.ndarray],
  208. scale: Tuple[float, float], # (scale_h, scale_w)
  209. orig_shape: Tuple[int, int],
  210. conf_threshold: float = 0.5
  211. ) -> List[Dict]:
  212. """
  213. 后处理模型输出
  214. Args:
  215. outputs: ONNX 模型输出
  216. scale: (scale_h, scale_w) 缩放因子
  217. orig_shape: (h, w) 原始图像尺寸
  218. conf_threshold: 置信度阈值
  219. Returns:
  220. 检测结果列表
  221. """
  222. scale_h, scale_w = scale
  223. orig_h, orig_w = orig_shape
  224. # 解析输出格式
  225. if len(outputs) >= 2:
  226. output0_shape = outputs[0].shape
  227. output1_shape = outputs[1].shape
  228. # RT-DETR ONNX 格式: (num_boxes, 6)
  229. # 格式: [label_id, score, x1, y1, x2, y2]
  230. if len(output0_shape) == 2 and output0_shape[1] == 6:
  231. pred = outputs[0]
  232. labels = pred[:, 0].astype(int)
  233. scores = pred[:, 1]
  234. bboxes = pred[:, 2:6].copy() # [x1, y1, x2, y2] - 在 640×640 尺度上
  235. # 情况2: output0 是 (batch, num_boxes, 6) - 带batch的合并格式
  236. elif len(output0_shape) == 3 and output0_shape[2] == 6:
  237. pred = outputs[0][0]
  238. labels = pred[:, 0].astype(int)
  239. scores = pred[:, 1]
  240. bboxes = pred[:, 2:6].copy()
  241. # 情况3: output0 是 bboxes, output1 是 scores (分离格式)
  242. elif len(output0_shape) == 2 and output0_shape[1] == 4:
  243. bboxes = outputs[0].copy()
  244. if len(output1_shape) == 1:
  245. scores = outputs[1]
  246. labels = np.zeros(len(scores), dtype=int)
  247. elif len(output1_shape) == 2:
  248. scores_all = outputs[1]
  249. scores = scores_all.max(axis=1)
  250. labels = scores_all.argmax(axis=1)
  251. else:
  252. raise ValueError(f"Unexpected output1 shape: {output1_shape}")
  253. # 情况4: RT-DETR 格式 (batch, num_boxes, 4) + (batch, num_boxes, num_classes)
  254. elif len(output0_shape) == 3 and output0_shape[2] == 4:
  255. bboxes = outputs[0][0].copy()
  256. scores_all = outputs[1][0]
  257. scores = scores_all.max(axis=1)
  258. labels = scores_all.argmax(axis=1)
  259. else:
  260. raise ValueError(f"Unexpected output format: {output0_shape}, {output1_shape}")
  261. elif len(outputs) == 1:
  262. # 单一输出
  263. output_shape = outputs[0].shape
  264. if len(output_shape) == 2 and output_shape[1] == 6:
  265. pred = outputs[0]
  266. labels = pred[:, 0].astype(int)
  267. scores = pred[:, 1]
  268. bboxes = pred[:, 2:6].copy()
  269. elif len(output_shape) == 3 and output_shape[2] == 6:
  270. pred = outputs[0][0]
  271. labels = pred[:, 0].astype(int)
  272. scores = pred[:, 1]
  273. bboxes = pred[:, 2:6].copy()
  274. else:
  275. raise ValueError(f"Unexpected single output shape: {output_shape}")
  276. else:
  277. raise ValueError(f"Unexpected number of outputs: {len(outputs)}")
  278. # 将坐标从 640×640 还原到原图尺度
  279. bboxes[:, [0, 2]] *= scale_w
  280. bboxes[:, [1, 3]] *= scale_h
  281. # 自适应阈值
  282. max_score = scores.max() if len(scores) > 0 else 0
  283. if max_score < conf_threshold:
  284. adjusted_threshold = max(max_score * 0.5, 0.05)
  285. conf_threshold = adjusted_threshold
  286. # 过滤低分框
  287. mask = scores > conf_threshold
  288. bboxes = bboxes[mask]
  289. scores = scores[mask]
  290. labels = labels[mask]
  291. # 过滤完全在图像外的框
  292. valid_mask = (
  293. (bboxes[:, 2] > 0) & # x2 > 0
  294. (bboxes[:, 3] > 0) & # y2 > 0
  295. (bboxes[:, 0] < orig_w) & # x1 < width
  296. (bboxes[:, 1] < orig_h) # y1 < height
  297. )
  298. bboxes = bboxes[valid_mask]
  299. scores = scores[valid_mask]
  300. labels = labels[valid_mask]
  301. # 裁剪坐标到图像范围
  302. bboxes[:, [0, 2]] = np.clip(bboxes[:, [0, 2]], 0, orig_w)
  303. bboxes[:, [1, 3]] = np.clip(bboxes[:, [1, 3]], 0, orig_h)
  304. # 构造结果
  305. results = []
  306. for box, score, label in zip(bboxes, scores, labels):
  307. x1, y1, x2, y2 = box
  308. # 过滤无效框
  309. width = x2 - x1
  310. height = y2 - y1
  311. # 过滤太小的框
  312. if width < 10 or height < 10:
  313. continue
  314. # 过滤面积异常大的框
  315. area = width * height
  316. img_area = orig_w * orig_h
  317. if area > img_area * 0.95:
  318. continue
  319. results.append({
  320. 'category_id': int(label),
  321. 'category_name': self.ORIGINAL_CATEGORY_NAMES.get(int(label), f'unknown_{label}'),
  322. 'bbox': [int(x1), int(y1), int(x2), int(y2)],
  323. 'poly': [int(x1), int(y1), int(x2), int(y1), int(x2), int(y2), int(x1), int(y2)],
  324. 'score': float(score),
  325. 'width': int(width),
  326. 'height': int(height)
  327. })
  328. return results
  329. def _predict(
  330. self,
  331. img: np.ndarray,
  332. conf_threshold: float = 0.25
  333. ) -> List[Dict]:
  334. """执行预测"""
  335. # 预处理
  336. input_dict, scale, orig_shape = self._preprocess(img)
  337. # ONNX 推理
  338. output_names = [out.name for out in self.session.get_outputs()]
  339. outputs = self.session.run(output_names, input_dict)
  340. # 后处理
  341. results = self._postprocess(outputs, scale, orig_shape, conf_threshold)
  342. return results
  343. def visualize(
  344. self,
  345. img: np.ndarray,
  346. results: List[Dict],
  347. output_path: str = None,
  348. show_confidence: bool = True,
  349. min_confidence: float = 0.0
  350. ) -> np.ndarray:
  351. """
  352. 可视化检测结果
  353. Args:
  354. img: 输入图像
  355. results: 检测结果 (MinerU格式)
  356. output_path: 输出路径(可选)
  357. show_confidence: 是否显示置信度
  358. min_confidence: 最小置信度阈值
  359. Returns:
  360. 标注后的图像
  361. """
  362. import random
  363. vis_img = img.copy()
  364. # 为每个类别分配固定颜色
  365. category_colors = {}
  366. # 预定义一些常用类别的颜色
  367. predefined_colors = {
  368. 'text': (0, 255, 0), # 绿色
  369. 'title': (255, 0, 0), # 红色
  370. 'table_body': (0, 0, 255), # 蓝色
  371. 'table_caption': (255, 255, 0), # 青色
  372. 'table_footnote': (255, 128, 0), # 橙色
  373. 'image_body': (255, 0, 255), # 洋红
  374. 'image_caption': (128, 0, 255), # 紫色
  375. 'interline_equation': (0, 255, 255), # 黄色
  376. 'abandon': (128, 128, 128), # 灰色
  377. }
  378. # 过滤低置信度结果
  379. filtered_results = [
  380. res for res in results
  381. if res['confidence'] >= min_confidence
  382. ]
  383. if not filtered_results:
  384. print(f"⚠️ No results to visualize (min_confidence={min_confidence})")
  385. return vis_img
  386. # 为每个出现的类别分配颜色
  387. for res in filtered_results:
  388. cat = res['category']
  389. if cat not in category_colors:
  390. if cat in predefined_colors:
  391. category_colors[cat] = predefined_colors[cat]
  392. else:
  393. # 随机生成颜色
  394. category_colors[cat] = (
  395. random.randint(50, 255),
  396. random.randint(50, 255),
  397. random.randint(50, 255)
  398. )
  399. # 绘制检测框
  400. for res in filtered_results:
  401. bbox = res['bbox']
  402. x1, y1, x2, y2 = bbox
  403. cat = res['category']
  404. confidence = res['confidence']
  405. color = category_colors[cat]
  406. # 绘制矩形边框
  407. cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
  408. # 构造标签文本
  409. if show_confidence:
  410. label = f"{cat} {confidence:.2f}"
  411. else:
  412. label = cat
  413. # 计算标签尺寸
  414. label_size, baseline = cv2.getTextSize(
  415. label,
  416. cv2.FONT_HERSHEY_SIMPLEX,
  417. 0.5,
  418. 1
  419. )
  420. label_w, label_h = label_size
  421. # 绘制标签背景 (填充矩形)
  422. cv2.rectangle(
  423. vis_img,
  424. (x1, y1 - label_h - 4),
  425. (x1 + label_w, y1),
  426. color,
  427. -1 # 填充
  428. )
  429. # 绘制标签文字 (白色)
  430. cv2.putText(
  431. vis_img,
  432. label,
  433. (x1, y1 - 2),
  434. cv2.FONT_HERSHEY_SIMPLEX,
  435. 0.5,
  436. (255, 255, 255), # 白色文字
  437. 1,
  438. cv2.LINE_AA
  439. )
  440. # 添加图例 (在图像右上角)
  441. if category_colors:
  442. self._draw_legend(vis_img, category_colors, len(filtered_results))
  443. # 保存可视化结果
  444. if output_path:
  445. output_path = Path(output_path)
  446. output_path.parent.mkdir(parents=True, exist_ok=True)
  447. cv2.imwrite(str(output_path), vis_img)
  448. print(f"💾 Visualization saved to: {output_path}")
  449. return vis_img
  450. def _draw_legend(
  451. self,
  452. img: np.ndarray,
  453. category_colors: Dict[str, tuple],
  454. total_count: int
  455. ):
  456. """
  457. 在图像上绘制图例
  458. Args:
  459. img: 图像
  460. category_colors: 类别颜色映射
  461. total_count: 总检测数量
  462. """
  463. legend_x = img.shape[1] - 200 # 右侧留200像素
  464. legend_y = 20
  465. line_height = 25
  466. # 绘制半透明背景
  467. overlay = img.copy()
  468. cv2.rectangle(
  469. overlay,
  470. (legend_x - 10, legend_y - 10),
  471. (img.shape[1] - 10, legend_y + len(category_colors) * line_height + 30),
  472. (255, 255, 255),
  473. -1
  474. )
  475. cv2.addWeighted(overlay, 0.7, img, 0.3, 0, img)
  476. # 绘制标题
  477. cv2.putText(
  478. img,
  479. f"Legend ({total_count} total)",
  480. (legend_x, legend_y),
  481. cv2.FONT_HERSHEY_SIMPLEX,
  482. 0.5,
  483. (0, 0, 0),
  484. 1,
  485. cv2.LINE_AA
  486. )
  487. # 绘制每个类别
  488. y_offset = legend_y + line_height
  489. for cat, color in sorted(category_colors.items()):
  490. # 绘制颜色方块
  491. cv2.rectangle(
  492. img,
  493. (legend_x, y_offset - 10),
  494. (legend_x + 15, y_offset),
  495. color,
  496. -1
  497. )
  498. cv2.rectangle(
  499. img,
  500. (legend_x, y_offset - 10),
  501. (legend_x + 15, y_offset),
  502. (0, 0, 0),
  503. 1
  504. )
  505. # 绘制类别名称
  506. cv2.putText(
  507. img,
  508. cat,
  509. (legend_x + 20, y_offset - 2),
  510. cv2.FONT_HERSHEY_SIMPLEX,
  511. 0.4,
  512. (0, 0, 0),
  513. 1,
  514. cv2.LINE_AA
  515. )
  516. y_offset += line_height
  517. # 测试代码
  518. if __name__ == "__main__":
  519. import yaml
  520. # 测试配置
  521. config = {
  522. 'model_dir': '/Users/zhch158/workspace/repository.git/PaddleX/zhch/unified_pytorch_models/Layout/RT-DETR-H_layout_17cls.onnx',
  523. 'device': 'cpu',
  524. 'conf': 0.25
  525. }
  526. # 初始化检测器
  527. print("🔧 Initializing detector...")
  528. detector = PaddleLayoutDetector(config)
  529. detector.initialize()
  530. # 读取测试图像
  531. img_path = "/Users/zhch158/workspace/data/流水分析/B用户_扫描流水/PaddleOCR_VL_Results/B用户_扫描流水/B用户_扫描流水_page_001.png"
  532. print(f"\n📖 Loading image: {img_path}")
  533. img = cv2.imread(img_path)
  534. if img is None:
  535. print(f"❌ Failed to load image: {img_path}")
  536. exit(1)
  537. print(f" Image shape: {img.shape}")
  538. # 执行检测
  539. print("\n🔍 Detecting layout...")
  540. results = detector.detect(img)
  541. print(f"\n✅ 检测到 {len(results)} 个区域:")
  542. for i, res in enumerate(results, 1):
  543. print(f" [{i}] {res['category']}: "
  544. f"score={res['confidence']:.3f}, "
  545. f"bbox={res['bbox']}, "
  546. f"original={res['raw']['original_category_name']}")
  547. # 统计各类别
  548. category_counts = {}
  549. for res in results:
  550. cat = res['category']
  551. category_counts[cat] = category_counts.get(cat, 0) + 1
  552. print(f"\n📊 类别统计 (MinerU格式):")
  553. for cat, count in sorted(category_counts.items()):
  554. print(f" - {cat}: {count}")
  555. # 使用新的可视化方法
  556. if len(results) > 0:
  557. print("\n🎨 Generating visualization...")
  558. # 创建输出目录
  559. output_dir = Path(__file__).parent.parent.parent / "tests" / "output"
  560. output_dir.mkdir(parents=True, exist_ok=True)
  561. output_path = output_dir / f"{Path(img_path).stem}_layout_vis.jpg"
  562. # 调用可视化方法
  563. vis_img = detector.visualize(
  564. img,
  565. results,
  566. output_path=str(output_path),
  567. show_confidence=True,
  568. min_confidence=0.0
  569. )
  570. print(f"💾 Visualization saved to: {output_path}")
  571. # 清理
  572. detector.cleanup()
  573. print("\n✅ 测试完成!")