| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- """
- Layout Inference Web Application
- A Streamlit-based layout inference tool that supports image uploads and multiple backend inference engines.
- """
- import streamlit as st
- import json
- import os
- import io
- import tempfile
- from PIL import Image
- import requests
- # Local utility imports
- # from utils import infer
- from dots_ocr.utils import dict_promptmode_to_prompt
- from dots_ocr.utils.format_transformer import layoutjson2md
- from dots_ocr.utils.layout_utils import draw_layout_on_image, post_process_cells
- from dots_ocr.utils.image_utils import get_input_dimensions, get_image_by_fitz_doc
- from dots_ocr.model.inference import inference_with_vllm
- from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
- import os
- from PIL import Image
- from dots_ocr.utils.demo_utils.display import read_image
- # ==================== Configuration ====================
- DEFAULT_CONFIG = {
- 'ip': "127.0.0.1",
- 'port_vllm': 8000,
- 'min_pixels': MIN_PIXELS,
- 'max_pixels': MAX_PIXELS,
- 'test_images_dir': "./assets/showcase_origin",
- }
- # ==================== Utility Functions ====================
- @st.cache_resource
- def read_image_v2(img: str):
- if img.startswith(("http://", "https://")):
- with requests.get(img, stream=True) as response:
- response.raise_for_status()
- img = Image.open(io.BytesIO(response.content))
- if isinstance(img, str):
- # img = transform_image_path(img)
- img, _, _ = read_image(img, use_native=True)
- elif isinstance(img, Image.Image):
- pass
- else:
- raise ValueError(f"Invalid image type: {type(img)}")
- return img
- # ==================== UI Components ====================
- def create_config_sidebar():
- """Create configuration sidebar"""
- st.sidebar.header("Configuration Parameters")
-
- config = {}
- config['prompt_key'] = st.sidebar.selectbox("Prompt Mode", list(dict_promptmode_to_prompt.keys()))
- config['ip'] = st.sidebar.text_input("Server IP", DEFAULT_CONFIG['ip'])
- config['port'] = st.sidebar.number_input("Port", min_value=1000, max_value=9999, value=DEFAULT_CONFIG['port_vllm'])
- # config['eos_word'] = st.sidebar.text_input("EOS Word", DEFAULT_CONFIG['eos_word'])
-
- # Image configuration
- st.sidebar.subheader("Image Configuration")
- config['min_pixels'] = st.sidebar.number_input("Min Pixels", value=DEFAULT_CONFIG['min_pixels'])
- config['max_pixels'] = st.sidebar.number_input("Max Pixels", value=DEFAULT_CONFIG['max_pixels'])
-
- return config
- def get_image_input():
- """Get image input"""
- st.markdown("#### Image Input")
-
- input_mode = st.pills(label="Select input method", options=["Upload Image", "Enter Image URL/Path", "Select Test Image"], key="input_mode", label_visibility="collapsed")
- if input_mode == "Upload Image":
- # File uploader
- uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
- if uploaded_file is not None:
- with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
- tmp_file.write(uploaded_file.getvalue())
- return tmp_file.name
- elif input_mode == 'Enter Image URL/Path':
- # URL input
- img_url_input = st.text_input("Enter Image URL/Path")
- return img_url_input
- elif input_mode == 'Select Test Image':
- # Test image selection
- test_images = []
- test_dir = DEFAULT_CONFIG['test_images_dir']
- if os.path.exists(test_dir):
- test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)]
- img_url_test = st.selectbox("Select Test Image", [""] + test_images)
- return img_url_test
- else:
- raise ValueError(f"Invalid input mode: {input_mode}")
- return None
- def process_and_display_results(output: str, image: Image.Image, config: dict):
- """Process and display inference results"""
- prompt, response = output['prompt'], output['response']
-
- try:
- col1, col2 = st.columns(2)
- # st.markdown('---')
- cells = json.loads(response)
- # image = Image.open(img_url)
-
- # Post-processing
- cells = post_process_cells(
- image, cells,
- image.width, image.height,
- min_pixels=config['min_pixels'],
- max_pixels=config['max_pixels']
- )
-
- # Calculate input dimensions
- input_width, input_height = get_input_dimensions(
- image,
- min_pixels=config['min_pixels'],
- max_pixels=config['max_pixels']
- )
- st.markdown('---')
- st.write(f'Input Dimensions: {input_width} x {input_height}')
- # st.write(f'Prompt: {prompt}')
- # st.markdown(f'模型原始输出: <span style="color:blue">{result}</span>', unsafe_allow_html=True)
- # st.write('模型原始输出:')
- # st.write(response)
- # st.write('后处理结果:', str(cells))
- st.text_area('Original Model Output', response, height=200)
- st.text_area('Post-processed Result', str(cells), height=200)
- # 显示结果
- # st.title("Layout推理结果")
-
- with col1:
- # st.markdown("##### 可视化结果")
- new_image = draw_layout_on_image(
- image, cells,
- resized_height=None, resized_width=None,
- # text_key='text',
- fill_bbox=True, draw_bbox=True
- )
- st.markdown('##### Visualization Result')
- st.image(new_image, width=new_image.width)
- # st.write(f"尺寸: {new_image.width} x {new_image.height}")
-
- with col2:
- # st.markdown("##### Markdown格式")
- md_code = layoutjson2md(image, cells, text_key='text')
- # md_code = fix_streamlit_formula(md_code)
- st.markdown('##### Markdown Format')
- st.markdown(md_code, unsafe_allow_html=True)
-
- except json.JSONDecodeError:
- st.error("Model output is not a valid JSON format")
- except Exception as e:
- st.error(f"Error processing results: {e}")
- # ==================== Main Application ====================
- def main():
- """Main application function"""
- st.set_page_config(page_title="Layout Inference Tool", layout="wide")
- st.title("🔍 Layout Inference Tool")
-
- # Configuration
- config = create_config_sidebar()
- prompt = dict_promptmode_to_prompt[config['prompt_key']]
- st.sidebar.info(f"Current Prompt: {prompt}")
-
- # Image input
- img_url = get_image_input()
- start_button = st.button('🚀 Start Inference', type="primary")
-
- if img_url is not None and img_url.strip() != "":
- try:
- # processed_image = read_image_v2(img_url)
- origin_image = read_image_v2(img_url)
- st.write(f"Original Dimensions: {origin_image.width} x {origin_image.height}")
- # processed_image = get_image_by_fitz_doc(origin_image, target_dpi=200)
- processed_image = origin_image
- except Exception as e:
- st.error(f"Failed to read image: {e}")
- return
- else:
- st.info("Please enter an image URL/path or upload an image")
- return
- output = None
- # Inference button
- if start_button:
- with st.spinner(f"Inferring... Server: {config['ip']}:{config['port']}"):
-
- response = inference_with_vllm(
- processed_image, prompt, config['ip'], config['port'],
- # config['min_pixels'], config['max_pixels']
- )
- output = {
- 'prompt': prompt,
- 'response': response,
- }
- else:
- st.image(processed_image, width=500)
- # Process results
- if output:
- process_and_display_results(output, processed_image, config)
- if __name__ == "__main__":
- main()
|