demo_gradio.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950
  1. """
  2. Layout Inference Web Application with Gradio
  3. A Gradio-based layout inference tool that supports image uploads and multiple backend inference engines.
  4. It adopts a reference-style interface design while preserving the original inference logic.
  5. """
  6. import gradio as gr
  7. import json
  8. import os
  9. import io
  10. import tempfile
  11. import base64
  12. import zipfile
  13. import uuid
  14. import re
  15. from pathlib import Path
  16. from PIL import Image
  17. import requests
  18. # Local tool imports
  19. from dots_ocr.utils import dict_promptmode_to_prompt
  20. from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
  21. from dots_ocr.utils.demo_utils.display import read_image
  22. from dots_ocr.utils.doc_utils import load_images_from_pdf
  23. # Add DotsOCRParser import
  24. from dots_ocr.parser import DotsOCRParser
  25. # ==================== Configuration ====================
  26. DEFAULT_CONFIG = {
  27. 'ip': "127.0.0.1",
  28. 'port_vllm': 8101,
  29. 'min_pixels': MIN_PIXELS,
  30. 'max_pixels': MAX_PIXELS,
  31. 'test_images_dir': "../demo/assets/showcase_origin",
  32. 'model_name': "DotsOCR",
  33. }
  34. # ==================== Global Variables ====================
  35. # Store current configuration
  36. current_config = DEFAULT_CONFIG.copy()
  37. # Create DotsOCRParser instance
  38. dots_parser = DotsOCRParser(
  39. ip=DEFAULT_CONFIG['ip'],
  40. port=DEFAULT_CONFIG['port_vllm'],
  41. dpi=200,
  42. min_pixels=DEFAULT_CONFIG['min_pixels'],
  43. max_pixels=DEFAULT_CONFIG['max_pixels'],
  44. model_name= DEFAULT_CONFIG['model_name']
  45. )
  46. # Store processing results
  47. processing_results = {
  48. 'original_image': None,
  49. 'processed_image': None,
  50. 'layout_result': None,
  51. 'markdown_content': None,
  52. 'cells_data': None,
  53. 'temp_dir': None,
  54. 'session_id': None,
  55. 'result_paths': None,
  56. 'pdf_results': None # Store multi-page PDF results
  57. }
  58. # PDF caching mechanism
  59. pdf_cache = {
  60. "images": [],
  61. "current_page": 0,
  62. "total_pages": 0,
  63. "file_type": None, # 'image' or 'pdf'
  64. "is_parsed": False, # Whether it has been parsed
  65. "results": [] # Store parsing results for each page
  66. }
  67. def read_image_v2(img):
  68. """Reads an image, supports URLs and local paths"""
  69. if isinstance(img, str) and img.startswith(("http://", "https://")):
  70. with requests.get(img, stream=True) as response:
  71. response.raise_for_status()
  72. img = Image.open(io.BytesIO(response.content))
  73. elif isinstance(img, str):
  74. img, _, _ = read_image(img, use_native=True)
  75. elif isinstance(img, Image.Image):
  76. pass
  77. else:
  78. raise ValueError(f"Invalid image type: {type(img)}")
  79. return img
  80. def load_file_for_preview(file_path):
  81. """Loads a file for preview, supports PDF and image files"""
  82. global pdf_cache
  83. if not file_path or not os.path.exists(file_path):
  84. return None, "<div id='page_info_box'>0 / 0</div>"
  85. file_ext = os.path.splitext(file_path)[1].lower()
  86. if file_ext == '.pdf':
  87. try:
  88. # Read PDF and convert to images (one image per page)
  89. pages = load_images_from_pdf(file_path)
  90. pdf_cache["file_type"] = "pdf"
  91. except Exception as e:
  92. return None, f"<div id='page_info_box'>PDF loading failed: {str(e)}</div>"
  93. elif file_ext in ['.jpg', '.jpeg', '.png']:
  94. # For image files, read directly as a single-page image
  95. try:
  96. image = Image.open(file_path)
  97. pages = [image]
  98. pdf_cache["file_type"] = "image"
  99. except Exception as e:
  100. return None, f"<div id='page_info_box'>Image loading failed: {str(e)}</div>"
  101. else:
  102. return None, "<div id='page_info_box'>Unsupported file format</div>"
  103. pdf_cache["images"] = pages
  104. pdf_cache["current_page"] = 0
  105. pdf_cache["total_pages"] = len(pages)
  106. pdf_cache["is_parsed"] = False
  107. pdf_cache["results"] = []
  108. return pages[0], f"<div id='page_info_box'>1 / {len(pages)}</div>"
  109. def turn_page(direction):
  110. """Page turning function"""
  111. global pdf_cache
  112. if not pdf_cache["images"]:
  113. return None, "<div id='page_info_box'>0 / 0</div>", "", ""
  114. if direction == "prev":
  115. pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
  116. elif direction == "next":
  117. pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
  118. index = pdf_cache["current_page"]
  119. current_image = pdf_cache["images"][index] # Use the original image by default
  120. page_info = f"<div id='page_info_box'>{index + 1} / {pdf_cache['total_pages']}</div>"
  121. # If parsed, display the results for the current page
  122. current_md = ""
  123. current_md_raw = ""
  124. current_json = ""
  125. if pdf_cache["is_parsed"] and index < len(pdf_cache["results"]):
  126. result = pdf_cache["results"][index]
  127. if 'md_content' in result:
  128. # Get the raw markdown content
  129. current_md_raw = result['md_content']
  130. # Process the content after LaTeX rendering
  131. current_md = result['md_content'] if result['md_content'] else ""
  132. if 'cells_data' in result:
  133. try:
  134. current_json = json.dumps(result['cells_data'], ensure_ascii=False, indent=2)
  135. except:
  136. current_json = str(result.get('cells_data', ''))
  137. # Use the image with layout boxes (if available)
  138. if 'layout_image' in result and result['layout_image']:
  139. current_image = result['layout_image']
  140. return current_image, page_info, current_json
  141. def get_test_images():
  142. """Gets the list of test images"""
  143. test_images = []
  144. test_dir = current_config['test_images_dir']
  145. if os.path.exists(test_dir):
  146. test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)
  147. if name.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf'))]
  148. return test_images
  149. def convert_image_to_base64(image):
  150. """Converts a PIL image to base64 encoding"""
  151. buffered = io.BytesIO()
  152. image.save(buffered, format="PNG")
  153. img_str = base64.b64encode(buffered.getvalue()).decode()
  154. return f"data:image/png;base64,{img_str}"
  155. def create_temp_session_dir():
  156. """Creates a unique temporary directory for each processing request"""
  157. session_id = uuid.uuid4().hex[:8]
  158. temp_dir = os.path.join(tempfile.gettempdir(), f"dots_ocr_demo_{session_id}")
  159. os.makedirs(temp_dir, exist_ok=True)
  160. return temp_dir, session_id
  161. def parse_image_with_high_level_api(parser, image, prompt_mode, fitz_preprocess=False):
  162. """
  163. Processes using the high-level API parse_image from DotsOCRParser
  164. """
  165. # Create a temporary session directory
  166. temp_dir, session_id = create_temp_session_dir()
  167. try:
  168. # Save the PIL Image as a temporary file
  169. temp_image_path = os.path.join(temp_dir, f"input_{session_id}.png")
  170. image.save(temp_image_path, "PNG")
  171. # Use the high-level API parse_image
  172. filename = f"demo_{session_id}"
  173. results = parser.parse_image(
  174. # input_path=temp_image_path,
  175. input_path=image,
  176. filename=filename,
  177. prompt_mode=prompt_mode,
  178. save_dir=temp_dir,
  179. fitz_preprocess=fitz_preprocess
  180. )
  181. # Parse the results
  182. if not results:
  183. raise ValueError("No results returned from parser")
  184. result = results[0] # parse_image returns a list with a single result
  185. # Read the result files
  186. layout_image = None
  187. cells_data = None
  188. md_content = None
  189. raw_response = None
  190. filtered = False
  191. # Read the layout image
  192. if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
  193. layout_image = Image.open(result['layout_image_path'])
  194. # Read the JSON data
  195. if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
  196. with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
  197. cells_data = json.load(f)
  198. # Read the Markdown content
  199. if 'md_content_path' in result and os.path.exists(result['md_content_path']):
  200. with open(result['md_content_path'], 'r', encoding='utf-8') as f:
  201. md_content = f.read()
  202. # Check for the raw response file (when JSON parsing fails)
  203. if 'filtered' in result:
  204. filtered = result['filtered']
  205. return {
  206. 'layout_image': layout_image,
  207. 'cells_data': cells_data,
  208. 'md_content': md_content,
  209. 'filtered': filtered,
  210. 'temp_dir': temp_dir,
  211. 'session_id': session_id,
  212. 'result_paths': result,
  213. 'input_width': result['input_width'],
  214. 'input_height': result['input_height'],
  215. }
  216. except Exception as e:
  217. # Clean up the temporary directory on error
  218. import shutil
  219. if os.path.exists(temp_dir):
  220. shutil.rmtree(temp_dir, ignore_errors=True)
  221. raise e
  222. def parse_pdf_with_high_level_api(parser, pdf_path, prompt_mode):
  223. """
  224. Processes using the high-level API parse_pdf from DotsOCRParser
  225. """
  226. # Create a temporary session directory
  227. temp_dir, session_id = create_temp_session_dir()
  228. try:
  229. # Use the high-level API parse_pdf
  230. filename = f"demo_{session_id}"
  231. results = parser.parse_pdf(
  232. input_path=pdf_path,
  233. filename=filename,
  234. prompt_mode=prompt_mode,
  235. save_dir=temp_dir
  236. )
  237. # Parse the results
  238. if not results:
  239. raise ValueError("No results returned from parser")
  240. # Handle multi-page results
  241. parsed_results = []
  242. all_md_content = []
  243. all_cells_data = []
  244. for i, result in enumerate(results):
  245. page_result = {
  246. 'page_no': result.get('page_no', i),
  247. 'layout_image': None,
  248. 'cells_data': None,
  249. 'md_content': None,
  250. 'filtered': False
  251. }
  252. # Read the layout image
  253. if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
  254. page_result['layout_image'] = Image.open(result['layout_image_path'])
  255. # Read the JSON data
  256. if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
  257. with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
  258. page_result['cells_data'] = json.load(f)
  259. all_cells_data.extend(page_result['cells_data'])
  260. # Read the Markdown content
  261. if 'md_content_path' in result and os.path.exists(result['md_content_path']):
  262. with open(result['md_content_path'], 'r', encoding='utf-8') as f:
  263. page_content = f.read()
  264. page_result['md_content'] = page_content
  265. all_md_content.append(page_content)
  266. # Check for the raw response file (when JSON parsing fails)
  267. page_result['filtered'] = False
  268. if 'filtered' in page_result:
  269. page_result['filtered'] = page_result['filtered']
  270. parsed_results.append(page_result)
  271. # Merge the content of all pages
  272. combined_md = "\n\n---\n\n".join(all_md_content) if all_md_content else ""
  273. return {
  274. 'parsed_results': parsed_results,
  275. 'combined_md_content': combined_md,
  276. 'combined_cells_data': all_cells_data,
  277. 'temp_dir': temp_dir,
  278. 'session_id': session_id,
  279. 'total_pages': len(results)
  280. }
  281. except Exception as e:
  282. # Clean up the temporary directory on error
  283. import shutil
  284. if os.path.exists(temp_dir):
  285. shutil.rmtree(temp_dir, ignore_errors=True)
  286. raise e
  287. # ==================== Core Processing Function ====================
  288. def process_image_inference(test_image_input, file_input,
  289. prompt_mode, server_ip, server_port, min_pixels, max_pixels,
  290. fitz_preprocess=False
  291. ):
  292. """Core function to handle image/PDF inference"""
  293. global current_config, processing_results, dots_parser, pdf_cache
  294. # First, clean up previous processing results to avoid confusion with the download button
  295. if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
  296. import shutil
  297. try:
  298. shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
  299. except Exception as e:
  300. print(f"Failed to clean up previous temporary directory: {e}")
  301. # Reset processing results
  302. processing_results = {
  303. 'original_image': None,
  304. 'processed_image': None,
  305. 'layout_result': None,
  306. 'markdown_content': None,
  307. 'cells_data': None,
  308. 'temp_dir': None,
  309. 'session_id': None,
  310. 'result_paths': None,
  311. 'pdf_results': None
  312. }
  313. # Update configuration
  314. current_config.update({
  315. 'ip': server_ip,
  316. 'port_vllm': server_port,
  317. 'min_pixels': min_pixels,
  318. 'max_pixels': max_pixels
  319. })
  320. # Update parser configuration
  321. dots_parser.ip = server_ip
  322. dots_parser.port = server_port
  323. dots_parser.min_pixels = min_pixels
  324. dots_parser.max_pixels = max_pixels
  325. # Determine the input source
  326. input_file_path = None
  327. image = None
  328. # Prioritize file input (supports PDF)
  329. if file_input is not None:
  330. input_file_path = file_input
  331. file_ext = os.path.splitext(input_file_path)[1].lower()
  332. if file_ext == '.pdf':
  333. # PDF file processing
  334. try:
  335. return process_pdf_file(input_file_path, prompt_mode)
  336. except Exception as e:
  337. return None, f"PDF processing failed: {e}", "", "", gr.update(value=None), None, ""
  338. elif file_ext in ['.jpg', '.jpeg', '.png']:
  339. # Image file processing
  340. try:
  341. image = Image.open(input_file_path)
  342. except Exception as e:
  343. return None, f"Failed to read image file: {e}", "", "", gr.update(value=None), None, ""
  344. # If no file input, check the test image input
  345. if image is None:
  346. if test_image_input and test_image_input != "":
  347. file_ext = os.path.splitext(test_image_input)[1].lower()
  348. if file_ext == '.pdf':
  349. return process_pdf_file(test_image_input, prompt_mode)
  350. else:
  351. try:
  352. image = read_image_v2(test_image_input)
  353. except Exception as e:
  354. return None, f"Failed to read test image: {e}", "", "", gr.update(value=None), gr.update(value=None), None, ""
  355. if image is None:
  356. return None, "Please upload image/PDF file or select test image", "", "", gr.update(value=None), None, ""
  357. try:
  358. # Clear PDF cache (for image processing)
  359. pdf_cache["images"] = []
  360. pdf_cache["current_page"] = 0
  361. pdf_cache["total_pages"] = 0
  362. pdf_cache["is_parsed"] = False
  363. pdf_cache["results"] = []
  364. # Process using the high-level API of DotsOCRParser
  365. original_image = image
  366. parse_result = parse_image_with_high_level_api(dots_parser, image, prompt_mode, fitz_preprocess)
  367. # Extract parsing results
  368. layout_image = parse_result['layout_image']
  369. cells_data = parse_result['cells_data']
  370. md_content = parse_result['md_content']
  371. filtered = parse_result['filtered']
  372. # Handle parsing failure case
  373. if filtered:
  374. # JSON parsing failed, only text content is available
  375. info_text = f"""
  376. **Image Information:**
  377. - Original Size: {original_image.width} x {original_image.height}
  378. - Processing: JSON parsing failed, using cleaned text output
  379. - Server: {current_config['ip']}:{current_config['port_vllm']}
  380. - Session ID: {parse_result['session_id']}
  381. """
  382. # Store results
  383. processing_results.update({
  384. 'original_image': original_image,
  385. 'processed_image': None,
  386. 'layout_result': None,
  387. 'markdown_content': md_content,
  388. 'cells_data': None,
  389. 'temp_dir': parse_result['temp_dir'],
  390. 'session_id': parse_result['session_id'],
  391. 'result_paths': parse_result['result_paths']
  392. })
  393. return (
  394. original_image, # No layout image
  395. info_text,
  396. md_content,
  397. md_content, # Display raw markdown text
  398. gr.update(visible=False), # Hide download button
  399. None, # Page info
  400. "" # Current page JSON output
  401. )
  402. # JSON parsing successful case
  403. # Save the raw markdown content (before LaTeX processing)
  404. md_content_raw = md_content or "No markdown content generated"
  405. # Store results
  406. processing_results.update({
  407. 'original_image': original_image,
  408. 'processed_image': None, # High-level API does not return processed_image
  409. 'layout_result': layout_image,
  410. 'markdown_content': md_content,
  411. 'cells_data': cells_data,
  412. 'temp_dir': parse_result['temp_dir'],
  413. 'session_id': parse_result['session_id'],
  414. 'result_paths': parse_result['result_paths']
  415. })
  416. # Prepare display information
  417. num_elements = len(cells_data) if cells_data else 0
  418. info_text = f"""
  419. **Image Information:**
  420. - Original Size: {original_image.width} x {original_image.height}
  421. - Model Input Size: {parse_result['input_width']} x {parse_result['input_height']}
  422. - Server: {current_config['ip']}:{current_config['port_vllm']}
  423. - Detected {num_elements} layout elements
  424. - Session ID: {parse_result['session_id']}
  425. """
  426. # Current page JSON output
  427. current_json = ""
  428. if cells_data:
  429. try:
  430. current_json = json.dumps(cells_data, ensure_ascii=False, indent=2)
  431. except:
  432. current_json = str(cells_data)
  433. # Create the download ZIP file
  434. download_zip_path = None
  435. if parse_result['temp_dir']:
  436. download_zip_path = os.path.join(parse_result['temp_dir'], f"layout_results_{parse_result['session_id']}.zip")
  437. try:
  438. with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
  439. for root, dirs, files in os.walk(parse_result['temp_dir']):
  440. for file in files:
  441. if file.endswith('.zip'):
  442. continue
  443. file_path = os.path.join(root, file)
  444. arcname = os.path.relpath(file_path, parse_result['temp_dir'])
  445. zipf.write(file_path, arcname)
  446. except Exception as e:
  447. print(f"Failed to create download ZIP: {e}")
  448. download_zip_path = None
  449. return (
  450. layout_image,
  451. info_text,
  452. md_content or "No markdown content generated",
  453. md_content_raw, # Raw markdown text
  454. gr.update(value=download_zip_path, visible=True) if download_zip_path else gr.update(visible=False), # Set the download file
  455. None, # Page info (not displayed for image processing)
  456. current_json # Current page JSON
  457. )
  458. except Exception as e:
  459. return None, f"Error during processing: {e}", "", "", gr.update(value=None), None, ""
  460. def process_pdf_file(pdf_path, prompt_mode):
  461. """Dedicated function for processing PDF files"""
  462. global pdf_cache, processing_results, dots_parser
  463. try:
  464. # First, load the PDF for preview
  465. preview_image, page_info = load_file_for_preview(pdf_path)
  466. # Parse the PDF using DotsOCRParser
  467. pdf_result = parse_pdf_with_high_level_api(dots_parser, pdf_path, prompt_mode)
  468. # Update the PDF cache
  469. pdf_cache["is_parsed"] = True
  470. pdf_cache["results"] = pdf_result['parsed_results']
  471. # Handle LaTeX table rendering
  472. combined_md = pdf_result['combined_md_content']
  473. combined_md_raw = combined_md or "No markdown content generated" # Save the raw content
  474. # Store results
  475. processing_results.update({
  476. 'original_image': None,
  477. 'processed_image': None,
  478. 'layout_result': None,
  479. 'markdown_content': combined_md,
  480. 'cells_data': pdf_result['combined_cells_data'],
  481. 'temp_dir': pdf_result['temp_dir'],
  482. 'session_id': pdf_result['session_id'],
  483. 'result_paths': None,
  484. 'pdf_results': pdf_result['parsed_results']
  485. })
  486. # Prepare display information
  487. total_elements = len(pdf_result['combined_cells_data'])
  488. info_text = f"""
  489. **PDF Information:**
  490. - Total Pages: {pdf_result['total_pages']}
  491. - Server: {current_config['ip']}:{current_config['port_vllm']}
  492. - Total Detected Elements: {total_elements}
  493. - Session ID: {pdf_result['session_id']}
  494. """
  495. # Content of the current page (first page)
  496. current_page_md = ""
  497. current_page_md_raw = ""
  498. current_page_json = ""
  499. current_page_layout_image = preview_image # Use the original preview image by default
  500. if pdf_cache["results"] and len(pdf_cache["results"]) > 0:
  501. current_result = pdf_cache["results"][0]
  502. if current_result['md_content']:
  503. # Raw markdown content
  504. current_page_md_raw = current_result['md_content']
  505. # Process the content after LaTeX rendering
  506. current_page_md = current_result['md_content']
  507. if current_result['cells_data']:
  508. try:
  509. current_page_json = json.dumps(current_result['cells_data'], ensure_ascii=False, indent=2)
  510. except:
  511. current_page_json = str(current_result['cells_data'])
  512. # Use the image with layout boxes (if available)
  513. if 'layout_image' in current_result and current_result['layout_image']:
  514. current_page_layout_image = current_result['layout_image']
  515. # Create the download ZIP file
  516. download_zip_path = None
  517. if pdf_result['temp_dir']:
  518. download_zip_path = os.path.join(pdf_result['temp_dir'], f"layout_results_{pdf_result['session_id']}.zip")
  519. try:
  520. with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
  521. for root, dirs, files in os.walk(pdf_result['temp_dir']):
  522. for file in files:
  523. if file.endswith('.zip'):
  524. continue
  525. file_path = os.path.join(root, file)
  526. arcname = os.path.relpath(file_path, pdf_result['temp_dir'])
  527. zipf.write(file_path, arcname)
  528. except Exception as e:
  529. print(f"Failed to create download ZIP: {e}")
  530. download_zip_path = None
  531. return (
  532. current_page_layout_image, # Use the image with layout boxes
  533. info_text,
  534. combined_md or "No markdown content generated", # Display the markdown for the entire PDF
  535. combined_md_raw or "No markdown content generated", # Display the raw markdown for the entire PDF
  536. gr.update(value=download_zip_path, visible=True) if download_zip_path else gr.update(visible=False), # Set the download file
  537. page_info,
  538. current_page_json
  539. )
  540. except Exception as e:
  541. # Reset the PDF cache
  542. pdf_cache["images"] = []
  543. pdf_cache["current_page"] = 0
  544. pdf_cache["total_pages"] = 0
  545. pdf_cache["is_parsed"] = False
  546. pdf_cache["results"] = []
  547. raise e
  548. def clear_all_data():
  549. """Clears all data"""
  550. global processing_results, pdf_cache
  551. # Clean up the temporary directory
  552. if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
  553. import shutil
  554. try:
  555. shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
  556. except Exception as e:
  557. print(f"Failed to clean up temporary directory: {e}")
  558. # Reset processing results
  559. processing_results = {
  560. 'original_image': None,
  561. 'processed_image': None,
  562. 'layout_result': None,
  563. 'markdown_content': None,
  564. 'cells_data': None,
  565. 'temp_dir': None,
  566. 'session_id': None,
  567. 'result_paths': None,
  568. 'pdf_results': None
  569. }
  570. # Reset the PDF cache
  571. pdf_cache = {
  572. "images": [],
  573. "current_page": 0,
  574. "total_pages": 0,
  575. "file_type": None,
  576. "is_parsed": False,
  577. "results": []
  578. }
  579. return (
  580. None, # Clear file input
  581. "", # Clear test image selection
  582. None, # Clear result image
  583. "Waiting for processing results...", # Reset info display
  584. "## Waiting for processing results...", # Reset Markdown display
  585. "🕐 Waiting for parsing result...", # Clear raw Markdown text
  586. gr.update(visible=False), # Hide download button
  587. "<div id='page_info_box'>0 / 0</div>", # Reset page info
  588. "🕐 Waiting for parsing result..." # Clear current page JSON
  589. )
  590. def update_prompt_display(prompt_mode):
  591. """Updates the prompt display content"""
  592. return dict_promptmode_to_prompt[prompt_mode]
  593. # ==================== Gradio Interface ====================
  594. def create_gradio_interface():
  595. """Creates the Gradio interface"""
  596. # CSS styles, matching the reference style
  597. css = """
  598. #parse_button {
  599. background: #FF576D !important; /* !important 确保覆盖主题默认样式 */
  600. border-color: #FF576D !important;
  601. }
  602. /* 鼠标悬停时的颜色 */
  603. #parse_button:hover {
  604. background: #F72C49 !important;
  605. border-color: #F72C49 !important;
  606. }
  607. #page_info_html {
  608. display: flex;
  609. align-items: center;
  610. justify-content: center;
  611. height: 100%;
  612. margin: 0 12px;
  613. }
  614. #page_info_box {
  615. padding: 8px 20px;
  616. font-size: 16px;
  617. border: 1px solid #bbb;
  618. border-radius: 8px;
  619. background-color: #f8f8f8;
  620. text-align: center;
  621. min-width: 80px;
  622. box-shadow: 0 1px 3px rgba(0,0,0,0.1);
  623. }
  624. #markdown_output {
  625. min-height: 800px;
  626. overflow: auto;
  627. }
  628. footer {
  629. visibility: hidden;
  630. }
  631. #info_box {
  632. padding: 10px;
  633. background-color: #f8f9fa;
  634. border-radius: 8px;
  635. border: 1px solid #dee2e6;
  636. margin: 10px 0;
  637. font-size: 14px;
  638. }
  639. #result_image {
  640. border-radius: 8px;
  641. }
  642. #markdown_tabs {
  643. height: 100%;
  644. }
  645. """
  646. with gr.Blocks(theme="ocean", css=css, title='dots.ocr') as demo:
  647. # Title
  648. gr.HTML("""
  649. <div style="display: flex; align-items: center; justify-content: center; margin-bottom: 20px;">
  650. <h1 style="margin: 0; font-size: 2em;">🔍 dots.ocr</h1>
  651. </div>
  652. <div style="text-align: center; margin-bottom: 10px;">
  653. <em>Supports image/PDF layout analysis and structured output</em>
  654. </div>
  655. """)
  656. with gr.Row():
  657. # Left side: Input and Configuration
  658. with gr.Column(scale=1, elem_id="left-panel"):
  659. gr.Markdown("### 📥 Upload & Select")
  660. file_input = gr.File(
  661. label="Upload PDF/Image",
  662. type="filepath",
  663. file_types=[".pdf", ".jpg", ".jpeg", ".png"],
  664. )
  665. test_images = get_test_images()
  666. test_image_input = gr.Dropdown(
  667. label="Or Select an Example",
  668. choices=[""] + test_images,
  669. value="",
  670. )
  671. gr.Markdown("### ⚙️ Prompt & Actions")
  672. prompt_mode = gr.Dropdown(
  673. label="Select Prompt",
  674. choices=["prompt_layout_all_en", "prompt_layout_only_en", "prompt_ocr"],
  675. value="prompt_layout_all_en",
  676. show_label=True
  677. )
  678. # Display current prompt content
  679. prompt_display = gr.Textbox(
  680. label="Current Prompt Content",
  681. value=dict_promptmode_to_prompt[list(dict_promptmode_to_prompt.keys())[0]],
  682. lines=4,
  683. max_lines=8,
  684. interactive=False,
  685. show_copy_button=True
  686. )
  687. with gr.Row():
  688. process_btn = gr.Button("🔍 Parse", variant="primary", scale=2, elem_id="parse_button")
  689. clear_btn = gr.Button("🗑️ Clear", variant="secondary", scale=1)
  690. with gr.Accordion("🛠️ Advanced Configuration", open=False):
  691. fitz_preprocess = gr.Checkbox(
  692. label="Enable fitz_preprocess for images",
  693. value=True,
  694. info="Processes image via a PDF-like pipeline (image->pdf->200dpi image). Recommended if your image DPI is low."
  695. )
  696. with gr.Row():
  697. server_ip = gr.Textbox(label="Server IP", value=DEFAULT_CONFIG['ip'])
  698. server_port = gr.Number(label="Port", value=DEFAULT_CONFIG['port_vllm'], precision=0)
  699. with gr.Row():
  700. min_pixels = gr.Number(label="Min Pixels", value=DEFAULT_CONFIG['min_pixels'], precision=0)
  701. max_pixels = gr.Number(label="Max Pixels", value=DEFAULT_CONFIG['max_pixels'], precision=0)
  702. # Right side: Result Display
  703. with gr.Column(scale=6, variant="compact"):
  704. with gr.Row():
  705. # Result Image
  706. with gr.Column(scale=3):
  707. gr.Markdown("### 👁️ File Preview")
  708. result_image = gr.Image(
  709. label="Layout Preview",
  710. visible=True,
  711. height=800,
  712. show_label=False
  713. )
  714. # Page navigation (shown during PDF preview)
  715. with gr.Row():
  716. prev_btn = gr.Button("⬅ Previous", size="sm")
  717. page_info = gr.HTML(
  718. value="<div id='page_info_box'>0 / 0</div>",
  719. elem_id="page_info_html"
  720. )
  721. next_btn = gr.Button("Next ➡", size="sm")
  722. # Info Display
  723. info_display = gr.Markdown(
  724. "Waiting for processing results...",
  725. elem_id="info_box"
  726. )
  727. # Markdown Result
  728. with gr.Column(scale=3):
  729. gr.Markdown("### ✔️ Result Display")
  730. with gr.Tabs(elem_id="markdown_tabs"):
  731. with gr.TabItem("Markdown Render Preview"):
  732. md_output = gr.Markdown(
  733. "## Please click the parse button to parse or select for single-task recognition...",
  734. label="Markdown Preview",
  735. max_height=600,
  736. latex_delimiters=[
  737. {"left": "$$", "right": "$$", "display": True},
  738. {"left": "$", "right": "$", "display": False},
  739. ],
  740. show_copy_button=False,
  741. elem_id="markdown_output"
  742. )
  743. with gr.TabItem("Markdown Raw Text"):
  744. md_raw_output = gr.Textbox(
  745. value="🕐 Waiting for parsing result...",
  746. label="Markdown Raw Text",
  747. max_lines=100,
  748. lines=38,
  749. show_copy_button=True,
  750. elem_id="markdown_output",
  751. show_label=False
  752. )
  753. with gr.TabItem("Current Page JSON"):
  754. current_page_json = gr.Textbox(
  755. value="🕐 Waiting for parsing result...",
  756. label="Current Page JSON",
  757. max_lines=100,
  758. lines=38,
  759. show_copy_button=True,
  760. elem_id="markdown_output",
  761. show_label=False
  762. )
  763. # Download Button
  764. with gr.Row():
  765. download_btn = gr.DownloadButton(
  766. "⬇️ Download Results",
  767. visible=False
  768. )
  769. # When the prompt mode changes, update the display content
  770. prompt_mode.change(
  771. fn=update_prompt_display,
  772. inputs=prompt_mode,
  773. outputs=prompt_display,
  774. show_progress=False
  775. )
  776. # Show preview on file upload
  777. file_input.upload(
  778. fn=load_file_for_preview,
  779. inputs=file_input,
  780. outputs=[result_image, page_info],
  781. show_progress=False
  782. )
  783. # Page navigation
  784. prev_btn.click(
  785. fn=lambda: turn_page("prev"),
  786. outputs=[result_image, page_info, current_page_json],
  787. show_progress=False
  788. )
  789. next_btn.click(
  790. fn=lambda: turn_page("next"),
  791. outputs=[result_image, page_info, current_page_json],
  792. show_progress=False
  793. )
  794. process_btn.click(
  795. fn=process_image_inference,
  796. inputs=[
  797. test_image_input, file_input,
  798. prompt_mode, server_ip, server_port, min_pixels, max_pixels,
  799. fitz_preprocess
  800. ],
  801. outputs=[
  802. result_image, info_display, md_output, md_raw_output,
  803. download_btn, page_info, current_page_json
  804. ],
  805. show_progress=True
  806. )
  807. clear_btn.click(
  808. fn=clear_all_data,
  809. outputs=[
  810. file_input, test_image_input,
  811. result_image, info_display, md_output, md_raw_output,
  812. download_btn, page_info, current_page_json
  813. ],
  814. show_progress=False
  815. )
  816. return demo
  817. # ==================== Main Program ====================
  818. if __name__ == "__main__":
  819. demo = create_gradio_interface()
  820. demo.queue().launch(
  821. server_name="0.0.0.0",
  822. server_port=7860,
  823. debug=True
  824. )