Pārlūkot izejas kodu

support disable_print and disable_save

gaotingquan 1 gadu atpakaļ
vecāks
revīzija
a1dbd7bf7e

+ 4 - 0
paddlex/modules/base/predictor/predictor.py

@@ -42,12 +42,16 @@ class BasePredictor(ABC, FromDictMixin, Node):
         output,
         pre_transforms=None,
         post_transforms=None,
+        disable_print=False,
+        disable_save=False,
     ):
         super().__init__()
         self.model_name = model_name
         self.model_dir = model_dir
         self.kernel_option = kernel_option
         self.output = output
+        self.disable_print = disable_print
+        self.disable_save = disable_save
         self.other_src = self.load_other_src()
 
         logging.debug(

+ 4 - 3
paddlex/modules/image_classification/predictor/predictor.py

@@ -75,7 +75,8 @@ class ClsPredictor(BasePredictor):
     def _get_post_transforms_from_config(self):
         """get postprocess transforms"""
         post_transforms = self.other_src.post_transforms
-        post_transforms.extend(
-            [T.PrintResult(), T.SaveClsResults(self.output, self.other_src.labels)]
-        )
+        if not self.disable_print:
+            post_transforms.append(T.PrintResult())
+        if not self.disable_save:
+            post_transforms.append(T.SaveClsResults(self.output, self.other_src.labels))
         return post_transforms

+ 8 - 4
paddlex/modules/object_detection/predictor/predictor.py

@@ -87,7 +87,11 @@ class DetPredictor(BasePredictor):
 
     def _get_post_transforms_from_config(self):
         """get postprocess transforms"""
-        return [
-            T.SaveDetResults(save_dir=self.output, labels=self.other_src.labels),
-            T.PrintResult(),
-        ]
+        post_transforms = []
+        if not self.disable_print:
+            post_transforms.append(T.PrintResult())
+        if not self.disable_save:
+            post_transforms.append(
+                T.SaveDetResults(save_dir=self.output, labels=self.other_src.labels)
+            )
+        return post_transforms

+ 6 - 1
paddlex/modules/semantic_segmentation/predictor/predictor.py

@@ -110,4 +110,9 @@ class SegPredictor(BasePredictor):
 
     def _get_post_transforms_from_config(self):
         """_get_post_transforms_from_config"""
-        return [T.GeneratePCMap(), T.SaveSegResults(self.output), T.PrintResult()]
+        post_transforms = []
+        if not self.disable_print:
+            post_transforms.append(T.PrintResult())
+        if not self.disable_save:
+            post_transforms.extend([T.GeneratePCMap(), T.SaveSegResults(self.output)])
+        return post_transforms

+ 4 - 1
paddlex/modules/table_recognition/predictor/predictor.py

@@ -92,4 +92,7 @@ class TableRecPredictor(BasePredictor):
 
     def _get_post_transforms_from_config(self):
         """get postprocess transforms"""
-        return [T.TableLabelDecode(), T.SaveTableResults(self.output)]
+        post_transforms = [T.TableLabelDecode()]
+        if not self.disable_save:
+            post_transforms.append(T.SaveTableResults(self.output))
+        return post_transforms

+ 7 - 3
paddlex/modules/text_detection/predictor/predictor.py

@@ -83,8 +83,12 @@ class TextDetPredictor(BasePredictor):
                 use_dilation=False,
                 score_mode="fast",
                 box_type="quad",
-            ),
-            T.SaveTextDetResults(self.output),
-            T.PrintResult(),
+            )
         ]
+        if not self.disable_print:
+            post_transforms.append(T.PrintResult())
+        if not self.disable_save:
+            post_transforms.append(
+                T.SaveTextDetResults(self.output),
+            )
         return post_transforms

+ 4 - 8
paddlex/modules/text_recognition/predictor/predictor.py

@@ -83,13 +83,9 @@ class TextRecPredictor(BasePredictor):
     def _get_post_transforms_from_config(self):
         """get postprocess transforms"""
         if self.model_name == "LaTeX_OCR_rec":
-            post_transforms = [
-                T.LaTeXOCRDecode(self.other_src.PostProcess),
-                T.PrintResult(),
-            ]
+            post_transforms = [T.LaTeXOCRDecode(self.other_src.PostProcess)]
         else:
-            post_transforms = [
-                T.CTCLabelDecode(self.other_src.PostProcess),
-                T.PrintResult(),
-            ]
+            post_transforms = [T.CTCLabelDecode(self.other_src.PostProcess)]
+        if not self.disable_print:
+            post_transforms.append(T.PrintResult())
         return post_transforms

+ 3 - 0
paddlex/pipelines/OCR/pipeline.py

@@ -40,6 +40,7 @@ class OCRPipeline(BasePipeline):
         device="gpu",
         **kwargs,
     ):
+        super().__init__(**kwargs)
         self.text_det_model_name = text_det_model_name
         self.text_rec_model_name = text_rec_model_name
         self.text_det_model_dir = text_det_model_dir
@@ -103,6 +104,8 @@ Only support: {text_rec_models}."
             self.text_rec_model_name,
             self.text_rec_model_dir,
             kernel_option=text_rec_kernel_option,
+            disable_print=self.disable_print,
+            disable_save=self.disable_save,
         )
 
     def predict(self, input):

+ 3 - 1
paddlex/pipelines/base/pipeline.py

@@ -43,8 +43,10 @@ class BasePipeline(ABC, metaclass=AutoRegisterABCMetaClass):
 
     __is_base = True
 
-    def __init__(self):
+    def __init__(self, disable_print=False, disable_save=False):
         super().__init__()
+        self.disable_print = disable_print
+        self.disable_save = disable_save
 
     @abstractmethod
     def load_model(self):

+ 3 - 1
paddlex/pipelines/image_classification/pipeline.py

@@ -32,7 +32,7 @@ class ClsPipeline(BasePipeline):
         device="gpu",
         **kwargs,
     ):
-        super().__init__()
+        super().__init__(**kwargs)
         self.model_name = model_name
         self.model_dir = model_dir
         self.output = output
@@ -64,6 +64,8 @@ class ClsPipeline(BasePipeline):
             model_dir=self.model_dir,
             output=self.output,
             kernel_option=kernel_option,
+            disable_print=self.disable_print,
+            disable_save=self.disable_save,
         )
 
     def get_kernel_option(self):

+ 3 - 0
paddlex/pipelines/instance_segmentation/pipeline.py

@@ -32,6 +32,7 @@ class InstanceSegPipeline(BasePipeline):
         device="gpu",
         **kwargs,
     ):
+        super().__init__(**kwargs)
         self.model_name = model_name
         self.model_dir = model_dir
         self.output = output
@@ -59,6 +60,8 @@ class InstanceSegPipeline(BasePipeline):
             model_dir=self.model_dir,
             output=self.output,
             kernel_option=kernel_option,
+            disable_print=self.disable_print,
+            disable_save=self.disable_save,
         )
 
     def predict(self, input):

+ 3 - 0
paddlex/pipelines/object_detection/pipeline.py

@@ -32,6 +32,7 @@ class DetPipeline(BasePipeline):
         device="gpu",
         **kwargs,
     ):
+        super().__init__(**kwargs)
         self.model_name = model_name
         self.model_dir = model_dir
         self.output = output
@@ -59,6 +60,8 @@ class DetPipeline(BasePipeline):
             model_dir=self.model_dir,
             output=self.output,
             kernel_option=kernel_option,
+            disable_print=self.disable_print,
+            disable_save=self.disable_save,
         )
 
     def predict(self, input):

+ 3 - 0
paddlex/pipelines/semantic_segmentation/pipeline.py

@@ -32,6 +32,7 @@ class SegPipeline(BasePipeline):
         device="gpu",
         **kwargs,
     ):
+        super().__init__(**kwargs)
         self.model_name = model_name
         self.model_dir = model_dir
         self.output = output
@@ -59,6 +60,8 @@ class SegPipeline(BasePipeline):
             model_dir=self.model_dir,
             output=self.output,
             kernel_option=kernel_option,
+            disable_print=self.disable_print,
+            disable_save=self.disable_save,
         )
 
     def predict(self, input):