Преглед на файлове

[Fix] Stop writing to stdout when downloading (#4237)

* Stop writing to stdout when downloading

* Fix bug
Lin Manhui преди 5 месеца
родител
ревизия
da8ed098aa
променени са 2 файла, в които са добавени 11 реда и са изтрити 7 реда
  1. 1 0
      paddlex/inference/serving/basic_serving/_app.py
  2. 10 7
      paddlex/utils/download.py

+ 1 - 0
paddlex/inference/serving/basic_serving/_app.py

@@ -114,6 +114,7 @@ class PipelineWrapper(Generic[PipelineT]):
         if not self._closed:
             self._queue.put(None)
             await call_async(self._thread.join)
+            self._closed = True
 
     def _worker(self):
         while not self._closed:

+ 10 - 7
paddlex/utils/download.py

@@ -39,14 +39,14 @@ class _ProgressPrinter(object):
             str_ += "\n"
             self._last_time = 0
         if time.time() - self._last_time >= self._flush_intvl:
-            sys.stdout.write(f"\r{str_}")
+            sys.stderr.write(f"\r{str_}")
             self._last_time = time.time()
-            sys.stdout.flush()
+            sys.stderr.flush()
 
 
 def _download(url, save_path, print_progress):
     if print_progress:
-        print(f"Connecting to {url} ...")
+        print(f"Connecting to {url} ...", file=sys.stderr)
 
     with requests.get(url, stream=True, timeout=15) as r:
         r.raise_for_status()
@@ -62,7 +62,10 @@ def _download(url, save_path, print_progress):
                 total_length = int(total_length)
                 if print_progress:
                     printer = _ProgressPrinter()
-                    print(f"Downloading {os.path.basename(save_path)} ...")
+                    print(
+                        f"Downloading {os.path.basename(save_path)} ...",
+                        file=sys.stderr,
+                    )
                 for data in r.iter_content(chunk_size=4096):
                     dl += len(data)
                     f.write(data)
@@ -95,17 +98,17 @@ def _extract_tar_file(file_path, extd_dir):
                 try:
                     f.extract(file, extd_dir)
                 except KeyError:
-                    print(f"File {file} not found in the archive.")
+                    print(f"File {file} not found in the archive.", file=sys.stderr)
                 yield total_num, index
     except Exception as e:
-        print(f"An error occurred: {e}")
+        print(f"An error occurred: {e}", file=sys.stderr)
 
 
 def _extract(file_path, extd_dir, print_progress):
     """extract"""
     if print_progress:
         printer = _ProgressPrinter()
-        print(f"Extracting {os.path.basename(file_path)}")
+        print(f"Extracting {os.path.basename(file_path)}", file=sys.stderr)
 
     if zipfile.is_zipfile(file_path):
         handler = _extract_zip_file