Jelajahi Sumber

set device when init model

gaotingquan 1 tahun lalu
induk
melakukan
288e5f83e9

+ 14 - 4
paddlex/inference/pipelines/base.py

@@ -24,9 +24,10 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
 
     __is_base = True
 
-    def __init__(self, predictor_kwargs) -> None:
+    def __init__(self, device, predictor_kwargs={}) -> None:
         super().__init__()
-        self._predictor_kwargs = {} if predictor_kwargs is None else predictor_kwargs
+        self._predictor_kwargs = predictor_kwargs
+        self._device = device
 
     @abstractmethod
     def set_predictor():
@@ -41,9 +42,18 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
     def _create(self, model=None, pipeline=None, *args, **kwargs):
         if model:
             return create_predictor(
-                model=model, *args, **kwargs, **self._predictor_kwargs
+                *args,
+                model=model,
+                device=self._device,
+                **kwargs,
+                **self._predictor_kwargs
             )
         elif pipeline:
-            return pipeline(*args, **kwargs, predictor_kwargs=self._predictor_kwargs)
+            return pipeline(
+                *args,
+                device=self._device,
+                predictor_kwargs=self._predictor_kwargs,
+                **kwargs
+            )
         else:
             raise Exception()

+ 1 - 1
paddlex/inference/pipelines/formula_recognition.py

@@ -33,7 +33,7 @@ class FormulaRecognitionPipeline(BasePipeline):
         device=None,
         predictor_kwargs=None,
     ):
-        super().__init__(predictor_kwargs=predictor_kwargs)
+        super().__init__(device, predictor_kwargs)
         self._build_predictor(layout_model, formula_rec_model)
         self.set_predictor(
             layout_batch_size=layout_batch_size,

+ 1 - 1
paddlex/inference/pipelines/ocr.py

@@ -32,7 +32,7 @@ class OCRPipeline(BasePipeline):
         device=None,
         predictor_kwargs=None,
     ):
-        super().__init__(predictor_kwargs=predictor_kwargs)
+        super().__init__(device, predictor_kwargs)
         self._build_predictor(text_det_model, text_rec_model)
         self.set_predictor(
             text_det_batch_size=text_det_batch_size,

+ 1 - 3
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -62,9 +62,7 @@ class PPChatOCRPipeline(_TableRecPipeline):
         predictor_kwargs=None,
         _build_models=True,
     ):
-        super().__init__(
-            predictor_kwargs=predictor_kwargs,
-        )
+        super().__init__(device, predictor_kwargs)
         if _build_models:
             self._build_predictor(
                 layout_model=layout_model,

+ 1 - 1
paddlex/inference/pipelines/seal_recognition.py

@@ -50,7 +50,7 @@ class SealOCRPipeline(BasePipeline):
         device=None,
         predictor_kwargs=None,
     ):
-        super().__init__(predictor_kwargs=predictor_kwargs)
+        super().__init__(device, predictor_kwargs)
         self._build_predictor(
             layout_model=layout_model,
             text_det_model=text_det_model,

+ 1 - 1
paddlex/inference/pipelines/single_model_pipeline.py

@@ -18,7 +18,7 @@ from .base import BasePipeline
 class _SingleModelPipeline(BasePipeline):
 
     def __init__(self, model, batch_size=1, device=None, predictor_kwargs=None):
-        super().__init__(predictor_kwargs=predictor_kwargs)
+        super().__init__(device, predictor_kwargs)
         self._build_predictor(model)
         self.set_predictor(batch_size=batch_size, device=device)
 

+ 4 - 3
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -26,9 +26,10 @@ class _TableRecPipeline(BasePipeline):
 
     def __init__(
         self,
-        predictor_kwargs=None,
+        device,
+        predictor_kwargs,
     ):
-        super().__init__(predictor_kwargs=predictor_kwargs)
+        super().__init__(device, predictor_kwargs)
 
     def _build_predictor(
         self,
@@ -179,7 +180,7 @@ class TableRecPipeline(_TableRecPipeline):
         device=None,
         predictor_kwargs=None,
     ):
-        super().__init__(predictor_kwargs=predictor_kwargs)
+        super().__init__(device, predictor_kwargs)
         self._build_predictor(layout_model, text_det_model, text_rec_model, table_model)
         self.set_predictor(
             layout_batch_size=layout_batch_size,