Jelajahi Sumber

修复codereview问题

Magic_yuan 1 bulan lalu
induk
melakukan
484ff5a6f9

+ 1 - 1
projects/mineru_tianshu/README.md

@@ -259,7 +259,7 @@ POST /api/v1/tasks/submit
 参数:
 参数:
   file: 文件 (必需)
   file: 文件 (必需)
   backend: pipeline | vlm-transformers | vlm-vllm-engine (默认: pipeline)
   backend: pipeline | vlm-transformers | vlm-vllm-engine (默认: pipeline)
-  lang: ch | en | korean | japan | ... (默认: auto)
+  lang: ch | en | korean | japan | ... (默认: ch)
   priority: 0-100 (数字越大越优先,默认: 0)
   priority: 0-100 (数字越大越优先,默认: 0)
 ```
 ```
 
 

+ 17 - 7
projects/mineru_tianshu/api_server.py

@@ -5,7 +5,7 @@ MinerU Tianshu - API Server
 提供RESTful API接口用于任务提交、查询和管理
 提供RESTful API接口用于任务提交、查询和管理
 """
 """
 from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Query
 from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Query
-from fastapi.responses import JSONResponse, FileResponse
+from fastapi.responses import JSONResponse
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 import tempfile
 import tempfile
 from pathlib import Path
 from pathlib import Path
@@ -105,7 +105,8 @@ def process_markdown_images(md_content: str, image_dir: Path, upload_images: boo
                     minio_client.fput_object(bucket_name, object_name, str(full_image_path))
                     minio_client.fput_object(bucket_name, object_name, str(full_image_path))
                     
                     
                     # 生成 MinIO 访问 URL
                     # 生成 MinIO 访问 URL
-                    minio_url = f"https://{minio_endpoint}/{bucket_name}/{object_name}"
+                    scheme = 'https' if MINIO_CONFIG['secure'] else 'http'
+                    minio_url = f"{scheme}://{minio_endpoint}/{bucket_name}/{object_name}"
                     
                     
                     # 返回 HTML 格式的 img 标签
                     # 返回 HTML 格式的 img 标签
                     return f'<img src="{minio_url}" alt="{alt_text}">'
                     return f'<img src="{minio_url}" alt="{alt_text}">'
@@ -137,7 +138,7 @@ async def root():
 
 
 @app.post("/api/v1/tasks/submit")
 @app.post("/api/v1/tasks/submit")
 async def submit_task(
 async def submit_task(
-    file: UploadFile = File(..., description="PDF文件或图片"),
+    file: UploadFile = File(..., description="文档文件: PDF/图片(MinerU解析) 或 Office/HTML/文本等(MarkItDown解析)"),
     backend: str = Form('pipeline', description="处理后端: pipeline/vlm-transformers/vlm-vllm-engine"),
     backend: str = Form('pipeline', description="处理后端: pipeline/vlm-transformers/vlm-vllm-engine"),
     lang: str = Form('ch', description="语言: ch/en/korean/japan等"),
     lang: str = Form('ch', description="语言: ch/en/korean/japan等"),
     method: str = Form('auto', description="解析方法: auto/txt/ocr"),
     method: str = Form('auto', description="解析方法: auto/txt/ocr"),
@@ -153,8 +154,14 @@ async def submit_task(
     try:
     try:
         # 保存上传的文件到临时目录
         # 保存上传的文件到临时目录
         temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix)
         temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix)
-        content = await file.read()
-        temp_file.write(content)
+        
+        # 流式写入文件到磁盘,避免高内存使用
+        while True:
+            chunk = await file.read(1 << 23)  # 8MB chunks
+            if not chunk:
+                break
+            temp_file.write(chunk)
+        
         temp_file.close()
         temp_file.close()
         
         
         # 创建任务
         # 创建任务
@@ -405,13 +412,16 @@ async def health_check():
 
 
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
+    # 从环境变量读取端口,默认为8000
+    api_port = int(os.getenv('API_PORT', '8000'))
+    
     logger.info("🚀 Starting MinerU Tianshu API Server...")
     logger.info("🚀 Starting MinerU Tianshu API Server...")
-    logger.info("📖 API Documentation: http://localhost:8000/docs")
+    logger.info(f"📖 API Documentation: http://localhost:{api_port}/docs")
     
     
     uvicorn.run(
     uvicorn.run(
         app, 
         app, 
         host='0.0.0.0', 
         host='0.0.0.0', 
-        port=8000,
+        port=api_port,
         log_level='info'
         log_level='info'
     )
     )
 
 

+ 1 - 1
projects/mineru_tianshu/client_example.py

@@ -9,7 +9,7 @@ import aiohttp
 from pathlib import Path
 from pathlib import Path
 from loguru import logger
 from loguru import logger
 import time
 import time
-from typing import List, Dict
+from typing import Dict
 
 
 
 
 class TianshuClient:
 class TianshuClient:

+ 0 - 1
projects/mineru_tianshu/litserve_worker.py

@@ -11,7 +11,6 @@ import sys
 from pathlib import Path
 from pathlib import Path
 import litserve as ls
 import litserve as ls
 from loguru import logger
 from loguru import logger
-from typing import Optional
 
 
 # 添加父目录到路径以导入 MinerU
 # 添加父目录到路径以导入 MinerU
 sys.path.insert(0, str(Path(__file__).parent.parent.parent))
 sys.path.insert(0, str(Path(__file__).parent.parent.parent))

+ 5 - 1
projects/mineru_tianshu/start_all.py

@@ -8,6 +8,7 @@ import subprocess
 import signal
 import signal
 import sys
 import sys
 import time
 import time
+import os
 from loguru import logger
 from loguru import logger
 from pathlib import Path
 from pathlib import Path
 import argparse
 import argparse
@@ -44,9 +45,12 @@ class TianshuLauncher:
         try:
         try:
             # 1. 启动 API Server
             # 1. 启动 API Server
             logger.info("📡 [1/3] Starting API Server...")
             logger.info("📡 [1/3] Starting API Server...")
+            env = os.environ.copy()
+            env['API_PORT'] = str(self.api_port)
             api_proc = subprocess.Popen(
             api_proc = subprocess.Popen(
                 [sys.executable, 'api_server.py'],
                 [sys.executable, 'api_server.py'],
-                cwd=Path(__file__).parent
+                cwd=Path(__file__).parent,
+                env=env
             )
             )
             self.processes.append(('API Server', api_proc))
             self.processes.append(('API Server', api_proc))
             time.sleep(3)
             time.sleep(3)

+ 4 - 7
projects/mineru_tianshu/task_db.py

@@ -7,7 +7,6 @@ MinerU Tianshu - SQLite Task Database Manager
 import sqlite3
 import sqlite3
 import json
 import json
 import uuid
 import uuid
-from datetime import datetime
 from contextlib import contextmanager
 from contextlib import contextmanager
 from typing import Optional, List, Dict
 from typing import Optional, List, Dict
 from pathlib import Path
 from pathlib import Path
@@ -124,10 +123,10 @@ class TaskDB:
                 cursor.execute('''
                 cursor.execute('''
                     UPDATE tasks 
                     UPDATE tasks 
                     SET status = 'processing', 
                     SET status = 'processing', 
-                        started_at = ?, 
+                        started_at = CURRENT_TIMESTAMP, 
                         worker_id = ?
                         worker_id = ?
                     WHERE task_id = ?
                     WHERE task_id = ?
-                ''', (datetime.now().isoformat(), worker_id, task['task_id']))
+                ''', (worker_id, task['task_id']))
                 
                 
                 return dict(task)
                 return dict(task)
             
             
@@ -149,8 +148,7 @@ class TaskDB:
             params = [status]
             params = [status]
             
             
             if status == 'completed':
             if status == 'completed':
-                updates.append('completed_at = ?')
-                params.append(datetime.now().isoformat())
+                updates.append('completed_at = CURRENT_TIMESTAMP')
                 if result_path:
                 if result_path:
                     updates.append('result_path = ?')
                     updates.append('result_path = ?')
                     params.append(result_path)
                     params.append(result_path)
@@ -158,8 +156,7 @@ class TaskDB:
             if status == 'failed' and error_message:
             if status == 'failed' and error_message:
                 updates.append('error_message = ?')
                 updates.append('error_message = ?')
                 params.append(error_message)
                 params.append(error_message)
-                updates.append('completed_at = ?')
-                params.append(datetime.now().isoformat())
+                updates.append('completed_at = CURRENT_TIMESTAMP')
             
             
             params.append(task_id)
             params.append(task_id)
             cursor.execute(f'''
             cursor.execute(f'''

+ 30 - 32
projects/mineru_tianshu/task_scheduler.py

@@ -9,7 +9,6 @@ import aiohttp
 from loguru import logger
 from loguru import logger
 from task_db import TaskDB
 from task_db import TaskDB
 import signal
 import signal
-import sys
 
 
 
 
 class TaskScheduler:
 class TaskScheduler:
@@ -41,40 +40,38 @@ class TaskScheduler:
         self.max_concurrent_polls = max_concurrent_polls
         self.max_concurrent_polls = max_concurrent_polls
         self.db = TaskDB()
         self.db = TaskDB()
         self.running = True
         self.running = True
-        self.active_polls = 0
+        self.semaphore = asyncio.Semaphore(max_concurrent_polls)
     
     
     async def trigger_worker_poll(self, session: aiohttp.ClientSession):
     async def trigger_worker_poll(self, session: aiohttp.ClientSession):
         """
         """
         触发一个 worker 拉取任务
         触发一个 worker 拉取任务
         """
         """
-        self.active_polls += 1
-        try:
-            async with session.post(
-                self.litserve_url,
-                json={'action': 'poll'},
-                timeout=aiohttp.ClientTimeout(total=600)  # 10分钟超时
-            ) as resp:
-                if resp.status == 200:
-                    result = await resp.json()
-                    
-                    if result.get('status') == 'completed':
-                        logger.info(f"✅ Task completed: {result.get('task_id')} by {result.get('worker_id')}")
-                    elif result.get('status') == 'failed':
-                        logger.error(f"❌ Task failed: {result.get('task_id')} - {result.get('error')}")
-                    elif result.get('status') == 'idle':
-                        # Worker 空闲,没有任务
-                        pass
-                    
-                    return result
-                else:
-                    logger.error(f"Worker poll failed with status {resp.status}")
-                    
-        except asyncio.TimeoutError:
-            logger.warning("Worker poll timeout")
-        except Exception as e:
-            logger.error(f"Worker poll error: {e}")
-        finally:
-            self.active_polls -= 1
+        async with self.semaphore:
+            try:
+                async with session.post(
+                    self.litserve_url,
+                    json={'action': 'poll'},
+                    timeout=aiohttp.ClientTimeout(total=600)  # 10分钟超时
+                ) as resp:
+                    if resp.status == 200:
+                        result = await resp.json()
+                        
+                        if result.get('status') == 'completed':
+                            logger.info(f"✅ Task completed: {result.get('task_id')} by {result.get('worker_id')}")
+                        elif result.get('status') == 'failed':
+                            logger.error(f"❌ Task failed: {result.get('task_id')} - {result.get('error')}")
+                        elif result.get('status') == 'idle':
+                            # Worker 空闲,没有任务
+                            pass
+                        
+                        return result
+                    else:
+                        logger.error(f"Worker poll failed with status {resp.status}")
+                        
+            except asyncio.TimeoutError:
+                logger.warning("Worker poll timeout")
+            except Exception as e:
+                logger.error(f"Worker poll error: {e}")
     
     
     async def schedule_loop(self):
     async def schedule_loop(self):
         """
         """
@@ -97,14 +94,15 @@ class TaskScheduler:
                         logger.info(f"📋 Queue status: {pending_count} pending, {processing_count} processing")
                         logger.info(f"📋 Queue status: {pending_count} pending, {processing_count} processing")
                         
                         
                         # 计算需要触发的 worker 数量
                         # 计算需要触发的 worker 数量
-                        # 考虑:待处理任务数、当前处理中的任务数、活跃的轮询数
+                        # 考虑:待处理任务数
                         needed_workers = min(
                         needed_workers = min(
                             pending_count,  # 待处理任务数
                             pending_count,  # 待处理任务数
-                            self.max_concurrent_polls - self.active_polls  # 剩余并发数
+                            self.max_concurrent_polls  # 最大并发数
                         )
                         )
                         
                         
                         if needed_workers > 0:
                         if needed_workers > 0:
                             # 并发触发多个 worker
                             # 并发触发多个 worker
+                            # semaphore 会自动控制实际并发数
                             tasks = [
                             tasks = [
                                 self.trigger_worker_poll(session) 
                                 self.trigger_worker_poll(session) 
                                 for _ in range(needed_workers)
                                 for _ in range(needed_workers)