| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212 |
- """
- 通用 PaddlePaddle 到 PyTorch 转换器
- 支持:
- 1. 传统推理模型(inference.pdmodel + inference.pdiparams)
- 2. PaddleX 3.0 模型(inference.json + inference.pdiparams)
- """
- import json
- import subprocess
- import torch
- import onnx
- from pathlib import Path
- from onnx2pytorch import ConvertModel
- from enum import Enum
- class PaddleModelFormat(Enum):
- """Paddle模型格式枚举"""
- LEGACY = "legacy" # inference.pdmodel + inference.pdiparams
- PADDLEX_V3 = "paddlex_v3" # inference.json + inference.pdiparams
- class UniversalPaddleToPyTorchConverter:
- """通用 Paddle → PyTorch 转换器"""
-
- def __init__(self, paddle_model_dir: str, output_dir: str = "./unified_pytorch_models"):
- self.paddle_model_dir = Path(paddle_model_dir)
- self.output_dir = Path(output_dir)
- self.output_dir.mkdir(parents=True, exist_ok=True)
-
- # 自动检测模型格式
- self.model_format = self._detect_model_format()
-
- def _detect_model_format(self) -> PaddleModelFormat:
- """自动检测模型格式"""
- has_pdmodel = (self.paddle_model_dir / "inference.pdmodel").exists()
- has_json = (self.paddle_model_dir / "inference.json").exists()
- has_params = (self.paddle_model_dir / "inference.pdiparams").exists()
-
- if not has_params:
- raise FileNotFoundError(
- f"❌ 参数文件不存在: {self.paddle_model_dir}/inference.pdiparams"
- )
-
- if has_pdmodel:
- print(f"✅ 检测到传统推理模型格式 (.pdmodel)")
- return PaddleModelFormat.LEGACY
- elif has_json:
- print(f"✅ 检测到 PaddleX 3.0 格式 (.json)")
- return PaddleModelFormat.PADDLEX_V3
- else:
- raise FileNotFoundError(
- f"❌ 未找到模型结构文件(inference.pdmodel 或 inference.json)"
- )
-
- def convert(self, model_name: str) -> str:
- """执行转换"""
- print(f"\n{'='*60}")
- print(f"🔄 开始转换: {model_name}")
- print(f" 模型格式: {self.model_format.value}")
- print(f"{'='*60}")
-
- # 根据格式选择转换方法
- if self.model_format == PaddleModelFormat.LEGACY:
- onnx_path = self._convert_legacy_to_onnx(model_name)
- else:
- onnx_path = self._convert_paddlex_to_onnx(model_name)
-
- # ONNX → PyTorch
- pytorch_path = self._onnx_to_pytorch(onnx_path, model_name)
-
- print(f"{'='*60}")
- print(f"✅ 转换完成: {model_name}")
- print(f"{'='*60}\n")
-
- return pytorch_path
-
- def _convert_legacy_to_onnx(self, model_name: str) -> Path:
- """传统格式 → ONNX"""
- print(f"\n📦 步骤1: 传统格式 Paddle → ONNX")
-
- onnx_output_path = self.output_dir / f"{model_name}.onnx"
-
- cmd = [
- "paddle2onnx",
- "--model_dir", str(self.paddle_model_dir),
- "--model_filename", "inference.pdmodel",
- "--params_filename", "inference.pdiparams",
- "--save_file", str(onnx_output_path),
- "--opset_version", "11",
- "--enable_onnx_checker", "True"
- ]
-
- result = subprocess.run(cmd, capture_output=True, text=True)
-
- if result.returncode != 0:
- raise RuntimeError(f"Paddle2ONNX 转换失败:\n{result.stderr}")
-
- print(f" ✅ ONNX已保存: {onnx_output_path}")
- return onnx_output_path
-
- def _convert_paddlex_to_onnx(self, model_name: str) -> Path:
- """PaddleX 3.0 格式 → ONNX"""
- print(f"\n📦 步骤1: PaddleX 3.0 → ONNX")
-
- onnx_output_path = self.output_dir / f"{model_name}.onnx"
-
- # 方法1: 尝试直接使用 paddle2onnx 处理 JSON
- print(f" 尝试方法1: paddle2onnx 直接转换 JSON...")
- cmd = [
- "paddle2onnx",
- "--model_dir", str(self.paddle_model_dir),
- "--model_filename", "inference.json",
- "--params_filename", "inference.pdiparams",
- "--save_file", str(onnx_output_path),
- "--opset_version", "11",
- ]
-
- result = subprocess.run(cmd, capture_output=True, text=True)
-
- if result.returncode == 0 and onnx_output_path.exists():
- print(f" ✅ 方法1成功!ONNX已保存: {onnx_output_path}")
- return onnx_output_path
-
- # 方法2: 使用 PaddleX CLI
- print(f" 方法1失败,尝试方法2: PaddleX CLI...")
- cmd = [
- "paddlex",
- "--paddle2onnx",
- "--paddle_model_dir", str(self.paddle_model_dir),
- "--onnx_model_dir", str(self.output_dir),
- "--opset_version", "11"
- ]
-
- result = subprocess.run(cmd, capture_output=True, text=True)
-
- if result.returncode != 0:
- raise RuntimeError(
- f"所有转换方法都失败了!\n"
- f"方法1输出: {result.stdout}\n"
- f"方法2输出: {result.stderr}\n"
- f"建议:请手动使用 PaddleX 导出 ONNX 模型"
- )
-
- # 重命名生成的文件
- generated_files = list(self.output_dir.glob("inference*.onnx"))
- if generated_files:
- generated_files[0].rename(onnx_output_path)
-
- if not onnx_output_path.exists():
- raise FileNotFoundError(
- f"ONNX 文件未生成: {onnx_output_path}\n"
- f"输出目录内容: {list(self.output_dir.iterdir())}"
- )
-
- print(f" ✅ ONNX已保存: {onnx_output_path}")
- return onnx_output_path
-
- def _onnx_to_pytorch(self, onnx_path: Path, model_name: str) -> Path:
- """ONNX → PyTorch"""
- print(f"\n📦 步骤2: ONNX → PyTorch")
-
- # 加载 ONNX
- onnx_model = onnx.load(str(onnx_path))
- onnx.checker.check_model(onnx_model)
-
- # 转换为 PyTorch
- pytorch_model = ConvertModel(onnx_model, experimental=True)
-
- # 保存
- pytorch_output_path = self.output_dir / f"{model_name}.pth"
- torch.save({
- 'model_state_dict': pytorch_model.state_dict(),
- 'model': pytorch_model,
- 'source': f'converted_from_paddle_{self.model_format.value}',
- 'original_dir': str(self.paddle_model_dir)
- }, pytorch_output_path)
-
- print(f" ✅ PyTorch已保存: {pytorch_output_path}")
- return pytorch_output_path
- def batch_convert_all_models():
- """批量转换所有模型(自动识别格式)"""
-
- MODELS = [
- # ("PicoDet_layout_1x", "Layout"),
- ("PP-DocLayoutV2", "Layout"),
- ("RT-DETR-H_layout_17cls", "Layout"),
- ]
-
- base_dir = Path("~/.paddlex/official_models").expanduser()
- output_base = Path("./")
-
- for model_name, category in MODELS:
- model_dir = base_dir / model_name
-
- if not model_dir.exists():
- print(f"⚠️ 跳过 {model_name}: 目录不存在")
- continue
-
- try:
- converter = UniversalPaddleToPyTorchConverter(
- model_dir,
- output_base / category
- )
- converter.convert(model_name)
- except Exception as e:
- print(f"❌ {model_name} 转换失败: {e}\n")
- if __name__ == "__main__":
- batch_convert_all_models()
|