|
@@ -0,0 +1,413 @@
|
|
|
|
|
+"""GLM-OCR VL识别适配器
|
|
|
|
|
+
|
|
|
|
|
+直接通过 HTTP 调用 GLM-OCR API(OpenAI 兼容格式)。
|
|
|
|
|
+支持表格、公式、文本和印章(seal)识别。
|
|
|
|
|
+
|
|
|
|
|
+架构说明:
|
|
|
|
|
+- 使用 requests 库直接调用 GLM-OCR HTTP API
|
|
|
|
|
+- 无需依赖 glmocr 包
|
|
|
|
|
+- 通过 task_prompt_mapping 配置不同任务的提示词
|
|
|
|
|
+- 支持图片预处理(尺寸控制)
|
|
|
|
|
+"""
|
|
|
|
|
+
|
|
|
|
|
+import sys
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from typing import Dict, Any, List, Union, Optional
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+from PIL import Image
|
|
|
|
|
+from loguru import logger
|
|
|
|
|
+import requests
|
|
|
|
|
+from requests.adapters import HTTPAdapter
|
|
|
|
|
+from urllib3.util.retry import Retry
|
|
|
|
|
+import base64
|
|
|
|
|
+from io import BytesIO
|
|
|
|
|
+import json
|
|
|
|
|
+
|
|
|
|
|
+# 导入基类
|
|
|
|
|
+from .base import BaseVLRecognizer
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class GLMOCRVLRecognizer(BaseVLRecognizer):
|
|
|
|
|
+ """
|
|
|
|
|
+ GLM-OCR VL识别适配器
|
|
|
|
|
+
|
|
|
|
|
+ 配置示例:
|
|
|
|
|
+ ```yaml
|
|
|
|
|
+ vl_recognition:
|
|
|
|
|
+ module: "glmocr"
|
|
|
|
|
+ api_url: "http://10.192.72.11:20036/v1/chat/completions"
|
|
|
|
|
+ api_key: null # 可选
|
|
|
|
|
+ model: "glm-ocr"
|
|
|
|
|
+ max_image_size: 3500
|
|
|
|
|
+ resize_mode: 'max'
|
|
|
|
|
+ task_prompt_mapping:
|
|
|
|
|
+ text: "Text Recognition:"
|
|
|
|
|
+ table: "Table Recognition:"
|
|
|
|
|
+ formula: "Formula Recognition:"
|
|
|
|
|
+ seal: "Seal Recognition:"
|
|
|
|
|
+ model_params:
|
|
|
|
|
+ connection_pool_size: 128
|
|
|
|
|
+ http_timeout: 300
|
|
|
|
|
+ retry_max_attempts: 2
|
|
|
|
|
+ ```
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, config: Dict[str, Any]):
|
|
|
|
|
+ super().__init__(config)
|
|
|
|
|
+
|
|
|
|
|
+ self.session = None
|
|
|
|
|
+
|
|
|
|
|
+ # API 配置
|
|
|
|
|
+ self.api_url = config.get('api_url', 'http://127.0.0.1:8000/v1/chat/completions')
|
|
|
|
|
+ self.api_key = config.get('api_key')
|
|
|
|
|
+ self.model = config.get('model', 'glm-ocr')
|
|
|
|
|
+ self.verify_ssl = config.get('verify_ssl', False)
|
|
|
|
|
+
|
|
|
|
|
+ # 图片尺寸限制配置
|
|
|
|
|
+ self.max_image_size = config.get('max_image_size', 3500)
|
|
|
|
|
+ self.resize_mode = config.get('resize_mode', 'max')
|
|
|
|
|
+
|
|
|
|
|
+ # Task prompt mapping(任务提示词映射)
|
|
|
|
|
+ self.task_prompt_mapping = config.get('task_prompt_mapping', {
|
|
|
|
|
+ 'text': 'Text Recognition:',
|
|
|
|
|
+ 'table': 'Table Recognition:',
|
|
|
|
|
+ 'formula': 'Formula Recognition:',
|
|
|
|
|
+ 'seal': 'Seal Recognition:',
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ # 模型参数
|
|
|
|
|
+ model_params = config.get('model_params', {})
|
|
|
|
|
+ self.connection_pool_size = model_params.get('connection_pool_size', 128)
|
|
|
|
|
+ self.http_timeout = model_params.get('http_timeout', 300)
|
|
|
|
|
+ self.connect_timeout = model_params.get('connect_timeout', 30)
|
|
|
|
|
+ self.retry_max_attempts = model_params.get('retry_max_attempts', 2)
|
|
|
|
|
+
|
|
|
|
|
+ # 生成参数
|
|
|
|
|
+ self.max_tokens = model_params.get('max_tokens', 4096)
|
|
|
|
|
+ self.temperature = model_params.get('temperature', 0.8)
|
|
|
|
|
+ self.top_p = model_params.get('top_p', 0.9)
|
|
|
|
|
+ self.top_k = model_params.get('top_k', 50)
|
|
|
|
|
+ self.repetition_penalty = model_params.get('repetition_penalty', 1.1)
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"GLM-OCR VL Recognizer configured with max_image_size={self.max_image_size}")
|
|
|
|
|
+ logger.debug(f"Task prompt mapping: {self.task_prompt_mapping}")
|
|
|
|
|
+
|
|
|
|
|
+ def initialize(self):
|
|
|
|
|
+ """初始化 HTTP 会话"""
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 创建会话
|
|
|
|
|
+ self.session = requests.Session()
|
|
|
|
|
+
|
|
|
|
|
+ # 配置连接池
|
|
|
|
|
+ adapter = HTTPAdapter(
|
|
|
|
|
+ pool_connections=self.connection_pool_size,
|
|
|
|
|
+ pool_maxsize=self.connection_pool_size,
|
|
|
|
|
+ max_retries=Retry(
|
|
|
|
|
+ total=self.retry_max_attempts,
|
|
|
|
|
+ backoff_factor=0.5,
|
|
|
|
|
+ status_forcelist=[429, 500, 502, 503, 504],
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ self.session.mount('http://', adapter)
|
|
|
|
|
+ self.session.mount('https://', adapter)
|
|
|
|
|
+
|
|
|
|
|
+ # 设置默认 headers
|
|
|
|
|
+ self.session.headers.update({
|
|
|
|
|
+ 'Content-Type': 'application/json',
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ if self.api_key:
|
|
|
|
|
+ self.session.headers.update({
|
|
|
|
|
+ 'Authorization': f'Bearer {self.api_key}'
|
|
|
|
|
+ })
|
|
|
|
|
+
|
|
|
|
|
+ logger.success(f"✅ GLM-OCR VL recognizer initialized: {self.api_url}")
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ Failed to initialize GLM-OCR VL recognizer: {e}")
|
|
|
|
|
+ raise
|
|
|
|
|
+
|
|
|
|
|
+ def cleanup(self):
|
|
|
|
|
+ """清理资源"""
|
|
|
|
|
+ if self.session:
|
|
|
|
|
+ self.session.close()
|
|
|
|
|
+ self.session = None
|
|
|
|
|
+ logger.debug("GLM-OCR VL recognizer cleaned up")
|
|
|
|
|
+
|
|
|
|
|
+ def _preprocess_image(self, image: Union[np.ndarray, Image.Image]) -> Image.Image:
|
|
|
|
|
+ """
|
|
|
|
|
+ 预处理图片,控制尺寸避免序列长度超限
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image: 输入图片
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 处理后的PIL图片
|
|
|
|
|
+ """
|
|
|
|
|
+ # 转换为PIL图像
|
|
|
|
|
+ if isinstance(image, np.ndarray):
|
|
|
|
|
+ image = Image.fromarray(image)
|
|
|
|
|
+
|
|
|
|
|
+ # 获取原始尺寸
|
|
|
|
|
+ orig_w, orig_h = image.size
|
|
|
|
|
+
|
|
|
|
|
+ # 计算缩放比例
|
|
|
|
|
+ if self.resize_mode == 'max':
|
|
|
|
|
+ # 保持宽高比,最长边不超过 max_image_size
|
|
|
|
|
+ max_dim = max(orig_w, orig_h)
|
|
|
|
|
+ if max_dim > self.max_image_size:
|
|
|
|
|
+ scale = self.max_image_size / max_dim
|
|
|
|
|
+ new_w = int(orig_w * scale)
|
|
|
|
|
+ new_h = int(orig_h * scale)
|
|
|
|
|
+
|
|
|
|
|
+ logger.debug(f"🔄 Resizing image: {orig_w}x{orig_h} → {new_w}x{new_h} (scale={scale:.3f})")
|
|
|
|
|
+ image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
|
|
|
+
|
|
|
|
|
+ elif self.resize_mode == 'fixed':
|
|
|
|
|
+ # 固定尺寸(可能改变宽高比)
|
|
|
|
|
+ if orig_w != self.max_image_size or orig_h != self.max_image_size:
|
|
|
|
|
+ logger.debug(f"🔄 Resizing image: {orig_w}x{orig_h} → {self.max_image_size}x{self.max_image_size}")
|
|
|
|
|
+ image = image.resize((self.max_image_size, self.max_image_size), Image.Resampling.LANCZOS)
|
|
|
|
|
+
|
|
|
|
|
+ return image
|
|
|
|
|
+
|
|
|
|
|
+ def _build_request_for_image(
|
|
|
|
|
+ self,
|
|
|
|
|
+ image: Image.Image,
|
|
|
|
|
+ task_type: str = 'text'
|
|
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 为单张图片构建 GLM-OCR API 请求
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image: PIL图片
|
|
|
|
|
+ task_type: 任务类型 ('text', 'table', 'formula', 'seal')
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 请求字典
|
|
|
|
|
+ """
|
|
|
|
|
+ # 获取任务对应的提示词
|
|
|
|
|
+ prompt_text = self.task_prompt_mapping.get(task_type, self.task_prompt_mapping.get('text', ''))
|
|
|
|
|
+
|
|
|
|
|
+ # 将图片转为 base64
|
|
|
|
|
+ buffered = BytesIO()
|
|
|
|
|
+ image.save(buffered, format="JPEG")
|
|
|
|
|
+ img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
|
|
|
+ img_url = f"data:image/jpeg;base64,{img_base64}"
|
|
|
|
|
+
|
|
|
|
|
+ # 构建请求(OpenAI 兼容格式)
|
|
|
|
|
+ request_data = {
|
|
|
|
|
+ "model": self.model,
|
|
|
|
|
+ "messages": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "role": "user",
|
|
|
|
|
+ "content": [
|
|
|
|
|
+ {"type": "image_url", "image_url": {"url": img_url}},
|
|
|
|
|
+ {"type": "text", "text": prompt_text},
|
|
|
|
|
+ ]
|
|
|
|
|
+ }
|
|
|
|
|
+ ],
|
|
|
|
|
+ "max_tokens": self.max_tokens,
|
|
|
|
|
+ "temperature": self.temperature,
|
|
|
|
|
+ "top_p": self.top_p,
|
|
|
|
|
+ "top_k": self.top_k,
|
|
|
|
|
+ "repetition_penalty": self.repetition_penalty,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return request_data
|
|
|
|
|
+
|
|
|
|
|
+ def _call_ocr_api(self, image: Image.Image, task_type: str) -> str:
|
|
|
|
|
+ """
|
|
|
|
|
+ 调用 GLM-OCR API 进行识别
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image: PIL图片
|
|
|
|
|
+ task_type: 任务类型
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 识别结果文本
|
|
|
|
|
+ """
|
|
|
|
|
+ if self.session is None:
|
|
|
|
|
+ raise RuntimeError("HTTP session not initialized")
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 构建请求
|
|
|
|
|
+ request_data = self._build_request_for_image(image, task_type)
|
|
|
|
|
+
|
|
|
|
|
+ # 调用 API
|
|
|
|
|
+ response = self.session.post(
|
|
|
|
|
+ self.api_url,
|
|
|
|
|
+ json=request_data,
|
|
|
|
|
+ timeout=(self.connect_timeout, self.http_timeout),
|
|
|
|
|
+ verify=self.verify_ssl
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if response.status_code != 200:
|
|
|
|
|
+ logger.error(f"OCR API returned status {response.status_code}: {response.text}")
|
|
|
|
|
+ return ""
|
|
|
|
|
+
|
|
|
|
|
+ # 解析响应
|
|
|
|
|
+ result = response.json()
|
|
|
|
|
+
|
|
|
|
|
+ # 提取识别结果
|
|
|
|
|
+ if 'choices' in result and len(result['choices']) > 0:
|
|
|
|
|
+ content = result['choices'][0].get('message', {}).get('content', '')
|
|
|
|
|
+ return content
|
|
|
|
|
+
|
|
|
|
|
+ logger.warning(f"No content in OCR response: {result}")
|
|
|
|
|
+ return ""
|
|
|
|
|
+
|
|
|
|
|
+ except requests.exceptions.Timeout:
|
|
|
|
|
+ logger.error(f"OCR API timeout after {self.http_timeout}s")
|
|
|
|
|
+ return ""
|
|
|
|
|
+ except requests.exceptions.RequestException as e:
|
|
|
|
|
+ logger.error(f"OCR API request failed: {e}")
|
|
|
|
|
+ return ""
|
|
|
|
|
+ except json.JSONDecodeError as e:
|
|
|
|
|
+ logger.error(f"Failed to parse OCR response: {e}")
|
|
|
|
|
+ return ""
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"OCR API call failed: {e}")
|
|
|
|
|
+ return ""
|
|
|
|
|
+
|
|
|
|
|
+ def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 识别表格
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image: 输入图片
|
|
|
|
|
+ **kwargs: 额外参数(未使用)
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 包含 'html' 和 'markdown' 的字典
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 预处理图片
|
|
|
|
|
+ image = self._preprocess_image(image)
|
|
|
|
|
+
|
|
|
|
|
+ # 调用 API
|
|
|
|
|
+ result_text = self._call_ocr_api(image, 'table')
|
|
|
|
|
+
|
|
|
|
|
+ if not result_text:
|
|
|
|
|
+ return {'html': '', 'markdown': '', 'cells': []}
|
|
|
|
|
+
|
|
|
|
|
+ # GLM-OCR 默认返回 Markdown 格式
|
|
|
|
|
+ # 如果需要 HTML,可以使用简单的转换(或保持 Markdown)
|
|
|
|
|
+ return {
|
|
|
|
|
+ 'html': result_text, # GLM-OCR 可能返回 HTML 或 Markdown
|
|
|
|
|
+ 'markdown': result_text,
|
|
|
|
|
+ 'cells': [], # GLM-OCR 不直接返回单元格坐标
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ Table recognition failed: {e}")
|
|
|
|
|
+ return {'html': '', 'markdown': '', 'cells': []}
|
|
|
|
|
+
|
|
|
|
|
+ def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 识别公式
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image: 输入图片
|
|
|
|
|
+ **kwargs: 额外参数(未使用)
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 包含 'latex' 的字典
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 预处理图片
|
|
|
|
|
+ image = self._preprocess_image(image)
|
|
|
|
|
+
|
|
|
|
|
+ # 调用 API
|
|
|
|
|
+ result_text = self._call_ocr_api(image, 'formula')
|
|
|
|
|
+
|
|
|
|
|
+ if not result_text:
|
|
|
|
|
+ return {'latex': '', 'confidence': 0.0, 'raw': {}}
|
|
|
|
|
+
|
|
|
|
|
+ # 清理 LaTeX 格式(移除 markdown 代码块标记)
|
|
|
|
|
+ latex = self._clean_latex(result_text)
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ 'latex': latex,
|
|
|
|
|
+ 'confidence': 0.9 if latex else 0.0,
|
|
|
|
|
+ 'raw': {'raw_output': result_text}
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ Formula recognition failed: {e}")
|
|
|
|
|
+ return {'latex': '', 'confidence': 0.0, 'raw': {}}
|
|
|
|
|
+
|
|
|
|
|
+ def recognize_text(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
|
|
|
|
|
+ """
|
|
|
|
|
+ 识别文本区域(包括普通文本和印章)
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image: 输入图片
|
|
|
|
|
+ **kwargs: 额外参数,可包含 'element_type' 指定类型(如 'seal')
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 包含 'text' 的字典
|
|
|
|
|
+ """
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 预处理图片
|
|
|
|
|
+ image = self._preprocess_image(image)
|
|
|
|
|
+
|
|
|
|
|
+ # 确定任务类型(如果是 seal,使用 seal 提示词)
|
|
|
|
|
+ element_type = kwargs.get('element_type', 'text')
|
|
|
|
|
+ task_type = 'seal' if element_type == 'seal' else 'text'
|
|
|
|
|
+
|
|
|
|
|
+ # 调用 API
|
|
|
|
|
+ result_text = self._call_ocr_api(image, task_type)
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ 'text': result_text or '',
|
|
|
|
|
+ 'confidence': 0.9 if result_text else 0.0
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"❌ Text recognition failed: {e}")
|
|
|
|
|
+ return {'text': '', 'confidence': 0.0}
|
|
|
|
|
+
|
|
|
|
|
+ def _clean_latex(self, latex_str: str) -> str:
|
|
|
|
|
+ """
|
|
|
|
|
+ 清理 LaTeX 字符串,移除 Markdown 代码块标记
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ latex_str: 原始 LaTeX 字符串
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ 清理后的 LaTeX
|
|
|
|
|
+ """
|
|
|
|
|
+ if not latex_str:
|
|
|
|
|
+ return ""
|
|
|
|
|
+
|
|
|
|
|
+ # 移除 Markdown 代码块标记
|
|
|
|
|
+ latex_str = latex_str.strip()
|
|
|
|
|
+ if latex_str.startswith('```'):
|
|
|
|
|
+ lines = latex_str.split('\n')
|
|
|
|
|
+ # 移除第一行的 ```latex 或 ```
|
|
|
|
|
+ if lines[0].startswith('```'):
|
|
|
|
|
+ lines = lines[1:]
|
|
|
|
|
+ # 移除最后一行的 ```
|
|
|
|
|
+ if lines and lines[-1].strip() == '```':
|
|
|
|
|
+ lines = lines[:-1]
|
|
|
|
|
+ latex_str = '\n'.join(lines)
|
|
|
|
|
+
|
|
|
|
|
+ # 移除行内代码标记
|
|
|
|
|
+ if latex_str.startswith('`') and latex_str.endswith('`'):
|
|
|
|
|
+ latex_str = latex_str[1:-1]
|
|
|
|
|
+
|
|
|
|
|
+ # 移除常见的 LaTeX 包裹符号
|
|
|
|
|
+ latex_str = latex_str.strip()
|
|
|
|
|
+ if latex_str.startswith('$') and latex_str.endswith('$'):
|
|
|
|
|
+ # 移除单个 $ 或 $$
|
|
|
|
|
+ if latex_str.startswith('$$') and latex_str.endswith('$$'):
|
|
|
|
|
+ latex_str = latex_str[2:-2]
|
|
|
|
|
+ else:
|
|
|
|
|
+ latex_str = latex_str[1:-1]
|
|
|
|
|
+
|
|
|
|
|
+ return latex_str.strip()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+# 导出适配器类
|
|
|
|
|
+__all__ = ['GLMOCRVLRecognizer']
|