Selaa lähdekoodia

[Docs] Update parallel inference documentation (#4072)

* Fix

* Use context manager
Lin Manhui 5 kuukautta sitten
vanhempi
commit
9e41543a5c

+ 27 - 27
docs/pipeline_usage/instructions/parallel_inference.en.md

@@ -130,36 +130,36 @@ def main():
             "Please specify at least two devices for performing parallel inference.",
             file=sys.stderr,
         )
-        sys.exit(2)
+        return 2
 
     if args.batch_size <= 0:
         print("Batch size must be greater than 0.", file=sys.stderr)
-        sys.exit(2)
-
-    manager = Manager()
-    task_queue = manager.Queue()
-    for img_path in input_dir.glob(args.input_glob_pattern):
-        task_queue.put(str(img_path))
-
-    processes = []
-    for device_id in device_ids:
-        for _ in range(args.instances_per_device):
-            device = constr_device(device_type, [device_id])
-            p = Process(
-                target=worker,
-                args=(
-                    args.pipeline,
-                    device,
-                    task_queue,
-                    args.batch_size,
-                    str(output_dir),
-                ),
-            )
-            p.start()
-            processes.append(p)
-
-    for p in processes:
-        p.join()
+        return 2
+
+    with Manager() as manager:
+        task_queue = manager.Queue()
+        for img_path in input_dir.glob(args.input_glob_pattern):
+            task_queue.put(str(img_path))
+
+        processes = []
+        for device_id in device_ids:
+            for _ in range(args.instances_per_device):
+                device = constr_device(device_type, [device_id])
+                p = Process(
+                    target=worker,
+                    args=(
+                        args.pipeline,
+                        device,
+                        task_queue,
+                        args.batch_size,
+                        str(output_dir),
+                    ),
+                )
+                p.start()
+                processes.append(p)
+
+        for p in processes:
+            p.join()
 
     print("All done")
 

+ 27 - 27
docs/pipeline_usage/instructions/parallel_inference.md

@@ -130,36 +130,36 @@ def main():
             "Please specify at least two devices for performing parallel inference.",
             file=sys.stderr,
         )
-        sys.exit(2)
+        return 2
 
     if args.batch_size <= 0:
         print("Batch size must be greater than 0.", file=sys.stderr)
-        sys.exit(2)
-
-    manager = Manager()
-    task_queue = manager.Queue()
-    for img_path in input_dir.glob(args.input_glob_pattern):
-        task_queue.put(str(img_path))
-
-    processes = []
-    for device_id in device_ids:
-        for _ in range(args.instances_per_device):
-            device = constr_device(device_type, [device_id])
-            p = Process(
-                target=worker,
-                args=(
-                    args.pipeline,
-                    device,
-                    task_queue,
-                    args.batch_size,
-                    str(output_dir),
-                ),
-            )
-            p.start()
-            processes.append(p)
-
-    for p in processes:
-        p.join()
+        return 2
+
+    with Manager() as manager:
+        task_queue = manager.Queue()
+        for img_path in input_dir.glob(args.input_glob_pattern):
+            task_queue.put(str(img_path))
+
+        processes = []
+        for device_id in device_ids:
+            for _ in range(args.instances_per_device):
+                device = constr_device(device_type, [device_id])
+                p = Process(
+                    target=worker,
+                    args=(
+                        args.pipeline,
+                        device,
+                        task_queue,
+                        args.batch_size,
+                        str(output_dir),
+                    ),
+                )
+                p.start()
+                processes.append(p)
+
+        for p in processes:
+            p.join()
 
     print("All done")