Kaynağa Gözat

修复codereview问题

Magic_yuan 1 ay önce
ebeveyn
işleme
484ff5a6f9

+ 1 - 1
projects/mineru_tianshu/README.md

@@ -259,7 +259,7 @@ POST /api/v1/tasks/submit
 参数:
   file: 文件 (必需)
   backend: pipeline | vlm-transformers | vlm-vllm-engine (默认: pipeline)
-  lang: ch | en | korean | japan | ... (默认: auto)
+  lang: ch | en | korean | japan | ... (默认: ch)
   priority: 0-100 (数字越大越优先,默认: 0)
 ```
 

+ 17 - 7
projects/mineru_tianshu/api_server.py

@@ -5,7 +5,7 @@ MinerU Tianshu - API Server
 提供RESTful API接口用于任务提交、查询和管理
 """
 from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Query
-from fastapi.responses import JSONResponse, FileResponse
+from fastapi.responses import JSONResponse
 from fastapi.middleware.cors import CORSMiddleware
 import tempfile
 from pathlib import Path
@@ -105,7 +105,8 @@ def process_markdown_images(md_content: str, image_dir: Path, upload_images: boo
                     minio_client.fput_object(bucket_name, object_name, str(full_image_path))
                     
                     # 生成 MinIO 访问 URL
-                    minio_url = f"https://{minio_endpoint}/{bucket_name}/{object_name}"
+                    scheme = 'https' if MINIO_CONFIG['secure'] else 'http'
+                    minio_url = f"{scheme}://{minio_endpoint}/{bucket_name}/{object_name}"
                     
                     # 返回 HTML 格式的 img 标签
                     return f'<img src="{minio_url}" alt="{alt_text}">'
@@ -137,7 +138,7 @@ async def root():
 
 @app.post("/api/v1/tasks/submit")
 async def submit_task(
-    file: UploadFile = File(..., description="PDF文件或图片"),
+    file: UploadFile = File(..., description="文档文件: PDF/图片(MinerU解析) 或 Office/HTML/文本等(MarkItDown解析)"),
     backend: str = Form('pipeline', description="处理后端: pipeline/vlm-transformers/vlm-vllm-engine"),
     lang: str = Form('ch', description="语言: ch/en/korean/japan等"),
     method: str = Form('auto', description="解析方法: auto/txt/ocr"),
@@ -153,8 +154,14 @@ async def submit_task(
     try:
         # 保存上传的文件到临时目录
         temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix)
-        content = await file.read()
-        temp_file.write(content)
+        
+        # 流式写入文件到磁盘,避免高内存使用
+        while True:
+            chunk = await file.read(1 << 23)  # 8MB chunks
+            if not chunk:
+                break
+            temp_file.write(chunk)
+        
         temp_file.close()
         
         # 创建任务
@@ -405,13 +412,16 @@ async def health_check():
 
 
 if __name__ == '__main__':
+    # 从环境变量读取端口,默认为8000
+    api_port = int(os.getenv('API_PORT', '8000'))
+    
     logger.info("🚀 Starting MinerU Tianshu API Server...")
-    logger.info("📖 API Documentation: http://localhost:8000/docs")
+    logger.info(f"📖 API Documentation: http://localhost:{api_port}/docs")
     
     uvicorn.run(
         app, 
         host='0.0.0.0', 
-        port=8000,
+        port=api_port,
         log_level='info'
     )
 

+ 1 - 1
projects/mineru_tianshu/client_example.py

@@ -9,7 +9,7 @@ import aiohttp
 from pathlib import Path
 from loguru import logger
 import time
-from typing import List, Dict
+from typing import Dict
 
 
 class TianshuClient:

+ 0 - 1
projects/mineru_tianshu/litserve_worker.py

@@ -11,7 +11,6 @@ import sys
 from pathlib import Path
 import litserve as ls
 from loguru import logger
-from typing import Optional
 
 # 添加父目录到路径以导入 MinerU
 sys.path.insert(0, str(Path(__file__).parent.parent.parent))

+ 5 - 1
projects/mineru_tianshu/start_all.py

@@ -8,6 +8,7 @@ import subprocess
 import signal
 import sys
 import time
+import os
 from loguru import logger
 from pathlib import Path
 import argparse
@@ -44,9 +45,12 @@ class TianshuLauncher:
         try:
             # 1. 启动 API Server
             logger.info("📡 [1/3] Starting API Server...")
+            env = os.environ.copy()
+            env['API_PORT'] = str(self.api_port)
             api_proc = subprocess.Popen(
                 [sys.executable, 'api_server.py'],
-                cwd=Path(__file__).parent
+                cwd=Path(__file__).parent,
+                env=env
             )
             self.processes.append(('API Server', api_proc))
             time.sleep(3)

+ 4 - 7
projects/mineru_tianshu/task_db.py

@@ -7,7 +7,6 @@ MinerU Tianshu - SQLite Task Database Manager
 import sqlite3
 import json
 import uuid
-from datetime import datetime
 from contextlib import contextmanager
 from typing import Optional, List, Dict
 from pathlib import Path
@@ -124,10 +123,10 @@ class TaskDB:
                 cursor.execute('''
                     UPDATE tasks 
                     SET status = 'processing', 
-                        started_at = ?, 
+                        started_at = CURRENT_TIMESTAMP, 
                         worker_id = ?
                     WHERE task_id = ?
-                ''', (datetime.now().isoformat(), worker_id, task['task_id']))
+                ''', (worker_id, task['task_id']))
                 
                 return dict(task)
             
@@ -149,8 +148,7 @@ class TaskDB:
             params = [status]
             
             if status == 'completed':
-                updates.append('completed_at = ?')
-                params.append(datetime.now().isoformat())
+                updates.append('completed_at = CURRENT_TIMESTAMP')
                 if result_path:
                     updates.append('result_path = ?')
                     params.append(result_path)
@@ -158,8 +156,7 @@ class TaskDB:
             if status == 'failed' and error_message:
                 updates.append('error_message = ?')
                 params.append(error_message)
-                updates.append('completed_at = ?')
-                params.append(datetime.now().isoformat())
+                updates.append('completed_at = CURRENT_TIMESTAMP')
             
             params.append(task_id)
             cursor.execute(f'''

+ 30 - 32
projects/mineru_tianshu/task_scheduler.py

@@ -9,7 +9,6 @@ import aiohttp
 from loguru import logger
 from task_db import TaskDB
 import signal
-import sys
 
 
 class TaskScheduler:
@@ -41,40 +40,38 @@ class TaskScheduler:
         self.max_concurrent_polls = max_concurrent_polls
         self.db = TaskDB()
         self.running = True
-        self.active_polls = 0
+        self.semaphore = asyncio.Semaphore(max_concurrent_polls)
     
     async def trigger_worker_poll(self, session: aiohttp.ClientSession):
         """
         触发一个 worker 拉取任务
         """
-        self.active_polls += 1
-        try:
-            async with session.post(
-                self.litserve_url,
-                json={'action': 'poll'},
-                timeout=aiohttp.ClientTimeout(total=600)  # 10分钟超时
-            ) as resp:
-                if resp.status == 200:
-                    result = await resp.json()
-                    
-                    if result.get('status') == 'completed':
-                        logger.info(f"✅ Task completed: {result.get('task_id')} by {result.get('worker_id')}")
-                    elif result.get('status') == 'failed':
-                        logger.error(f"❌ Task failed: {result.get('task_id')} - {result.get('error')}")
-                    elif result.get('status') == 'idle':
-                        # Worker 空闲,没有任务
-                        pass
-                    
-                    return result
-                else:
-                    logger.error(f"Worker poll failed with status {resp.status}")
-                    
-        except asyncio.TimeoutError:
-            logger.warning("Worker poll timeout")
-        except Exception as e:
-            logger.error(f"Worker poll error: {e}")
-        finally:
-            self.active_polls -= 1
+        async with self.semaphore:
+            try:
+                async with session.post(
+                    self.litserve_url,
+                    json={'action': 'poll'},
+                    timeout=aiohttp.ClientTimeout(total=600)  # 10分钟超时
+                ) as resp:
+                    if resp.status == 200:
+                        result = await resp.json()
+                        
+                        if result.get('status') == 'completed':
+                            logger.info(f"✅ Task completed: {result.get('task_id')} by {result.get('worker_id')}")
+                        elif result.get('status') == 'failed':
+                            logger.error(f"❌ Task failed: {result.get('task_id')} - {result.get('error')}")
+                        elif result.get('status') == 'idle':
+                            # Worker 空闲,没有任务
+                            pass
+                        
+                        return result
+                    else:
+                        logger.error(f"Worker poll failed with status {resp.status}")
+                        
+            except asyncio.TimeoutError:
+                logger.warning("Worker poll timeout")
+            except Exception as e:
+                logger.error(f"Worker poll error: {e}")
     
     async def schedule_loop(self):
         """
@@ -97,14 +94,15 @@ class TaskScheduler:
                         logger.info(f"📋 Queue status: {pending_count} pending, {processing_count} processing")
                         
                         # 计算需要触发的 worker 数量
-                        # 考虑:待处理任务数、当前处理中的任务数、活跃的轮询数
+                        # 考虑:待处理任务数
                         needed_workers = min(
                             pending_count,  # 待处理任务数
-                            self.max_concurrent_polls - self.active_polls  # 剩余并发数
+                            self.max_concurrent_polls  # 最大并发数
                         )
                         
                         if needed_workers > 0:
                             # 并发触发多个 worker
+                            # semaphore 会自动控制实际并发数
                             tasks = [
                                 self.trigger_worker_poll(session) 
                                 for _ in range(needed_workers)