batch_build_dataset.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import concurrent.futures
  2. import glob
  3. import os
  4. import threading
  5. import fitz
  6. from magic_pdf.data.dataset import PymuDocDataset
  7. from magic_pdf.data.utils import fitz_doc_to_image # PyMuPDF
  8. def partition_array_greedy(arr, k):
  9. """Partition an array into k parts using a simple greedy approach.
  10. Parameters:
  11. -----------
  12. arr : list
  13. The input array of integers
  14. k : int
  15. Number of partitions to create
  16. Returns:
  17. --------
  18. partitions : list of lists
  19. The k partitions of the array
  20. """
  21. # Handle edge cases
  22. if k <= 0:
  23. raise ValueError('k must be a positive integer')
  24. if k > len(arr):
  25. k = len(arr) # Adjust k if it's too large
  26. if k == 1:
  27. return [list(range(len(arr)))]
  28. if k == len(arr):
  29. return [[i] for i in range(len(arr))]
  30. # Sort the array in descending order
  31. sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True)
  32. # Initialize k empty partitions
  33. partitions = [[] for _ in range(k)]
  34. partition_sums = [0] * k
  35. # Assign each element to the partition with the smallest current sum
  36. for idx in sorted_indices:
  37. # Find the partition with the smallest sum
  38. min_sum_idx = partition_sums.index(min(partition_sums))
  39. # Add the element to this partition
  40. partitions[min_sum_idx].append(idx) # Store the original index
  41. partition_sums[min_sum_idx] += arr[idx][1]
  42. return partitions
  43. def process_pdf_batch(pdf_jobs, idx):
  44. """Process a batch of PDF pages using multiple threads.
  45. Parameters:
  46. -----------
  47. pdf_jobs : list of tuples
  48. List of (pdf_path, page_num) tuples
  49. output_dir : str or None
  50. Directory to save images to
  51. num_threads : int
  52. Number of threads to use
  53. **kwargs :
  54. Additional arguments for process_pdf_page
  55. Returns:
  56. --------
  57. images : list
  58. List of processed images
  59. """
  60. images = []
  61. for pdf_path, _ in pdf_jobs:
  62. doc = fitz.open(pdf_path)
  63. tmp = []
  64. for page_num in range(len(doc)):
  65. page = doc[page_num]
  66. tmp.append(fitz_doc_to_image(page))
  67. images.append(tmp)
  68. return (idx, images)
  69. def batch_build_dataset(pdf_paths, k, lang=None):
  70. """Process multiple PDFs by partitioning them into k balanced parts and
  71. processing each part in parallel.
  72. Parameters:
  73. -----------
  74. pdf_paths : list
  75. List of paths to PDF files
  76. k : int
  77. Number of partitions to create
  78. output_dir : str or None
  79. Directory to save images to
  80. threads_per_worker : int
  81. Number of threads to use per worker
  82. **kwargs :
  83. Additional arguments for process_pdf_page
  84. Returns:
  85. --------
  86. all_images : list
  87. List of all processed images
  88. """
  89. # Get page counts for each PDF
  90. pdf_info = []
  91. total_pages = 0
  92. for pdf_path in pdf_paths:
  93. try:
  94. doc = fitz.open(pdf_path)
  95. num_pages = len(doc)
  96. pdf_info.append((pdf_path, num_pages))
  97. total_pages += num_pages
  98. doc.close()
  99. except Exception as e:
  100. print(f'Error opening {pdf_path}: {e}')
  101. # Partition the jobs based on page countEach job has 1 page
  102. partitions = partition_array_greedy(pdf_info, k)
  103. for i, partition in enumerate(partitions):
  104. print(f'Partition {i+1}: {len(partition)} pdfs')
  105. # Process each partition in parallel
  106. all_images_h = {}
  107. with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
  108. # Submit one task per partition
  109. futures = []
  110. for sn, partition in enumerate(partitions):
  111. # Get the jobs for this partition
  112. partition_jobs = [pdf_info[idx] for idx in partition]
  113. # Submit the task
  114. future = executor.submit(
  115. process_pdf_batch,
  116. partition_jobs,
  117. sn
  118. )
  119. futures.append(future)
  120. # Process results as they complete
  121. for i, future in enumerate(concurrent.futures.as_completed(futures)):
  122. try:
  123. idx, images = future.result()
  124. print(f'Partition {i+1} completed: processed {len(images)} images')
  125. all_images_h[idx] = images
  126. except Exception as e:
  127. print(f'Error processing partition: {e}')
  128. results = [None] * len(pdf_paths)
  129. for i in range(len(partitions)):
  130. partition = partitions[i]
  131. for j in range(len(partition)):
  132. with open(pdf_info[partition[j]][0], 'rb') as f:
  133. pdf_bytes = f.read()
  134. dataset = PymuDocDataset(pdf_bytes, lang=lang)
  135. dataset.set_images(all_images_h[i][j])
  136. results[partition[j]] = dataset
  137. return results