demo_gradio_annotion.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666
  1. """
  2. Layout Inference Web Application with Gradio - Annotation Version
  3. A Gradio-based layout inference tool that supports image uploads and multiple backend inference engines.
  4. This version adds an image annotation feature, allowing users to draw bounding boxes on an image and send both the image and the boxes to the model.
  5. """
  6. import gradio as gr
  7. import json
  8. import os
  9. import io
  10. import tempfile
  11. import base64
  12. import zipfile
  13. import uuid
  14. import re
  15. from pathlib import Path
  16. from PIL import Image
  17. import requests
  18. from gradio_image_annotation import image_annotator
  19. # Local utility imports
  20. from dots_ocr.utils import dict_promptmode_to_prompt
  21. from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
  22. from dots_ocr.utils.demo_utils.display import read_image
  23. from dots_ocr.utils.doc_utils import load_images_from_pdf
  24. # Add DotsOCRParser import
  25. from dots_ocr.parser import DotsOCRParser
  26. # ==================== Configuration ====================
  27. DEFAULT_CONFIG = {
  28. 'ip': "127.0.0.1",
  29. 'port_vllm': 8000,
  30. 'min_pixels': MIN_PIXELS,
  31. 'max_pixels': MAX_PIXELS,
  32. 'test_images_dir': "./assets/showcase_origin",
  33. }
  34. # ==================== Global Variables ====================
  35. # Store the current configuration
  36. current_config = DEFAULT_CONFIG.copy()
  37. # Create a DotsOCRParser instance
  38. dots_parser = DotsOCRParser(
  39. ip=DEFAULT_CONFIG['ip'],
  40. port=DEFAULT_CONFIG['port_vllm'],
  41. dpi=200,
  42. min_pixels=DEFAULT_CONFIG['min_pixels'],
  43. max_pixels=DEFAULT_CONFIG['max_pixels']
  44. )
  45. # Store processing results
  46. processing_results = {
  47. 'original_image': None,
  48. 'processed_image': None,
  49. 'layout_result': None,
  50. 'markdown_content': None,
  51. 'cells_data': None,
  52. 'temp_dir': None,
  53. 'session_id': None,
  54. 'result_paths': None,
  55. 'annotation_data': None # Store annotation data
  56. }
  57. # ==================== Utility Functions ====================
  58. def read_image_v2(img):
  59. """Reads an image, supporting URLs and local paths."""
  60. if isinstance(img, str) and img.startswith(("http://", "https://")):
  61. with requests.get(img, stream=True) as response:
  62. response.raise_for_status()
  63. img = Image.open(io.BytesIO(response.content))
  64. elif isinstance(img, str):
  65. img, _, _ = read_image(img, use_native=True)
  66. elif isinstance(img, Image.Image):
  67. pass
  68. else:
  69. raise ValueError(f"Invalid image type: {type(img)}")
  70. return img
  71. def get_test_images():
  72. """Gets the list of test images."""
  73. test_images = []
  74. test_dir = current_config['test_images_dir']
  75. if os.path.exists(test_dir):
  76. test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)
  77. if name.lower().endswith(('.png', '.jpg', '.jpeg'))]
  78. return test_images
  79. def create_temp_session_dir():
  80. """Creates a unique temporary directory for each processing request."""
  81. session_id = uuid.uuid4().hex[:8]
  82. temp_dir = os.path.join(tempfile.gettempdir(), f"dots_ocr_demo_{session_id}")
  83. os.makedirs(temp_dir, exist_ok=True)
  84. return temp_dir, session_id
  85. def parse_image_with_bbox(parser, image, prompt_mode, bbox=None, fitz_preprocess=False):
  86. """
  87. Processes an image using DotsOCRParser, with support for the bbox parameter.
  88. """
  89. # Create a temporary session directory
  90. temp_dir, session_id = create_temp_session_dir()
  91. try:
  92. # Save the PIL Image to a temporary file
  93. temp_image_path = os.path.join(temp_dir, f"input_{session_id}.png")
  94. image.save(temp_image_path, "PNG")
  95. # Use the high-level parse_image interface, passing the bbox parameter
  96. filename = f"demo_{session_id}"
  97. results = parser.parse_image(
  98. input_path=temp_image_path,
  99. filename=filename,
  100. prompt_mode=prompt_mode,
  101. save_dir=temp_dir,
  102. bbox=bbox,
  103. fitz_preprocess=fitz_preprocess
  104. )
  105. # Parse the results
  106. if not results:
  107. raise ValueError("No results returned from parser")
  108. result = results[0] # parse_image returns a list with a single result
  109. # Read the result files
  110. layout_image = None
  111. cells_data = None
  112. md_content = None
  113. filtered = False
  114. # Read the layout image
  115. if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
  116. layout_image = Image.open(result['layout_image_path'])
  117. # Read the JSON data
  118. if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
  119. with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
  120. cells_data = json.load(f)
  121. # Read the Markdown content
  122. if 'md_content_path' in result and os.path.exists(result['md_content_path']):
  123. with open(result['md_content_path'], 'r', encoding='utf-8') as f:
  124. md_content = f.read()
  125. # Check for the original response file (if JSON parsing fails)
  126. if 'filtered' in result:
  127. filtered = result['filtered']
  128. return {
  129. 'layout_image': layout_image,
  130. 'cells_data': cells_data,
  131. 'md_content': md_content,
  132. 'filtered': filtered,
  133. 'temp_dir': temp_dir,
  134. 'session_id': session_id,
  135. 'result_paths': result
  136. }
  137. except Exception as e:
  138. # Clean up the temporary directory on error
  139. import shutil
  140. if os.path.exists(temp_dir):
  141. shutil.rmtree(temp_dir, ignore_errors=True)
  142. raise e
  143. def process_annotation_data(annotation_data):
  144. """Processes annotation data, converting it to the format required by the model."""
  145. if not annotation_data or not annotation_data.get('boxes'):
  146. return None, None
  147. # Get image and box data
  148. image = annotation_data.get('image')
  149. boxes = annotation_data.get('boxes', [])
  150. if not boxes:
  151. return image, None
  152. # Ensure the image is in PIL Image format
  153. if image is not None:
  154. import numpy as np
  155. if isinstance(image, np.ndarray):
  156. image = Image.fromarray(image)
  157. elif not isinstance(image, Image.Image):
  158. # If it's another format, try to convert it
  159. try:
  160. image = Image.open(image) if isinstance(image, str) else Image.fromarray(image)
  161. except Exception as e:
  162. print(f"Image format conversion failed: {e}")
  163. return None, None
  164. # Get the coordinate information of the box (only one box)
  165. box = boxes[0]
  166. bbox = [box['xmin'], box['ymin'], box['xmax'], box['ymax']]
  167. return image, bbox
  168. # ==================== Core Processing Function ====================
  169. def process_image_inference_with_annotation(annotation_data, test_image_input,
  170. prompt_mode, server_ip, server_port, min_pixels, max_pixels,
  171. fitz_preprocess=False
  172. ):
  173. """Core function for image inference, supporting annotation data."""
  174. global current_config, processing_results, dots_parser
  175. # First, clean up previous processing results
  176. if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
  177. import shutil
  178. try:
  179. shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
  180. except Exception as e:
  181. print(f"Failed to clean up previous temporary directory: {e}")
  182. # Reset processing results
  183. processing_results = {
  184. 'original_image': None,
  185. 'processed_image': None,
  186. 'layout_result': None,
  187. 'markdown_content': None,
  188. 'cells_data': None,
  189. 'temp_dir': None,
  190. 'session_id': None,
  191. 'result_paths': None,
  192. 'annotation_data': annotation_data
  193. }
  194. # Update configuration
  195. current_config.update({
  196. 'ip': server_ip,
  197. 'port_vllm': server_port,
  198. 'min_pixels': min_pixels,
  199. 'max_pixels': max_pixels
  200. })
  201. # Update parser configuration
  202. dots_parser.ip = server_ip
  203. dots_parser.port = server_port
  204. dots_parser.min_pixels = min_pixels
  205. dots_parser.max_pixels = max_pixels
  206. # Determine the input source and process annotation data
  207. image = None
  208. bbox = None
  209. # Prioritize processing annotation data
  210. if annotation_data and annotation_data.get('image') is not None:
  211. image, bbox = process_annotation_data(annotation_data)
  212. if image is not None:
  213. # If there's a bbox, force the use of 'prompt_grounding_ocr' mode
  214. assert bbox is not None
  215. prompt_mode = "prompt_grounding_ocr"
  216. # If there's no annotation data, check the test image input
  217. if image is None and test_image_input and test_image_input != "":
  218. try:
  219. image = read_image_v2(test_image_input)
  220. except Exception as e:
  221. return None, f"Failed to read test image: {e}", "", "", gr.update(value=None), ""
  222. if image is None:
  223. return None, "Please select a test image or add an image in the annotation component", "", "", gr.update(value=None), ""
  224. if bbox is None:
  225. return None, "Please select a bounding box by mouse", "Please select a bounding box by mouse", "", "", gr.update(value=None)
  226. try:
  227. # Process using DotsOCRParser, passing the bbox parameter
  228. original_image = image
  229. parse_result = parse_image_with_bbox(dots_parser, image, prompt_mode, bbox, fitz_preprocess)
  230. # Extract parsing results
  231. layout_image = parse_result['layout_image']
  232. cells_data = parse_result['cells_data']
  233. md_content = parse_result['md_content']
  234. filtered = parse_result['filtered']
  235. # Store the results
  236. processing_results.update({
  237. 'original_image': original_image,
  238. 'processed_image': None,
  239. 'layout_result': layout_image,
  240. 'markdown_content': md_content,
  241. 'cells_data': cells_data,
  242. 'temp_dir': parse_result['temp_dir'],
  243. 'session_id': parse_result['session_id'],
  244. 'result_paths': parse_result['result_paths'],
  245. 'annotation_data': annotation_data
  246. })
  247. # Handle the case where parsing fails
  248. if filtered:
  249. info_text = f"""
  250. **Image Information:**
  251. - Original Dimensions: {original_image.width} x {original_image.height}
  252. - Processing Mode: {'Region OCR' if bbox else 'Full Image OCR'}
  253. - Processing Status: JSON parsing failed, using cleaned text output
  254. - Server: {current_config['ip']}:{current_config['port_vllm']}
  255. - Session ID: {parse_result['session_id']}
  256. - Box Coordinates: {bbox if bbox else 'None'}
  257. """
  258. return (
  259. md_content or "No markdown content generated",
  260. info_text,
  261. md_content or "No markdown content generated",
  262. md_content or "No markdown content generated",
  263. gr.update(visible=False),
  264. ""
  265. )
  266. # Handle the case where JSON parsing succeeds
  267. num_elements = len(cells_data) if cells_data else 0
  268. info_text = f"""
  269. **Image Information:**
  270. - Original Dimensions: {original_image.width} x {original_image.height}
  271. - Processing Mode: {'Region OCR' if bbox else 'Full Image OCR'}
  272. - Server: {current_config['ip']}:{current_config['port_vllm']}
  273. - Detected {num_elements} layout elements
  274. - Session ID: {parse_result['session_id']}
  275. - Box Coordinates: {bbox if bbox else 'None'}
  276. """
  277. # Current page JSON output
  278. current_json = ""
  279. if cells_data:
  280. try:
  281. current_json = json.dumps(cells_data, ensure_ascii=False, indent=2)
  282. except:
  283. current_json = str(cells_data)
  284. # Create a downloadable ZIP file
  285. download_zip_path = None
  286. if parse_result['temp_dir']:
  287. download_zip_path = os.path.join(parse_result['temp_dir'], f"layout_results_{parse_result['session_id']}.zip")
  288. try:
  289. with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
  290. for root, dirs, files in os.walk(parse_result['temp_dir']):
  291. for file in files:
  292. if file.endswith('.zip'):
  293. continue
  294. file_path = os.path.join(root, file)
  295. arcname = os.path.relpath(file_path, parse_result['temp_dir'])
  296. zipf.write(file_path, arcname)
  297. except Exception as e:
  298. print(f"Failed to create download ZIP: {e}")
  299. download_zip_path = None
  300. return (
  301. md_content or "No markdown content generated",
  302. info_text,
  303. md_content or "No markdown content generated",
  304. md_content or "No markdown content generated",
  305. gr.update(value=download_zip_path, visible=True) if download_zip_path else gr.update(visible=False),
  306. current_json
  307. )
  308. except Exception as e:
  309. return f"An error occurred during processing: {e}", f"An error occurred during processing: {e}", "", "", gr.update(value=None), ""
  310. def load_image_to_annotator(test_image_input):
  311. """Loads an image into the annotation component."""
  312. image = None
  313. # Check the test image input
  314. if test_image_input and test_image_input != "":
  315. try:
  316. image = read_image_v2(test_image_input)
  317. except Exception as e:
  318. return None
  319. if image is None:
  320. return None
  321. # Return the format required by the annotation component
  322. return {
  323. "image": image,
  324. "boxes": []
  325. }
  326. def clear_all_data():
  327. """Clears all data."""
  328. global processing_results
  329. # Clean up the temporary directory
  330. if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
  331. import shutil
  332. try:
  333. shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
  334. except Exception as e:
  335. print(f"Failed to clean up temporary directory: {e}")
  336. # Reset processing results
  337. processing_results = {
  338. 'original_image': None,
  339. 'processed_image': None,
  340. 'layout_result': None,
  341. 'markdown_content': None,
  342. 'cells_data': None,
  343. 'temp_dir': None,
  344. 'session_id': None,
  345. 'result_paths': None,
  346. 'annotation_data': None
  347. }
  348. return (
  349. "", # Clear test image selection
  350. None, # Clear annotation component
  351. "Waiting for processing results...", # Reset info display
  352. "## Waiting for processing results...", # Reset Markdown display
  353. "🕐 Waiting for parsing results...", # Clear raw Markdown text
  354. gr.update(visible=False), # Hide download button
  355. "🕐 Waiting for parsing results..." # Clear JSON
  356. )
  357. def update_prompt_display(prompt_mode):
  358. """Updates the displayed prompt content."""
  359. return dict_promptmode_to_prompt[prompt_mode]
  360. # ==================== Gradio Interface ====================
  361. def create_gradio_interface():
  362. """Creates the Gradio interface."""
  363. # CSS styling to match the reference style
  364. css = """
  365. footer {
  366. visibility: hidden;
  367. }
  368. #info_box {
  369. padding: 10px;
  370. background-color: #f8f9fa;
  371. border-radius: 8px;
  372. border: 1px solid #dee2e6;
  373. margin: 10px 0;
  374. font-size: 14px;
  375. }
  376. #markdown_tabs {
  377. height: 100%;
  378. }
  379. #annotation_component {
  380. border-radius: 8px;
  381. }
  382. """
  383. with gr.Blocks(theme="ocean", css=css, title='dots.ocr - Annotation') as demo:
  384. # Title
  385. gr.HTML("""
  386. <div style="display: flex; align-items: center; justify-content: center; margin-bottom: 20px;">
  387. <h1 style="margin: 0; font-size: 2em;">🔍 dots.ocr - Annotation Version</h1>
  388. </div>
  389. <div style="text-align: center; margin-bottom: 10px;">
  390. <em>Supports image annotation, drawing boxes, and sending box information to the model for OCR.</em>
  391. </div>
  392. """)
  393. with gr.Row():
  394. # Left side: Input and Configuration
  395. with gr.Column(scale=1, variant="compact"):
  396. gr.Markdown("### 📁 Select Example")
  397. test_images = get_test_images()
  398. test_image_input = gr.Dropdown(
  399. label="Select Example",
  400. choices=[""] + test_images,
  401. value="",
  402. show_label=True
  403. )
  404. # Button to load image into the annotation component
  405. load_btn = gr.Button("📷 Load Image to Annotation Area", variant="secondary")
  406. prompt_mode = gr.Dropdown(
  407. label="Select Prompt",
  408. # choices=["prompt_layout_all_en", "prompt_layout_only_en", "prompt_ocr", "prompt_grounding_ocr"],
  409. choices=["prompt_grounding_ocr"],
  410. value="prompt_grounding_ocr",
  411. show_label=True,
  412. info="If a box is drawn, 'prompt_grounding_ocr' mode will be used automatically."
  413. )
  414. # Display the current prompt content
  415. prompt_display = gr.Textbox(
  416. label="Current Prompt Content",
  417. # value=dict_promptmode_to_prompt[list(dict_promptmode_to_prompt.keys())[0]],
  418. value=dict_promptmode_to_prompt["prompt_grounding_ocr"],
  419. lines=4,
  420. max_lines=8,
  421. interactive=False,
  422. show_copy_button=True
  423. )
  424. gr.Markdown("### ⚙️ Actions")
  425. process_btn = gr.Button("🔍 Parse", variant="primary")
  426. clear_btn = gr.Button("🗑️ Clear", variant="secondary")
  427. gr.Markdown("### 🛠️ Configuration")
  428. fitz_preprocess = gr.Checkbox(
  429. label="Enable fitz_preprocess",
  430. value=False,
  431. info="Performs fitz preprocessing on the image input, converting the image to a PDF and then to a 200dpi image."
  432. )
  433. with gr.Row():
  434. server_ip = gr.Textbox(
  435. label="Server IP",
  436. value=DEFAULT_CONFIG['ip']
  437. )
  438. server_port = gr.Number(
  439. label="Port",
  440. value=DEFAULT_CONFIG['port_vllm'],
  441. precision=0
  442. )
  443. with gr.Row():
  444. min_pixels = gr.Number(
  445. label="Min Pixels",
  446. value=DEFAULT_CONFIG['min_pixels'],
  447. precision=0
  448. )
  449. max_pixels = gr.Number(
  450. label="Max Pixels",
  451. value=DEFAULT_CONFIG['max_pixels'],
  452. precision=0
  453. )
  454. # Right side: Result Display
  455. with gr.Column(scale=6, variant="compact"):
  456. with gr.Row():
  457. # Image Annotation Area
  458. with gr.Column(scale=3):
  459. gr.Markdown("### 🎯 Image Annotation Area")
  460. gr.Markdown("""
  461. **Instructions:**
  462. - Method 1: Select an example image on the left and click "Load Image to Annotation Area".
  463. - Method 2: Upload an image directly in the annotation area below (drag and drop or click to upload).
  464. - Use the mouse to draw a box on the image to select the region for recognition.
  465. - Only one box can be drawn. To draw a new one, please delete the old one first.
  466. - **Hotkey: Press the Delete key to remove the selected box.**
  467. - After drawing a box, clicking Parse will automatically use the Region OCR mode.
  468. """)
  469. annotator = image_annotator(
  470. value=None,
  471. label="Image Annotation",
  472. height=600,
  473. show_label=False,
  474. elem_id="annotation_component",
  475. single_box=True, # Only allow one box; a new box will replace the old one
  476. box_min_size=10,
  477. interactive=True,
  478. disable_edit_boxes=True, # Disable the edit dialog
  479. label_list=["OCR Region"], # Set the default label
  480. label_colors=[(255, 0, 0)], # Set color to red
  481. use_default_label=True, # Use the default label
  482. image_type="pil" # Ensure it returns a PIL Image format
  483. )
  484. # Information Display
  485. info_display = gr.Markdown(
  486. "Waiting for processing results...",
  487. elem_id="info_box"
  488. )
  489. # Result Display Area
  490. with gr.Column(scale=3):
  491. gr.Markdown("### ✅ Results")
  492. with gr.Tabs(elem_id="markdown_tabs"):
  493. with gr.TabItem("Markdown Rendered View"):
  494. md_output = gr.Markdown(
  495. "## Please upload an image and click the Parse button for recognition...",
  496. label="Markdown Preview",
  497. max_height=1000,
  498. latex_delimiters=[
  499. {"left": "$$", "right": "$$", "display": True},
  500. {"left": "$", "right": "$", "display": False},
  501. ],
  502. show_copy_button=False,
  503. elem_id="markdown_output"
  504. )
  505. with gr.TabItem("Markdown Raw Text"):
  506. md_raw_output = gr.Textbox(
  507. value="🕐 Waiting for parsing results...",
  508. label="Markdown Raw Text",
  509. max_lines=100,
  510. lines=38,
  511. show_copy_button=True,
  512. elem_id="markdown_output",
  513. show_label=False
  514. )
  515. with gr.TabItem("JSON Result"):
  516. json_output = gr.Textbox(
  517. value="🕐 Waiting for parsing results...",
  518. label="JSON Result",
  519. max_lines=100,
  520. lines=38,
  521. show_copy_button=True,
  522. elem_id="markdown_output",
  523. show_label=False
  524. )
  525. # Download Button
  526. with gr.Row():
  527. download_btn = gr.DownloadButton(
  528. "⬇️ Download Results",
  529. visible=False
  530. )
  531. # Event Binding
  532. # When the prompt mode changes, update the displayed content
  533. prompt_mode.change(
  534. fn=update_prompt_display,
  535. inputs=prompt_mode,
  536. outputs=prompt_display,
  537. show_progress=False
  538. )
  539. # Load image into the annotation component
  540. load_btn.click(
  541. fn=load_image_to_annotator,
  542. inputs=[test_image_input],
  543. outputs=annotator,
  544. show_progress=False
  545. )
  546. # Process Inference
  547. process_btn.click(
  548. fn=process_image_inference_with_annotation,
  549. inputs=[
  550. annotator, test_image_input,
  551. prompt_mode, server_ip, server_port, min_pixels, max_pixels,
  552. fitz_preprocess
  553. ],
  554. outputs=[
  555. md_output, info_display, md_raw_output, md_raw_output,
  556. download_btn, json_output
  557. ],
  558. show_progress=True
  559. )
  560. # Clear Data
  561. clear_btn.click(
  562. fn=clear_all_data,
  563. outputs=[
  564. test_image_input, annotator,
  565. info_display, md_output, md_raw_output,
  566. download_btn, json_output
  567. ],
  568. show_progress=False
  569. )
  570. return demo
  571. # ==================== Main Program ====================
  572. if __name__ == "__main__":
  573. demo = create_gradio_interface()
  574. demo.queue().launch(
  575. server_name="0.0.0.0",
  576. server_port=7861, # Use a different port to avoid conflicts
  577. debug=True
  578. )