|
|
@@ -1,5 +1,7 @@
|
|
|
import argparse
|
|
|
import sys
|
|
|
+import time
|
|
|
+import traceback
|
|
|
from multiprocessing import Manager, Process
|
|
|
from pathlib import Path
|
|
|
from queue import Empty
|
|
|
@@ -13,30 +15,78 @@ def worker(pipeline_name_or_config_path, device, task_queue, batch_size, output_
|
|
|
|
|
|
should_end = False
|
|
|
batch = []
|
|
|
+ processed_count = 0
|
|
|
|
|
|
while not should_end:
|
|
|
try:
|
|
|
input_path = task_queue.get_nowait()
|
|
|
except Empty:
|
|
|
should_end = True
|
|
|
+ except Exception as e:
|
|
|
+ # 处理其他可能的异常
|
|
|
+ print(f"Unexpected error while getting task: {e}", file=sys.stderr)
|
|
|
+ traceback.print_exc()
|
|
|
+ should_end = True
|
|
|
else:
|
|
|
- batch.append(input_path)
|
|
|
+ if input_path is None:
|
|
|
+ should_end = True
|
|
|
+ else:
|
|
|
+ batch.append(input_path)
|
|
|
|
|
|
if batch and (len(batch) == batch_size or should_end):
|
|
|
try:
|
|
|
- for result in pipeline.predict(batch):
|
|
|
- input_path = Path(result["input_path"])
|
|
|
- if result.get("page_index") is not None:
|
|
|
- output_path = f"{input_path.stem}_{result['page_index']}.json"
|
|
|
- else:
|
|
|
- output_path = f"{input_path.stem}.json"
|
|
|
- output_path = str(Path(output_dir, output_path))
|
|
|
- result.save_to_json(output_path)
|
|
|
- print(f"Processed {repr(str(input_path))}")
|
|
|
- except Exception as e:
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
+ # 使用pipeline预测,添加PP-StructureV3的参数
|
|
|
+ results = pipeline.predict(
|
|
|
+ batch,
|
|
|
+ use_doc_orientation_classify=True,
|
|
|
+ use_doc_unwarping=False,
|
|
|
+ use_seal_recognition=True,
|
|
|
+ use_chart_recognition=True,
|
|
|
+ use_table_recognition=True,
|
|
|
+ use_formula_recognition=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ batch_processing_time = time.time() - start_time
|
|
|
+
|
|
|
+ for result in results:
|
|
|
+ try:
|
|
|
+ input_path = Path(result["input_path"])
|
|
|
+
|
|
|
+ # 保存结果 - 按照ppstructurev3的方式处理文件名
|
|
|
+ if result.get("page_index") is not None:
|
|
|
+ output_filename = f"{input_path.stem}_{result['page_index']}"
|
|
|
+ else:
|
|
|
+ output_filename = f"{input_path.stem}"
|
|
|
+
|
|
|
+ # 保存JSON和Markdown文件
|
|
|
+ json_output_path = str(Path(output_dir, f"{output_filename}.json"))
|
|
|
+ md_output_path = str(Path(output_dir, f"{output_filename}.md"))
|
|
|
+
|
|
|
+ result.save_to_json(json_output_path)
|
|
|
+ result.save_to_markdown(md_output_path)
|
|
|
+
|
|
|
+ processed_count += 1
|
|
|
+ print(
|
|
|
+ f"Processed {repr(str(input_path))} -> {json_output_path}, {md_output_path}"
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(
|
|
|
+ f"Error saving result for {result.get('input_path', 'unknown')}: {e}",
|
|
|
+ file=sys.stderr,
|
|
|
+ )
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
print(
|
|
|
- f"Error processing {batch} on {repr(device)}: {e}", file=sys.stderr
|
|
|
+ f"Batch processed: {len(batch)} files in {batch_processing_time:.2f}s on {device}"
|
|
|
)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error processing batch {batch} on {repr(device)}: {e}", file=sys.stderr)
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
batch.clear()
|
|
|
|
|
|
|
|
|
@@ -73,6 +123,11 @@ def main():
|
|
|
default="*",
|
|
|
help="Pattern to find the input files.",
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--test_mode",
|
|
|
+ action="store_true",
|
|
|
+ help="Test mode (process only 20 images)"
|
|
|
+ )
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
input_dir = Path(args.input_dir).resolve()
|
|
|
@@ -103,9 +158,23 @@ def main():
|
|
|
print("Batch size must be greater than 0.", file=sys.stderr)
|
|
|
return 2
|
|
|
|
|
|
+ # 查找图像文件
|
|
|
+ image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"]
|
|
|
+ image_files = []
|
|
|
+ for ext in image_extensions:
|
|
|
+ image_files.extend(list(input_dir.glob(f"*{ext}")))
|
|
|
+ image_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
|
|
|
+ print(f"Found {len(image_files)} image files")
|
|
|
+
|
|
|
+ if args.test_mode:
|
|
|
+ image_files = image_files[:20]
|
|
|
+ print(f"Test mode: processing only {len(image_files)} images")
|
|
|
+
|
|
|
with Manager() as manager:
|
|
|
task_queue = manager.Queue()
|
|
|
- for img_path in input_dir.glob(args.input_glob_pattern):
|
|
|
+
|
|
|
+ # 将图像文件路径放入队列
|
|
|
+ for img_path in image_files:
|
|
|
task_queue.put(str(img_path))
|
|
|
|
|
|
processes = []
|
|
|
@@ -125,11 +194,14 @@ def main():
|
|
|
p.start()
|
|
|
processes.append(p)
|
|
|
|
|
|
+ # 发送结束信号
|
|
|
+ for _ in range(len(device_ids) * args.instances_per_device):
|
|
|
+ task_queue.put(None)
|
|
|
+
|
|
|
for p in processes:
|
|
|
p.join()
|
|
|
|
|
|
print("All done")
|
|
|
-
|
|
|
return 0
|
|
|
|
|
|
|