Răsfoiți Sursa

pass batch_size when create model in pipeline (#2912)

Tingquan Gao 10 luni în urmă
părinte
comite
d1612a3f12

+ 1 - 0
paddlex/inference/pipelines_new/base.py

@@ -87,6 +87,7 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
             model_name=config["model_name"],
             model_dir=model_dir,
             device=self.device,
+            batch_size=config.get("batch_size", 1),
             pp_option=self.pp_option,
             use_hpip=self.use_hpip,
             **kwargs,

+ 0 - 9
paddlex/inference/pipelines_new/ocr/pipeline.py

@@ -71,11 +71,6 @@ class OCRPipeline(BasePipeline):
                 "TextLineOrientation",
                 {"model_config_error": "config error for textline_orientation_model!"},
             )
-            # TODO: add batch_size
-            # batch_size = textline_orientation_config.get("batch_size", 1)
-            # self.textline_orientation_model = self.create_model(
-            #     textline_orientation_config, batch_size=batch_size
-            # )
             self.textline_orientation_model = self.create_model(
                 textline_orientation_config
             )
@@ -116,10 +111,6 @@ class OCRPipeline(BasePipeline):
             "TextRecognition",
             {"model_config_error": "config error for text_rec_model!"},
         )
-        # TODO: add batch_size
-        # batch_size = text_rec_config.get("batch_size", 1)
-        # self.text_rec_model = self.create_model(text_rec_config,
-        #     batch_size=batch_size)
         self.text_rec_score_thresh = text_rec_config.get("score_thresh", 0)
         self.text_rec_model = self.create_model(text_rec_config)