Browse Source

feat(zhch): 增强任务处理逻辑,添加异常处理和结果保存功能,支持测试模式

zhch158_admin 3 months ago
parent
commit
f2f88b38a6
1 changed files with 86 additions and 14 deletions
  1. 86 14
      zhch/mp_infer.py

+ 86 - 14
zhch/mp_infer.py

@@ -1,5 +1,7 @@
 import argparse
 import sys
+import time
+import traceback
 from multiprocessing import Manager, Process
 from pathlib import Path
 from queue import Empty
@@ -13,30 +15,78 @@ def worker(pipeline_name_or_config_path, device, task_queue, batch_size, output_
 
     should_end = False
     batch = []
+    processed_count = 0
 
     while not should_end:
         try:
             input_path = task_queue.get_nowait()
         except Empty:
             should_end = True
+        except Exception as e:
+            # 处理其他可能的异常
+            print(f"Unexpected error while getting task: {e}", file=sys.stderr)
+            traceback.print_exc()
+            should_end = True
         else:
-            batch.append(input_path)
+            if input_path is None:
+                should_end = True
+            else:
+                batch.append(input_path)
 
         if batch and (len(batch) == batch_size or should_end):
             try:
-                for result in pipeline.predict(batch):
-                    input_path = Path(result["input_path"])
-                    if result.get("page_index") is not None:
-                        output_path = f"{input_path.stem}_{result['page_index']}.json"
-                    else:
-                        output_path = f"{input_path.stem}.json"
-                    output_path = str(Path(output_dir, output_path))
-                    result.save_to_json(output_path)
-                    print(f"Processed {repr(str(input_path))}")
-            except Exception as e:
+                start_time = time.time()
+
+                # 使用pipeline预测,添加PP-StructureV3的参数
+                results = pipeline.predict(
+                    batch,
+                    use_doc_orientation_classify=True,
+                    use_doc_unwarping=False,
+                    use_seal_recognition=True,
+                    use_chart_recognition=True,
+                    use_table_recognition=True,
+                    use_formula_recognition=True,
+                )
+
+                batch_processing_time = time.time() - start_time
+
+                for result in results:
+                    try:
+                        input_path = Path(result["input_path"])
+
+                        # 保存结果 - 按照ppstructurev3的方式处理文件名
+                        if result.get("page_index") is not None:
+                            output_filename = f"{input_path.stem}_{result['page_index']}"
+                        else:
+                            output_filename = f"{input_path.stem}"
+
+                        # 保存JSON和Markdown文件
+                        json_output_path = str(Path(output_dir, f"{output_filename}.json"))
+                        md_output_path = str(Path(output_dir, f"{output_filename}.md"))
+
+                        result.save_to_json(json_output_path)
+                        result.save_to_markdown(md_output_path)
+
+                        processed_count += 1
+                        print(
+                            f"Processed {repr(str(input_path))} -> {json_output_path}, {md_output_path}"
+                        )
+
+                    except Exception as e:
+                        print(
+                            f"Error saving result for {result.get('input_path', 'unknown')}: {e}",
+                            file=sys.stderr,
+                        )
+                        traceback.print_exc()
+
                 print(
-                    f"Error processing {batch} on {repr(device)}: {e}", file=sys.stderr
+                    f"Batch processed: {len(batch)} files in {batch_processing_time:.2f}s on {device}"
                 )
+
+            except Exception as e:
+                print(f"Error processing batch {batch} on {repr(device)}: {e}", file=sys.stderr)
+                traceback.print_exc()
+
             batch.clear()
 
 
@@ -73,6 +123,11 @@ def main():
         default="*",
         help="Pattern to find the input files.",
     )
+    parser.add_argument(
+        "--test_mode", 
+        action="store_true", 
+        help="Test mode (process only 20 images)"
+    )
     args = parser.parse_args()
 
     input_dir = Path(args.input_dir).resolve()
@@ -103,9 +158,23 @@ def main():
         print("Batch size must be greater than 0.", file=sys.stderr)
         return 2
 
+    # 查找图像文件
+    image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"]
+    image_files = []
+    for ext in image_extensions:
+        image_files.extend(list(input_dir.glob(f"*{ext}")))
+        image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
+    print(f"Found {len(image_files)} image files")
+
+    if args.test_mode:
+        image_files = image_files[:20]
+        print(f"Test mode: processing only {len(image_files)} images")
+
     with Manager() as manager:
         task_queue = manager.Queue()
-        for img_path in input_dir.glob(args.input_glob_pattern):
+
+        # 将图像文件路径放入队列
+        for img_path in image_files:
             task_queue.put(str(img_path))
 
         processes = []
@@ -125,11 +194,14 @@ def main():
                 p.start()
                 processes.append(p)
 
+        # 发送结束信号
+        for _ in range(len(device_ids) * args.instances_per_device):
+            task_queue.put(None)
+
         for p in processes:
             p.join()
 
     print("All done")
-
     return 0