Sfoglia il codice sorgente

rm unnecessary device setting (#2247)

Tingquan Gao 1 anno fa
parent
commit
4647847dc6

+ 3 - 2
paddlex/inference/components/paddle_predictor/predictor.py

@@ -37,7 +37,7 @@ class BasePaddlePredictor(BaseComponent, PPEngineMixin):
         self.model_prefix = model_prefix
         self._is_initialized = False
 
-    def _reset(self):
+    def reset(self):
         if not self.option:
             self.option = PaddlePredictorOption()
         (
@@ -164,7 +164,7 @@ No need to generate again."
 
     def apply(self, **kwargs):
         if not self._is_initialized:
-            self._reset()
+            self.reset()
 
         x = self.to_batch(**kwargs)
         for idx in range(len(x)):
@@ -196,6 +196,7 @@ class ImagePredictor(BasePaddlePredictor):
 
 
 class ImageDetPredictor(BasePaddlePredictor):
+
     INPUT_KEYS = [["img", "scale_factors"], ["img", "scale_factors", "img_size"]]
     OUTPUT_KEYS = [["boxes"], ["boxes", "masks"]]
     DEAULT_INPUTS = {"img": "img", "scale_factors": "scale_factors"}

+ 4 - 2
paddlex/inference/models/base/basic_predictor.py

@@ -72,10 +72,12 @@ class BasicPredictor(
     def set_predictor(self, batch_size=None, device=None, pp_option=None):
         if batch_size:
             self.components["ReadCmp"].batch_size = batch_size
-        if device:
+        if device and device != self.pp_option.device:
             self.pp_option.device = device
-        if pp_option:
+            self.components["PPEngineCmp"].reset()
+        if pp_option and pp_option != self.pp_option:
             self.pp_option = pp_option
+            self.components["PPEngineCmp"].reset()
 
     def _has_setter(self, attr):
         prop = getattr(self.__class__, attr, None)

+ 0 - 1
paddlex/inference/pipelines/formula_recognition.py

@@ -38,7 +38,6 @@ class FormulaRecognitionPipeline(BasePipeline):
         self.set_predictor(
             layout_batch_size=layout_batch_size,
             formula_rec_batch_size=formula_rec_batch_size,
-            device=device,
         )
 
     def _build_predictor(self, layout_model, formula_rec_model):

+ 0 - 1
paddlex/inference/pipelines/ocr.py

@@ -37,7 +37,6 @@ class OCRPipeline(BasePipeline):
         self.set_predictor(
             text_det_batch_size=text_det_batch_size,
             text_rec_batch_size=text_rec_batch_size,
-            device=device,
         )
 
     def _build_predictor(self, text_det_model, text_rec_model):

+ 0 - 1
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -83,7 +83,6 @@ class PPChatOCRPipeline(_TableRecPipeline):
                 doc_image_ori_cls_batch_size=doc_image_ori_cls_batch_size,
                 doc_image_unwarp_batch_size=doc_image_unwarp_batch_size,
                 seal_text_det_batch_size=seal_text_det_batch_size,
-                device=device,
             )
 
         # get base prompt from yaml info

+ 0 - 1
paddlex/inference/pipelines/seal_recognition.py

@@ -63,7 +63,6 @@ class SealOCRPipeline(BasePipeline):
             layout_batch_size=layout_batch_size,
             text_det_batch_size=text_det_batch_size,
             text_rec_batch_size=text_rec_batch_size,
-            device=device,
         )
 
     def _build_predictor(

+ 1 - 1
paddlex/inference/pipelines/single_model_pipeline.py

@@ -20,7 +20,7 @@ class _SingleModelPipeline(BasePipeline):
     def __init__(self, model, batch_size=1, device=None, predictor_kwargs=None):
         super().__init__(device, predictor_kwargs)
         self._build_predictor(model)
-        self.set_predictor(batch_size=batch_size, device=device)
+        self.set_predictor(batch_size=batch_size)
 
     def _build_predictor(self, model):
         self.model = self._create(model)

+ 0 - 1
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -187,5 +187,4 @@ class TableRecPipeline(_TableRecPipeline):
             text_det_batch_size=text_det_batch_size,
             text_rec_batch_size=text_rec_batch_size,
             table_batch_size=table_batch_size,
-            device=device,
         )

+ 2 - 2
paddlex/model.py

@@ -59,8 +59,8 @@ class _ModelBasedInference(_BaseModel):
     def predict(self, *args, **kwargs):
         yield from self._predictor(*args, **kwargs)
 
-    def set_predict(self, **kwargs):
-        self._predictor.set_predict(**kwargs)
+    def set_predictor(self, **kwargs):
+        self._predictor.set_predictor(**kwargs)
 
 
 class _ModelBasedConfig(_BaseModel):