batch_build_dataset.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import os
  2. import glob
  3. import threading
  4. import concurrent.futures
  5. import fitz
  6. from magic_pdf.data.utils import fitz_doc_to_image # PyMuPDF
  7. from magic_pdf.data.dataset import PymuDocDataset
  8. def partition_array_greedy(arr, k):
  9. """
  10. Partition an array into k parts using a simple greedy approach.
  11. Parameters:
  12. -----------
  13. arr : list
  14. The input array of integers
  15. k : int
  16. Number of partitions to create
  17. Returns:
  18. --------
  19. partitions : list of lists
  20. The k partitions of the array
  21. """
  22. # Handle edge cases
  23. if k <= 0:
  24. raise ValueError("k must be a positive integer")
  25. if k > len(arr):
  26. k = len(arr) # Adjust k if it's too large
  27. if k == 1:
  28. return [list(range(len(arr)))]
  29. if k == len(arr):
  30. return [[i] for i in range(len(arr))]
  31. # Sort the array in descending order
  32. sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True)
  33. # Initialize k empty partitions
  34. partitions = [[] for _ in range(k)]
  35. partition_sums = [0] * k
  36. # Assign each element to the partition with the smallest current sum
  37. for idx in sorted_indices:
  38. # Find the partition with the smallest sum
  39. min_sum_idx = partition_sums.index(min(partition_sums))
  40. # Add the element to this partition
  41. partitions[min_sum_idx].append(idx) # Store the original index
  42. partition_sums[min_sum_idx] += arr[idx][1]
  43. return partitions
  44. def process_pdf_batch(pdf_jobs, idx):
  45. """
  46. Process a batch of PDF pages using multiple threads.
  47. Parameters:
  48. -----------
  49. pdf_jobs : list of tuples
  50. List of (pdf_path, page_num) tuples
  51. output_dir : str or None
  52. Directory to save images to
  53. num_threads : int
  54. Number of threads to use
  55. **kwargs :
  56. Additional arguments for process_pdf_page
  57. Returns:
  58. --------
  59. images : list
  60. List of processed images
  61. """
  62. images = []
  63. for pdf_path, _ in pdf_jobs:
  64. doc = fitz.open(pdf_path)
  65. tmp = []
  66. for page_num in range(len(doc)):
  67. page = doc[page_num]
  68. tmp.append(fitz_doc_to_image(page))
  69. images.append(tmp)
  70. return (idx, images)
  71. def batch_build_dataset(pdf_paths, k, lang=None):
  72. """
  73. Process multiple PDFs by partitioning them into k balanced parts and processing each part in parallel.
  74. Parameters:
  75. -----------
  76. pdf_paths : list
  77. List of paths to PDF files
  78. k : int
  79. Number of partitions to create
  80. output_dir : str or None
  81. Directory to save images to
  82. threads_per_worker : int
  83. Number of threads to use per worker
  84. **kwargs :
  85. Additional arguments for process_pdf_page
  86. Returns:
  87. --------
  88. all_images : list
  89. List of all processed images
  90. """
  91. # Get page counts for each PDF
  92. pdf_info = []
  93. total_pages = 0
  94. for pdf_path in pdf_paths:
  95. try:
  96. doc = fitz.open(pdf_path)
  97. num_pages = len(doc)
  98. pdf_info.append((pdf_path, num_pages))
  99. total_pages += num_pages
  100. doc.close()
  101. except Exception as e:
  102. print(f"Error opening {pdf_path}: {e}")
  103. # Partition the jobs based on page countEach job has 1 page
  104. partitions = partition_array_greedy(pdf_info, k)
  105. for i, partition in enumerate(partitions):
  106. print(f"Partition {i+1}: {len(partition)} pdfs")
  107. # Process each partition in parallel
  108. all_images_h = {}
  109. with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
  110. # Submit one task per partition
  111. futures = []
  112. for sn, partition in enumerate(partitions):
  113. # Get the jobs for this partition
  114. partition_jobs = [pdf_info[idx] for idx in partition]
  115. # Submit the task
  116. future = executor.submit(
  117. process_pdf_batch,
  118. partition_jobs,
  119. sn
  120. )
  121. futures.append(future)
  122. # Process results as they complete
  123. for i, future in enumerate(concurrent.futures.as_completed(futures)):
  124. try:
  125. idx, images = future.result()
  126. print(f"Partition {i+1} completed: processed {len(images)} images")
  127. all_images_h[idx] = images
  128. except Exception as e:
  129. print(f"Error processing partition: {e}")
  130. results = [None] * len(pdf_paths)
  131. for i in range(len(partitions)):
  132. partition = partitions[i]
  133. for j in range(len(partition)):
  134. with open(pdf_info[partition[j]][0], "rb") as f:
  135. pdf_bytes = f.read()
  136. dataset = PymuDocDataset(pdf_bytes, lang=lang)
  137. dataset.set_images(all_images_h[i][j])
  138. results[partition[j]] = dataset
  139. return results