فهرست منبع

export pipeline info

FlyingQianMM 4 سال پیش
والد
کامیت
d2be43d633
1فایلهای تغییر یافته به همراه39 افزوده شده و 0 حذف شده
  1. 39 0
      dygraph/paddlex/cv/models/base.py

+ 39 - 0
dygraph/paddlex/cv/models/base.py

@@ -482,6 +482,39 @@ class BaseModel:
         logging.info("Model is ready for quantization-aware training.")
         self.status = 'Quantized'
 
+    def _get_pipeline_info(self, save_dir):
+        pipeline_info = {}
+        pipeline_info["pipeline_name"] = self.model_type
+        nodes = [{
+            "src0": {
+                "type": "Source",
+                "next": "decode0"
+            }
+        }, {
+            "decode0": {
+                "type": "Decode",
+                "next": "predict0"
+            }
+        }, {
+            "predict0": {
+                "type": "Predict",
+                "init_params": {
+                    "use_gpu": False,
+                    "gpu_id": 0,
+                    "use_trt": False,
+                    "model_dir": save_dir,
+                },
+                "next": "sink0"
+            }
+        }, {
+            "sink0": {
+                "type": "Sink"
+            }
+        }]
+        pipeline_info["pipeline_nodes"] = nodes
+        pipeline_info["version"] = "1.0.0"
+        return pipeline_info
+
     def _export_inference_model(self, save_dir, image_shape=None):
         save_dir = osp.join(save_dir, 'inference_model')
         self.net.eval()
@@ -515,6 +548,12 @@ class BaseModel:
                 mode='w') as f:
             yaml.dump(model_info, f)
 
+        pipeline_info = self._get_pipeline_info(save_dir)
+        with open(
+                osp.join(save_dir, 'pipeline.yml'), encoding='utf-8',
+                mode='w') as f:
+            yaml.dump(pipeline_info, f)
+
         # 模型保存成功的标志
         open(osp.join(save_dir, '.success'), 'w').close()
         logging.info("The model for the inference deployment is saved in {}.".