client_example.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. """
  2. MinerU Tianshu - Client Example
  3. 天枢客户端示例
  4. 演示如何使用 Python 客户端提交任务和查询状态
  5. """
  6. import asyncio
  7. import aiohttp
  8. from pathlib import Path
  9. from loguru import logger
  10. import time
  11. from typing import Dict
  12. class TianshuClient:
  13. """天枢客户端"""
  14. def __init__(self, api_url='http://localhost:8000'):
  15. self.api_url = api_url
  16. self.base_url = f"{api_url}/api/v1"
  17. async def submit_task(
  18. self,
  19. session: aiohttp.ClientSession,
  20. file_path: str,
  21. backend: str = 'pipeline',
  22. lang: str = 'ch',
  23. method: str = 'auto',
  24. formula_enable: bool = True,
  25. table_enable: bool = True,
  26. priority: int = 0
  27. ) -> Dict:
  28. """
  29. 提交任务
  30. Args:
  31. session: aiohttp session
  32. file_path: 文件路径
  33. backend: 处理后端
  34. lang: 语言
  35. method: 解析方法
  36. formula_enable: 是否启用公式识别
  37. table_enable: 是否启用表格识别
  38. priority: 优先级
  39. Returns:
  40. 响应字典,包含 task_id
  41. """
  42. with open(file_path, 'rb') as f:
  43. data = aiohttp.FormData()
  44. data.add_field('file', f, filename=Path(file_path).name)
  45. data.add_field('backend', backend)
  46. data.add_field('lang', lang)
  47. data.add_field('method', method)
  48. data.add_field('formula_enable', str(formula_enable).lower())
  49. data.add_field('table_enable', str(table_enable).lower())
  50. data.add_field('priority', str(priority))
  51. async with session.post(f'{self.base_url}/tasks/submit', data=data) as resp:
  52. if resp.status == 200:
  53. result = await resp.json()
  54. logger.info(f"✅ Submitted: {file_path} -> Task ID: {result['task_id']}")
  55. return result
  56. else:
  57. error = await resp.text()
  58. logger.error(f"❌ Failed to submit {file_path}: {error}")
  59. return {'success': False, 'error': error}
  60. async def get_task_status(self, session: aiohttp.ClientSession, task_id: str) -> Dict:
  61. """
  62. 查询任务状态
  63. Args:
  64. session: aiohttp session
  65. task_id: 任务ID
  66. Returns:
  67. 任务状态字典
  68. """
  69. async with session.get(f'{self.base_url}/tasks/{task_id}') as resp:
  70. if resp.status == 200:
  71. return await resp.json()
  72. else:
  73. return {'success': False, 'error': 'Task not found'}
  74. async def wait_for_task(
  75. self,
  76. session: aiohttp.ClientSession,
  77. task_id: str,
  78. timeout: int = 600,
  79. poll_interval: int = 2
  80. ) -> Dict:
  81. """
  82. 等待任务完成
  83. Args:
  84. session: aiohttp session
  85. task_id: 任务ID
  86. timeout: 超时时间(秒)
  87. poll_interval: 轮询间隔(秒)
  88. Returns:
  89. 最终任务状态
  90. """
  91. start_time = time.time()
  92. while True:
  93. status = await self.get_task_status(session, task_id)
  94. if not status.get('success'):
  95. logger.error(f"❌ Failed to get status for task {task_id}")
  96. return status
  97. task_status = status.get('status')
  98. if task_status == 'completed':
  99. logger.info(f"✅ Task {task_id} completed!")
  100. logger.info(f" Output: {status.get('result_path')}")
  101. return status
  102. elif task_status == 'failed':
  103. logger.error(f"❌ Task {task_id} failed!")
  104. logger.error(f" Error: {status.get('error_message')}")
  105. return status
  106. elif task_status == 'cancelled':
  107. logger.warning(f"⚠️ Task {task_id} was cancelled")
  108. return status
  109. # 检查超时
  110. if time.time() - start_time > timeout:
  111. logger.error(f"⏱️ Task {task_id} timeout after {timeout}s")
  112. return {'success': False, 'error': 'timeout'}
  113. # 等待后继续轮询
  114. await asyncio.sleep(poll_interval)
  115. async def get_queue_stats(self, session: aiohttp.ClientSession) -> Dict:
  116. """获取队列统计"""
  117. async with session.get(f'{self.base_url}/queue/stats') as resp:
  118. return await resp.json()
  119. async def cancel_task(self, session: aiohttp.ClientSession, task_id: str) -> Dict:
  120. """取消任务"""
  121. async with session.delete(f'{self.base_url}/tasks/{task_id}') as resp:
  122. return await resp.json()
  123. async def example_single_task():
  124. """示例1:提交单个任务并等待完成"""
  125. logger.info("=" * 60)
  126. logger.info("示例1:提交单个任务")
  127. logger.info("=" * 60)
  128. client = TianshuClient()
  129. async with aiohttp.ClientSession() as session:
  130. # 提交任务
  131. result = await client.submit_task(
  132. session,
  133. file_path='../../demo/pdfs/demo1.pdf',
  134. backend='pipeline',
  135. lang='ch',
  136. formula_enable=True,
  137. table_enable=True
  138. )
  139. if result.get('success'):
  140. task_id = result['task_id']
  141. # 等待完成
  142. logger.info(f"⏳ Waiting for task {task_id} to complete...")
  143. final_status = await client.wait_for_task(session, task_id)
  144. return final_status
  145. async def example_batch_tasks():
  146. """示例2:批量提交多个任务并并发等待"""
  147. logger.info("=" * 60)
  148. logger.info("示例2:批量提交多个任务")
  149. logger.info("=" * 60)
  150. client = TianshuClient()
  151. # 准备任务列表
  152. files = [
  153. '../../demo/pdfs/demo1.pdf',
  154. '../../demo/pdfs/demo2.pdf',
  155. '../../demo/pdfs/demo3.pdf',
  156. ]
  157. async with aiohttp.ClientSession() as session:
  158. # 并发提交所有任务
  159. logger.info(f"📤 Submitting {len(files)} tasks...")
  160. submit_tasks = [
  161. client.submit_task(session, file)
  162. for file in files
  163. ]
  164. results = await asyncio.gather(*submit_tasks)
  165. # 提取 task_ids
  166. task_ids = [r['task_id'] for r in results if r.get('success')]
  167. logger.info(f"✅ Submitted {len(task_ids)} tasks successfully")
  168. # 并发等待所有任务完成
  169. logger.info(f"⏳ Waiting for all tasks to complete...")
  170. wait_tasks = [
  171. client.wait_for_task(session, task_id)
  172. for task_id in task_ids
  173. ]
  174. final_results = await asyncio.gather(*wait_tasks)
  175. # 统计结果
  176. completed = sum(1 for r in final_results if r.get('status') == 'completed')
  177. failed = sum(1 for r in final_results if r.get('status') == 'failed')
  178. logger.info("=" * 60)
  179. logger.info(f"📊 Results: {completed} completed, {failed} failed")
  180. logger.info("=" * 60)
  181. return final_results
  182. async def example_priority_tasks():
  183. """示例3:使用优先级队列"""
  184. logger.info("=" * 60)
  185. logger.info("示例3:优先级队列")
  186. logger.info("=" * 60)
  187. client = TianshuClient()
  188. async with aiohttp.ClientSession() as session:
  189. # 提交低优先级任务
  190. low_priority = await client.submit_task(
  191. session,
  192. file_path='../../demo/pdfs/demo1.pdf',
  193. priority=0
  194. )
  195. logger.info(f"📝 Low priority task: {low_priority['task_id']}")
  196. # 提交高优先级任务
  197. high_priority = await client.submit_task(
  198. session,
  199. file_path='../../demo/pdfs/demo2.pdf',
  200. priority=10
  201. )
  202. logger.info(f"🔥 High priority task: {high_priority['task_id']}")
  203. # 高优先级任务会先被处理
  204. logger.info("⏳ 高优先级任务将优先处理...")
  205. async def example_queue_monitoring():
  206. """示例4:监控队列状态"""
  207. logger.info("=" * 60)
  208. logger.info("示例4:监控队列状态")
  209. logger.info("=" * 60)
  210. client = TianshuClient()
  211. async with aiohttp.ClientSession() as session:
  212. # 获取队列统计
  213. stats = await client.get_queue_stats(session)
  214. logger.info("📊 Queue Statistics:")
  215. logger.info(f" Total: {stats.get('total', 0)}")
  216. for status, count in stats.get('stats', {}).items():
  217. logger.info(f" {status:12s}: {count}")
  218. async def main():
  219. """主函数"""
  220. import sys
  221. if len(sys.argv) > 1:
  222. example = sys.argv[1]
  223. else:
  224. example = 'all'
  225. try:
  226. if example == 'single' or example == 'all':
  227. await example_single_task()
  228. print()
  229. if example == 'batch' or example == 'all':
  230. await example_batch_tasks()
  231. print()
  232. if example == 'priority' or example == 'all':
  233. await example_priority_tasks()
  234. print()
  235. if example == 'monitor' or example == 'all':
  236. await example_queue_monitoring()
  237. print()
  238. except Exception as e:
  239. logger.error(f"Example failed: {e}")
  240. import traceback
  241. traceback.print_exc()
  242. if __name__ == '__main__':
  243. """
  244. 使用方法:
  245. # 运行所有示例
  246. python client_example.py
  247. # 运行特定示例
  248. python client_example.py single
  249. python client_example.py batch
  250. python client_example.py priority
  251. python client_example.py monitor
  252. """
  253. asyncio.run(main())