gaotingquan 1 anno fa
parent
commit
684cd3ef73

+ 4 - 4
paddlex/inference/components/transforms/image/common.py

@@ -13,14 +13,13 @@
 # limitations under the License.
 
 import math
-import tempfile
 from pathlib import Path
 from copy import deepcopy
 
 import numpy as np
 import cv2
 
-from .....utils.cache import CACHE_DIR
+from .....utils.cache import CACHE_DIR, temp_file_manager
 from ....utils.io import ImageReader, ImageWriter, PDFReader
 from ...utils.mixin import BatchSizeMixin
 from ...base import BaseComponent
@@ -92,8 +91,9 @@ class ReadImage(_BaseRead):
     def apply(self, img):
         """apply"""
         if isinstance(img, np.ndarray):
-            # TODO(gaotingquan): set delete to True
-            with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
+            with temp_file_manager.temp_file_context(
+                delete=True, suffix=".png"
+            ) as temp_file:
                 img_path = Path(temp_file.name)
                 self._writer.write(img_path, img)
                 yield [

+ 4 - 2
paddlex/inference/components/transforms/ts/common.py

@@ -14,11 +14,11 @@
 
 from pathlib import Path
 from copy import deepcopy
-import tempfile
 import joblib
 import numpy as np
 import pandas as pd
 
+from .....utils.cache import CACHE_DIR, temp_file_manager
 from .....utils.download import download
 from .....utils.cache import CACHE_DIR
 from ....utils.io.readers import CSVReader
@@ -59,7 +59,9 @@ class ReadTS(_BaseRead):
 
     def apply(self, ts):
         if isinstance(ts, pd.DataFrame):
-            with tempfile.NamedTemporaryFile(suffix=".csv", delete=True) as temp_file:
+            with temp_file_manager.temp_file_context(
+                delete=True, suffix=".csv"
+            ) as temp_file:
                 input_path = Path(temp_file.name)
                 ts_path = input_path.as_posix()
                 self._writer.write(ts_path, ts)

+ 1 - 1
paddlex/inference/utils/pp_option.py

@@ -95,7 +95,7 @@ class PaddlePredictorOption(object):
         set_env_for_device(device)
         if device_type not in ("cpu"):
             if device_ids is None or len(device_ids) > 1:
-                logging.warning(f"The device ID has been set to {device_id}.")
+                logging.debug(f"The device ID has been set to {device_id}.")
 
     @register("min_subgraph_size")
     def set_min_subgraph_size(self, min_subgraph_size: int):

+ 48 - 1
paddlex/utils/cache.py

@@ -15,17 +15,23 @@
 
 import os
 import os.path as osp
+from pathlib import Path
 import inspect
 import functools
 import pickle
 import hashlib
-
+import tempfile
+import atexit
 import filelock
 
+from . import logging
+
+
 DEFAULT_CACHE_DIR = osp.abspath(osp.join(os.path.expanduser("~"), ".paddlex"))
 CACHE_DIR = os.environ.get("PADDLE_PDX_CACHE_HOME", DEFAULT_CACHE_DIR)
 FUNC_CACHE_DIR = osp.join(CACHE_DIR, "func_ret")
 FILE_LOCK_DIR = osp.join(CACHE_DIR, "locks")
+TEMP_DIR = osp.join(CACHE_DIR, "temp")
 
 
 def create_cache_dir(*args, **kwargs):
@@ -99,3 +105,44 @@ def persist(cond=None):
         return _wrapper
 
     return _deco
+
+
+class TempFileManager:
+    def __init__(self):
+        self.temp_files = []
+        Path(TEMP_DIR).mkdir(parents=True, exist_ok=True)
+        atexit.register(self.cleanup)
+
+    def create_temp_file(self, **kwargs):
+        temp_file = tempfile.NamedTemporaryFile(dir=TEMP_DIR, **kwargs)
+        self.temp_files.append(temp_file)
+        return temp_file
+
+    def cleanup(self):
+        for temp_file in self.temp_files:
+            try:
+                temp_file.close()
+                os.remove(temp_file.name)
+            except FileNotFoundError as e:
+                pass
+        self.temp_files = []
+
+    class TempFileContextManager:
+        def __init__(self, manager, **kwargs):
+            self.manager = manager
+            self.kwargs = kwargs
+            self.temp_file = None
+
+        def __enter__(self):
+            self.temp_file = self.manager.create_temp_file(**self.kwargs)
+            return self.temp_file
+
+        def __exit__(self, exc_type, exc_value, traceback):
+            if self.temp_file:
+                self.temp_file.close()
+
+    def temp_file_context(self, **kwargs):
+        return self.TempFileContextManager(self, **kwargs)
+
+
+temp_file_manager = TempFileManager()