paddle_to_pytorch_universal.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. """
  2. 通用 PaddlePaddle 到 PyTorch 转换器
  3. 支持:
  4. 1. 传统推理模型(inference.pdmodel + inference.pdiparams)
  5. 2. PaddleX 3.0 模型(inference.json + inference.pdiparams)
  6. """
  7. import json
  8. import subprocess
  9. import torch
  10. import onnx
  11. from pathlib import Path
  12. from onnx2pytorch import ConvertModel
  13. from enum import Enum
  14. class PaddleModelFormat(Enum):
  15. """Paddle模型格式枚举"""
  16. LEGACY = "legacy" # inference.pdmodel + inference.pdiparams
  17. PADDLEX_V3 = "paddlex_v3" # inference.json + inference.pdiparams
  18. class UniversalPaddleToPyTorchConverter:
  19. """通用 Paddle → PyTorch 转换器"""
  20. def __init__(self, paddle_model_dir: str, output_dir: str = "./unified_pytorch_models"):
  21. self.paddle_model_dir = Path(paddle_model_dir)
  22. self.output_dir = Path(output_dir)
  23. self.output_dir.mkdir(parents=True, exist_ok=True)
  24. # 自动检测模型格式
  25. self.model_format = self._detect_model_format()
  26. def _detect_model_format(self) -> PaddleModelFormat:
  27. """自动检测模型格式"""
  28. has_pdmodel = (self.paddle_model_dir / "inference.pdmodel").exists()
  29. has_json = (self.paddle_model_dir / "inference.json").exists()
  30. has_params = (self.paddle_model_dir / "inference.pdiparams").exists()
  31. if not has_params:
  32. raise FileNotFoundError(
  33. f"❌ 参数文件不存在: {self.paddle_model_dir}/inference.pdiparams"
  34. )
  35. if has_pdmodel:
  36. print(f"✅ 检测到传统推理模型格式 (.pdmodel)")
  37. return PaddleModelFormat.LEGACY
  38. elif has_json:
  39. print(f"✅ 检测到 PaddleX 3.0 格式 (.json)")
  40. return PaddleModelFormat.PADDLEX_V3
  41. else:
  42. raise FileNotFoundError(
  43. f"❌ 未找到模型结构文件(inference.pdmodel 或 inference.json)"
  44. )
  45. def convert(self, model_name: str) -> str:
  46. """执行转换"""
  47. print(f"\n{'='*60}")
  48. print(f"🔄 开始转换: {model_name}")
  49. print(f" 模型格式: {self.model_format.value}")
  50. print(f"{'='*60}")
  51. # 根据格式选择转换方法
  52. if self.model_format == PaddleModelFormat.LEGACY:
  53. onnx_path = self._convert_legacy_to_onnx(model_name)
  54. else:
  55. onnx_path = self._convert_paddlex_to_onnx(model_name)
  56. # ONNX → PyTorch
  57. pytorch_path = self._onnx_to_pytorch(onnx_path, model_name)
  58. print(f"{'='*60}")
  59. print(f"✅ 转换完成: {model_name}")
  60. print(f"{'='*60}\n")
  61. return pytorch_path
  62. def _convert_legacy_to_onnx(self, model_name: str) -> Path:
  63. """传统格式 → ONNX"""
  64. print(f"\n📦 步骤1: 传统格式 Paddle → ONNX")
  65. onnx_output_path = self.output_dir / f"{model_name}.onnx"
  66. cmd = [
  67. "paddle2onnx",
  68. "--model_dir", str(self.paddle_model_dir),
  69. "--model_filename", "inference.pdmodel",
  70. "--params_filename", "inference.pdiparams",
  71. "--save_file", str(onnx_output_path),
  72. "--opset_version", "11",
  73. "--enable_onnx_checker", "True"
  74. ]
  75. result = subprocess.run(cmd, capture_output=True, text=True)
  76. if result.returncode != 0:
  77. raise RuntimeError(f"Paddle2ONNX 转换失败:\n{result.stderr}")
  78. print(f" ✅ ONNX已保存: {onnx_output_path}")
  79. return onnx_output_path
  80. def _convert_paddlex_to_onnx(self, model_name: str) -> Path:
  81. """PaddleX 3.0 格式 → ONNX"""
  82. print(f"\n📦 步骤1: PaddleX 3.0 → ONNX")
  83. onnx_output_path = self.output_dir / f"{model_name}.onnx"
  84. # 方法1: 尝试直接使用 paddle2onnx 处理 JSON
  85. print(f" 尝试方法1: paddle2onnx 直接转换 JSON...")
  86. cmd = [
  87. "paddle2onnx",
  88. "--model_dir", str(self.paddle_model_dir),
  89. "--model_filename", "inference.json",
  90. "--params_filename", "inference.pdiparams",
  91. "--save_file", str(onnx_output_path),
  92. "--opset_version", "11",
  93. ]
  94. result = subprocess.run(cmd, capture_output=True, text=True)
  95. if result.returncode == 0 and onnx_output_path.exists():
  96. print(f" ✅ 方法1成功!ONNX已保存: {onnx_output_path}")
  97. return onnx_output_path
  98. # 方法2: 使用 PaddleX CLI
  99. print(f" 方法1失败,尝试方法2: PaddleX CLI...")
  100. cmd = [
  101. "paddlex",
  102. "--paddle2onnx",
  103. "--paddle_model_dir", str(self.paddle_model_dir),
  104. "--onnx_model_dir", str(self.output_dir),
  105. "--opset_version", "11"
  106. ]
  107. result = subprocess.run(cmd, capture_output=True, text=True)
  108. if result.returncode != 0:
  109. raise RuntimeError(
  110. f"所有转换方法都失败了!\n"
  111. f"方法1输出: {result.stdout}\n"
  112. f"方法2输出: {result.stderr}\n"
  113. f"建议:请手动使用 PaddleX 导出 ONNX 模型"
  114. )
  115. # 重命名生成的文件
  116. generated_files = list(self.output_dir.glob("inference*.onnx"))
  117. if generated_files:
  118. generated_files[0].rename(onnx_output_path)
  119. if not onnx_output_path.exists():
  120. raise FileNotFoundError(
  121. f"ONNX 文件未生成: {onnx_output_path}\n"
  122. f"输出目录内容: {list(self.output_dir.iterdir())}"
  123. )
  124. print(f" ✅ ONNX已保存: {onnx_output_path}")
  125. return onnx_output_path
  126. def _onnx_to_pytorch(self, onnx_path: Path, model_name: str) -> Path:
  127. """ONNX → PyTorch"""
  128. print(f"\n📦 步骤2: ONNX → PyTorch")
  129. # 加载 ONNX
  130. onnx_model = onnx.load(str(onnx_path))
  131. onnx.checker.check_model(onnx_model)
  132. # 转换为 PyTorch
  133. pytorch_model = ConvertModel(onnx_model, experimental=True)
  134. # 保存
  135. pytorch_output_path = self.output_dir / f"{model_name}.pth"
  136. torch.save({
  137. 'model_state_dict': pytorch_model.state_dict(),
  138. 'model': pytorch_model,
  139. 'source': f'converted_from_paddle_{self.model_format.value}',
  140. 'original_dir': str(self.paddle_model_dir)
  141. }, pytorch_output_path)
  142. print(f" ✅ PyTorch已保存: {pytorch_output_path}")
  143. return pytorch_output_path
  144. def batch_convert_all_models():
  145. """批量转换所有模型(自动识别格式)"""
  146. MODELS = [
  147. # ("PicoDet_layout_1x", "Layout"),
  148. ("PP-DocLayoutV2", "Layout"),
  149. ("RT-DETR-H_layout_17cls", "Layout"),
  150. ]
  151. base_dir = Path("~/.paddlex/official_models").expanduser()
  152. output_base = Path("./")
  153. for model_name, category in MODELS:
  154. model_dir = base_dir / model_name
  155. if not model_dir.exists():
  156. print(f"⚠️ 跳过 {model_name}: 目录不存在")
  157. continue
  158. try:
  159. converter = UniversalPaddleToPyTorchConverter(
  160. model_dir,
  161. output_base / category
  162. )
  163. converter.convert(model_name)
  164. except Exception as e:
  165. print(f"❌ {model_name} 转换失败: {e}\n")
  166. if __name__ == "__main__":
  167. batch_convert_all_models()