|
|
@@ -0,0 +1,212 @@
|
|
|
+"""
|
|
|
+通用 PaddlePaddle 到 PyTorch 转换器
|
|
|
+支持:
|
|
|
+1. 传统推理模型(inference.pdmodel + inference.pdiparams)
|
|
|
+2. PaddleX 3.0 模型(inference.json + inference.pdiparams)
|
|
|
+"""
|
|
|
+import json
|
|
|
+import subprocess
|
|
|
+import torch
|
|
|
+import onnx
|
|
|
+from pathlib import Path
|
|
|
+from onnx2pytorch import ConvertModel
|
|
|
+from enum import Enum
|
|
|
+
|
|
|
+
|
|
|
+class PaddleModelFormat(Enum):
|
|
|
+ """Paddle模型格式枚举"""
|
|
|
+ LEGACY = "legacy" # inference.pdmodel + inference.pdiparams
|
|
|
+ PADDLEX_V3 = "paddlex_v3" # inference.json + inference.pdiparams
|
|
|
+
|
|
|
+
|
|
|
+class UniversalPaddleToPyTorchConverter:
|
|
|
+ """通用 Paddle → PyTorch 转换器"""
|
|
|
+
|
|
|
+ def __init__(self, paddle_model_dir: str, output_dir: str = "./unified_pytorch_models"):
|
|
|
+ self.paddle_model_dir = Path(paddle_model_dir)
|
|
|
+ self.output_dir = Path(output_dir)
|
|
|
+ self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+
|
|
|
+ # 自动检测模型格式
|
|
|
+ self.model_format = self._detect_model_format()
|
|
|
+
|
|
|
+ def _detect_model_format(self) -> PaddleModelFormat:
|
|
|
+ """自动检测模型格式"""
|
|
|
+ has_pdmodel = (self.paddle_model_dir / "inference.pdmodel").exists()
|
|
|
+ has_json = (self.paddle_model_dir / "inference.json").exists()
|
|
|
+ has_params = (self.paddle_model_dir / "inference.pdiparams").exists()
|
|
|
+
|
|
|
+ if not has_params:
|
|
|
+ raise FileNotFoundError(
|
|
|
+ f"❌ 参数文件不存在: {self.paddle_model_dir}/inference.pdiparams"
|
|
|
+ )
|
|
|
+
|
|
|
+ if has_pdmodel:
|
|
|
+ print(f"✅ 检测到传统推理模型格式 (.pdmodel)")
|
|
|
+ return PaddleModelFormat.LEGACY
|
|
|
+ elif has_json:
|
|
|
+ print(f"✅ 检测到 PaddleX 3.0 格式 (.json)")
|
|
|
+ return PaddleModelFormat.PADDLEX_V3
|
|
|
+ else:
|
|
|
+ raise FileNotFoundError(
|
|
|
+ f"❌ 未找到模型结构文件(inference.pdmodel 或 inference.json)"
|
|
|
+ )
|
|
|
+
|
|
|
+ def convert(self, model_name: str) -> str:
|
|
|
+ """执行转换"""
|
|
|
+ print(f"\n{'='*60}")
|
|
|
+ print(f"🔄 开始转换: {model_name}")
|
|
|
+ print(f" 模型格式: {self.model_format.value}")
|
|
|
+ print(f"{'='*60}")
|
|
|
+
|
|
|
+ # 根据格式选择转换方法
|
|
|
+ if self.model_format == PaddleModelFormat.LEGACY:
|
|
|
+ onnx_path = self._convert_legacy_to_onnx(model_name)
|
|
|
+ else:
|
|
|
+ onnx_path = self._convert_paddlex_to_onnx(model_name)
|
|
|
+
|
|
|
+ # ONNX → PyTorch
|
|
|
+ pytorch_path = self._onnx_to_pytorch(onnx_path, model_name)
|
|
|
+
|
|
|
+ print(f"{'='*60}")
|
|
|
+ print(f"✅ 转换完成: {model_name}")
|
|
|
+ print(f"{'='*60}\n")
|
|
|
+
|
|
|
+ return pytorch_path
|
|
|
+
|
|
|
+ def _convert_legacy_to_onnx(self, model_name: str) -> Path:
|
|
|
+ """传统格式 → ONNX"""
|
|
|
+ print(f"\n📦 步骤1: 传统格式 Paddle → ONNX")
|
|
|
+
|
|
|
+ onnx_output_path = self.output_dir / f"{model_name}.onnx"
|
|
|
+
|
|
|
+ cmd = [
|
|
|
+ "paddle2onnx",
|
|
|
+ "--model_dir", str(self.paddle_model_dir),
|
|
|
+ "--model_filename", "inference.pdmodel",
|
|
|
+ "--params_filename", "inference.pdiparams",
|
|
|
+ "--save_file", str(onnx_output_path),
|
|
|
+ "--opset_version", "11",
|
|
|
+ "--enable_onnx_checker", "True"
|
|
|
+ ]
|
|
|
+
|
|
|
+ result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
+
|
|
|
+ if result.returncode != 0:
|
|
|
+ raise RuntimeError(f"Paddle2ONNX 转换失败:\n{result.stderr}")
|
|
|
+
|
|
|
+ print(f" ✅ ONNX已保存: {onnx_output_path}")
|
|
|
+ return onnx_output_path
|
|
|
+
|
|
|
+ def _convert_paddlex_to_onnx(self, model_name: str) -> Path:
|
|
|
+ """PaddleX 3.0 格式 → ONNX"""
|
|
|
+ print(f"\n📦 步骤1: PaddleX 3.0 → ONNX")
|
|
|
+
|
|
|
+ onnx_output_path = self.output_dir / f"{model_name}.onnx"
|
|
|
+
|
|
|
+ # 方法1: 尝试直接使用 paddle2onnx 处理 JSON
|
|
|
+ print(f" 尝试方法1: paddle2onnx 直接转换 JSON...")
|
|
|
+ cmd = [
|
|
|
+ "paddle2onnx",
|
|
|
+ "--model_dir", str(self.paddle_model_dir),
|
|
|
+ "--model_filename", "inference.json",
|
|
|
+ "--params_filename", "inference.pdiparams",
|
|
|
+ "--save_file", str(onnx_output_path),
|
|
|
+ "--opset_version", "11",
|
|
|
+ ]
|
|
|
+
|
|
|
+ result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
+
|
|
|
+ if result.returncode == 0 and onnx_output_path.exists():
|
|
|
+ print(f" ✅ 方法1成功!ONNX已保存: {onnx_output_path}")
|
|
|
+ return onnx_output_path
|
|
|
+
|
|
|
+ # 方法2: 使用 PaddleX CLI
|
|
|
+ print(f" 方法1失败,尝试方法2: PaddleX CLI...")
|
|
|
+ cmd = [
|
|
|
+ "paddlex",
|
|
|
+ "--paddle2onnx",
|
|
|
+ "--paddle_model_dir", str(self.paddle_model_dir),
|
|
|
+ "--onnx_model_dir", str(self.output_dir),
|
|
|
+ "--opset_version", "11"
|
|
|
+ ]
|
|
|
+
|
|
|
+ result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
+
|
|
|
+ if result.returncode != 0:
|
|
|
+ raise RuntimeError(
|
|
|
+ f"所有转换方法都失败了!\n"
|
|
|
+ f"方法1输出: {result.stdout}\n"
|
|
|
+ f"方法2输出: {result.stderr}\n"
|
|
|
+ f"建议:请手动使用 PaddleX 导出 ONNX 模型"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 重命名生成的文件
|
|
|
+ generated_files = list(self.output_dir.glob("inference*.onnx"))
|
|
|
+ if generated_files:
|
|
|
+ generated_files[0].rename(onnx_output_path)
|
|
|
+
|
|
|
+ if not onnx_output_path.exists():
|
|
|
+ raise FileNotFoundError(
|
|
|
+ f"ONNX 文件未生成: {onnx_output_path}\n"
|
|
|
+ f"输出目录内容: {list(self.output_dir.iterdir())}"
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f" ✅ ONNX已保存: {onnx_output_path}")
|
|
|
+ return onnx_output_path
|
|
|
+
|
|
|
+ def _onnx_to_pytorch(self, onnx_path: Path, model_name: str) -> Path:
|
|
|
+ """ONNX → PyTorch"""
|
|
|
+ print(f"\n📦 步骤2: ONNX → PyTorch")
|
|
|
+
|
|
|
+ # 加载 ONNX
|
|
|
+ onnx_model = onnx.load(str(onnx_path))
|
|
|
+ onnx.checker.check_model(onnx_model)
|
|
|
+
|
|
|
+ # 转换为 PyTorch
|
|
|
+ pytorch_model = ConvertModel(onnx_model, experimental=True)
|
|
|
+
|
|
|
+ # 保存
|
|
|
+ pytorch_output_path = self.output_dir / f"{model_name}.pth"
|
|
|
+ torch.save({
|
|
|
+ 'model_state_dict': pytorch_model.state_dict(),
|
|
|
+ 'model': pytorch_model,
|
|
|
+ 'source': f'converted_from_paddle_{self.model_format.value}',
|
|
|
+ 'original_dir': str(self.paddle_model_dir)
|
|
|
+ }, pytorch_output_path)
|
|
|
+
|
|
|
+ print(f" ✅ PyTorch已保存: {pytorch_output_path}")
|
|
|
+ return pytorch_output_path
|
|
|
+
|
|
|
+
|
|
|
+def batch_convert_all_models():
|
|
|
+ """批量转换所有模型(自动识别格式)"""
|
|
|
+
|
|
|
+ MODELS = [
|
|
|
+ # ("PicoDet_layout_1x", "Layout"),
|
|
|
+ ("PP-DocLayoutV2", "Layout"),
|
|
|
+ ("RT-DETR-H_layout_17cls", "Layout"),
|
|
|
+ ]
|
|
|
+
|
|
|
+ base_dir = Path("~/.paddlex/official_models").expanduser()
|
|
|
+ output_base = Path("./")
|
|
|
+
|
|
|
+ for model_name, category in MODELS:
|
|
|
+ model_dir = base_dir / model_name
|
|
|
+
|
|
|
+ if not model_dir.exists():
|
|
|
+ print(f"⚠️ 跳过 {model_name}: 目录不存在")
|
|
|
+ continue
|
|
|
+
|
|
|
+ try:
|
|
|
+ converter = UniversalPaddleToPyTorchConverter(
|
|
|
+ model_dir,
|
|
|
+ output_base / category
|
|
|
+ )
|
|
|
+ converter.convert(model_name)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"❌ {model_name} 转换失败: {e}\n")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ batch_convert_all_models()
|