paddle_to_pytorch_universal.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. """
  2. 通用 PaddlePaddle 到 PyTorch 转换器
  3. 支持:
  4. 1. 传统推理模型(inference.pdmodel + inference.pdiparams)
  5. 2. PaddleX 3.0 模型(inference.json + inference.pdiparams)
  6. """
  7. import json
  8. import subprocess
  9. import torch
  10. import onnx
  11. from pathlib import Path
  12. from onnx2pytorch import ConvertModel
  13. from enum import Enum
  14. class PaddleModelFormat(Enum):
  15. """Paddle模型格式枚举"""
  16. LEGACY = "legacy" # inference.pdmodel + inference.pdiparams
  17. PADDLEX_V3 = "paddlex_v3" # inference.json + inference.pdiparams
  18. class UniversalPaddleToPyTorchConverter:
  19. """通用 Paddle → PyTorch 转换器"""
  20. def __init__(self, paddle_model_dir: str, output_dir: str = "./unified_pytorch_models"):
  21. self.paddle_model_dir = Path(paddle_model_dir)
  22. self.output_dir = Path(output_dir)
  23. self.output_dir.mkdir(parents=True, exist_ok=True)
  24. # 自动检测模型格式
  25. self.model_format = self._detect_model_format()
  26. def _detect_model_format(self) -> PaddleModelFormat:
  27. """自动检测模型格式"""
  28. has_pdmodel = (self.paddle_model_dir / "inference.pdmodel").exists()
  29. has_json = (self.paddle_model_dir / "inference.json").exists()
  30. has_params = (self.paddle_model_dir / "inference.pdiparams").exists()
  31. if not has_params:
  32. raise FileNotFoundError(
  33. f"❌ 参数文件不存在: {self.paddle_model_dir}/inference.pdiparams"
  34. )
  35. if has_pdmodel:
  36. print(f"✅ 检测到传统推理模型格式 (.pdmodel)")
  37. return PaddleModelFormat.LEGACY
  38. elif has_json:
  39. print(f"✅ 检测到 PaddleX 3.0 格式 (.json)")
  40. return PaddleModelFormat.PADDLEX_V3
  41. else:
  42. raise FileNotFoundError(
  43. f"❌ 未找到模型结构文件(inference.pdmodel 或 inference.json)"
  44. )
  45. def convert(self, model_name: str) -> str:
  46. """执行转换"""
  47. print(f"\n{'='*60}")
  48. print(f"🔄 开始转换: {model_name}")
  49. print(f" 模型格式: {self.model_format.value}")
  50. print(f"{'='*60}")
  51. # 根据格式选择转换方法
  52. if self.model_format == PaddleModelFormat.LEGACY:
  53. onnx_path = self._convert_legacy_to_onnx(model_name)
  54. else:
  55. onnx_path = self._convert_paddlex_to_onnx(model_name)
  56. # ONNX → PyTorch
  57. # pytorch_path = self._onnx_to_pytorch(onnx_path, model_name)
  58. print(f"{'='*60}")
  59. print(f"✅ 转换完成: {model_name}")
  60. print(f"{'='*60}\n")
  61. return onnx_path
  62. def _convert_legacy_to_onnx(self, model_name: str) -> Path:
  63. """传统格式 → ONNX"""
  64. print(f"\n📦 步骤1: 传统格式 Paddle → ONNX")
  65. onnx_output_path = self.output_dir / f"{model_name}.onnx"
  66. cmd = [
  67. "paddle2onnx",
  68. "--model_dir", str(self.paddle_model_dir),
  69. "--model_filename", "inference.pdmodel",
  70. "--params_filename", "inference.pdiparams",
  71. "--save_file", str(onnx_output_path),
  72. # "--opset_version", "11", # 不用指定
  73. "--enable_onnx_checker", "True"
  74. ]
  75. result = subprocess.run(cmd, capture_output=True, text=True)
  76. if result.returncode != 0:
  77. raise RuntimeError(f"Paddle2ONNX 转换失败:\n{result.stderr}")
  78. print(f" ✅ ONNX已保存: {onnx_output_path}")
  79. return onnx_output_path
  80. def _convert_paddlex_to_onnx(self, model_name: str) -> Path:
  81. """PaddleX 3.0 格式 → ONNX"""
  82. # 使用 PaddleX CLI
  83. print(f" PaddleX CLI...")
  84. onnx_output_path = self.output_dir / f"{model_name}.onnx"
  85. cmd = [
  86. "paddlex",
  87. "--paddle2onnx",
  88. "--paddle_model_dir", str(self.paddle_model_dir),
  89. "--onnx_model_dir", str(self.output_dir),
  90. ]
  91. result = subprocess.run(cmd, capture_output=True, text=True)
  92. if result.returncode != 0:
  93. raise RuntimeError(
  94. f"所有转换方法都失败了!\n"
  95. f"输出: {result.stdout}\n"
  96. f"输出: {result.stderr}\n"
  97. f"建议:请手动使用 PaddleX 导出 ONNX 模型"
  98. )
  99. # 重命名生成的文件
  100. generated_files = list(self.output_dir.glob("inference*.onnx"))
  101. if generated_files:
  102. generated_files[0].rename(onnx_output_path)
  103. if not onnx_output_path.exists():
  104. raise FileNotFoundError(
  105. f"ONNX 文件未生成: {onnx_output_path}\n"
  106. f"输出目录内容: {list(self.output_dir.iterdir())}"
  107. )
  108. print(f" ✅ ONNX已保存: {onnx_output_path}")
  109. return onnx_output_path
  110. def _onnx_to_pytorch(self, onnx_path: Path, model_name: str) -> Path:
  111. """ONNX → PyTorch"""
  112. print(f"\n📦 步骤2: ONNX → PyTorch")
  113. # 加载 ONNX
  114. onnx_model = onnx.load(str(onnx_path))
  115. onnx.checker.check_model(onnx_model)
  116. # 转换为 PyTorch
  117. pytorch_model = ConvertModel(onnx_model, experimental=True)
  118. # 保存
  119. pytorch_output_path = self.output_dir / f"{model_name}.pth"
  120. torch.save({
  121. 'model_state_dict': pytorch_model.state_dict(),
  122. 'model': pytorch_model,
  123. 'source': f'converted_from_paddle_{self.model_format.value}',
  124. 'original_dir': str(self.paddle_model_dir)
  125. }, pytorch_output_path)
  126. print(f" ✅ PyTorch已保存: {pytorch_output_path}")
  127. return pytorch_output_path
  128. def batch_convert_all_models():
  129. """批量转换所有模型(自动识别格式)"""
  130. MODELS = [
  131. # ("PicoDet_layout_1x", "Layout"),
  132. ("PP-LCNet_x1_0_doc_ori", "Layout"),
  133. ("RT-DETR-H_layout_17cls", "Layout"),
  134. ]
  135. base_dir = Path("~/.paddlex/official_models").expanduser()
  136. output_base = Path("./")
  137. for model_name, category in MODELS:
  138. model_dir = base_dir / model_name
  139. if not model_dir.exists():
  140. print(f"⚠️ 跳过 {model_name}: 目录不存在")
  141. continue
  142. try:
  143. converter = UniversalPaddleToPyTorchConverter(
  144. model_dir,
  145. output_base / category
  146. )
  147. converter.convert(model_name)
  148. except Exception as e:
  149. print(f"❌ {model_name} 转换失败: {e}\n")
  150. if __name__ == "__main__":
  151. batch_convert_all_models()