Explorar el Código

feat(adapter): 添加 GLM-OCR VL 识别适配器,支持表格、公式、文本和印章识别

zhch158_admin hace 1 semana
padre
commit
cb2803c537
Se han modificado 1 ficheros con 413 adiciones y 0 borrados
  1. 413 0
      ocr_tools/universal_doc_parser/models/adapters/glmocr_vl_adapter.py

+ 413 - 0
ocr_tools/universal_doc_parser/models/adapters/glmocr_vl_adapter.py

@@ -0,0 +1,413 @@
+"""GLM-OCR VL识别适配器
+
+直接通过 HTTP 调用 GLM-OCR API(OpenAI 兼容格式)。
+支持表格、公式、文本和印章(seal)识别。
+
+架构说明:
+- 使用 requests 库直接调用 GLM-OCR HTTP API
+- 无需依赖 glmocr 包
+- 通过 task_prompt_mapping 配置不同任务的提示词
+- 支持图片预处理(尺寸控制)
+"""
+
+import sys
+from pathlib import Path
+from typing import Dict, Any, List, Union, Optional
+import numpy as np
+from PIL import Image
+from loguru import logger
+import requests
+from requests.adapters import HTTPAdapter
+from urllib3.util.retry import Retry
+import base64
+from io import BytesIO
+import json
+
+# 导入基类
+from .base import BaseVLRecognizer
+
+
+class GLMOCRVLRecognizer(BaseVLRecognizer):
+    """
+    GLM-OCR VL识别适配器
+    
+    配置示例:
+    ```yaml
+    vl_recognition:
+      module: "glmocr"
+      api_url: "http://10.192.72.11:20036/v1/chat/completions"
+      api_key: null  # 可选
+      model: "glm-ocr"
+      max_image_size: 3500
+      resize_mode: 'max'
+      task_prompt_mapping:
+        text: "Text Recognition:"
+        table: "Table Recognition:"
+        formula: "Formula Recognition:"
+        seal: "Seal Recognition:"
+      model_params:
+        connection_pool_size: 128
+        http_timeout: 300
+        retry_max_attempts: 2
+    ```
+    """
+    
+    def __init__(self, config: Dict[str, Any]):
+        super().__init__(config)
+        
+        self.session = None
+        
+        # API 配置
+        self.api_url = config.get('api_url', 'http://127.0.0.1:8000/v1/chat/completions')
+        self.api_key = config.get('api_key')
+        self.model = config.get('model', 'glm-ocr')
+        self.verify_ssl = config.get('verify_ssl', False)
+        
+        # 图片尺寸限制配置
+        self.max_image_size = config.get('max_image_size', 3500)
+        self.resize_mode = config.get('resize_mode', 'max')
+        
+        # Task prompt mapping(任务提示词映射)
+        self.task_prompt_mapping = config.get('task_prompt_mapping', {
+            'text': 'Text Recognition:',
+            'table': 'Table Recognition:',
+            'formula': 'Formula Recognition:',
+            'seal': 'Seal Recognition:',
+        })
+        
+        # 模型参数
+        model_params = config.get('model_params', {})
+        self.connection_pool_size = model_params.get('connection_pool_size', 128)
+        self.http_timeout = model_params.get('http_timeout', 300)
+        self.connect_timeout = model_params.get('connect_timeout', 30)
+        self.retry_max_attempts = model_params.get('retry_max_attempts', 2)
+        
+        # 生成参数
+        self.max_tokens = model_params.get('max_tokens', 4096)
+        self.temperature = model_params.get('temperature', 0.8)
+        self.top_p = model_params.get('top_p', 0.9)
+        self.top_k = model_params.get('top_k', 50)
+        self.repetition_penalty = model_params.get('repetition_penalty', 1.1)
+        
+        logger.info(f"GLM-OCR VL Recognizer configured with max_image_size={self.max_image_size}")
+        logger.debug(f"Task prompt mapping: {self.task_prompt_mapping}")
+    
+    def initialize(self):
+        """初始化 HTTP 会话"""
+        try:
+            # 创建会话
+            self.session = requests.Session()
+            
+            # 配置连接池
+            adapter = HTTPAdapter(
+                pool_connections=self.connection_pool_size,
+                pool_maxsize=self.connection_pool_size,
+                max_retries=Retry(
+                    total=self.retry_max_attempts,
+                    backoff_factor=0.5,
+                    status_forcelist=[429, 500, 502, 503, 504],
+                )
+            )
+            self.session.mount('http://', adapter)
+            self.session.mount('https://', adapter)
+            
+            # 设置默认 headers
+            self.session.headers.update({
+                'Content-Type': 'application/json',
+            })
+            
+            if self.api_key:
+                self.session.headers.update({
+                    'Authorization': f'Bearer {self.api_key}'
+                })
+            
+            logger.success(f"✅ GLM-OCR VL recognizer initialized: {self.api_url}")
+            
+        except Exception as e:
+            logger.error(f"❌ Failed to initialize GLM-OCR VL recognizer: {e}")
+            raise
+    
+    def cleanup(self):
+        """清理资源"""
+        if self.session:
+            self.session.close()
+            self.session = None
+        logger.debug("GLM-OCR VL recognizer cleaned up")
+    
+    def _preprocess_image(self, image: Union[np.ndarray, Image.Image]) -> Image.Image:
+        """
+        预处理图片,控制尺寸避免序列长度超限
+        
+        Args:
+            image: 输入图片
+            
+        Returns:
+            处理后的PIL图片
+        """
+        # 转换为PIL图像
+        if isinstance(image, np.ndarray):
+            image = Image.fromarray(image)
+        
+        # 获取原始尺寸
+        orig_w, orig_h = image.size
+        
+        # 计算缩放比例
+        if self.resize_mode == 'max':
+            # 保持宽高比,最长边不超过 max_image_size
+            max_dim = max(orig_w, orig_h)
+            if max_dim > self.max_image_size:
+                scale = self.max_image_size / max_dim
+                new_w = int(orig_w * scale)
+                new_h = int(orig_h * scale)
+                
+                logger.debug(f"🔄 Resizing image: {orig_w}x{orig_h} → {new_w}x{new_h} (scale={scale:.3f})")
+                image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
+        
+        elif self.resize_mode == 'fixed':
+            # 固定尺寸(可能改变宽高比)
+            if orig_w != self.max_image_size or orig_h != self.max_image_size:
+                logger.debug(f"🔄 Resizing image: {orig_w}x{orig_h} → {self.max_image_size}x{self.max_image_size}")
+                image = image.resize((self.max_image_size, self.max_image_size), Image.Resampling.LANCZOS)
+        
+        return image
+    
+    def _build_request_for_image(
+        self, 
+        image: Image.Image, 
+        task_type: str = 'text'
+    ) -> Dict[str, Any]:
+        """
+        为单张图片构建 GLM-OCR API 请求
+        
+        Args:
+            image: PIL图片
+            task_type: 任务类型 ('text', 'table', 'formula', 'seal')
+            
+        Returns:
+            请求字典
+        """
+        # 获取任务对应的提示词
+        prompt_text = self.task_prompt_mapping.get(task_type, self.task_prompt_mapping.get('text', ''))
+        
+        # 将图片转为 base64
+        buffered = BytesIO()
+        image.save(buffered, format="JPEG")
+        img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
+        img_url = f"data:image/jpeg;base64,{img_base64}"
+        
+        # 构建请求(OpenAI 兼容格式)
+        request_data = {
+            "model": self.model,
+            "messages": [
+                {
+                    "role": "user",
+                    "content": [
+                        {"type": "image_url", "image_url": {"url": img_url}},
+                        {"type": "text", "text": prompt_text},
+                    ]
+                }
+            ],
+            "max_tokens": self.max_tokens,
+            "temperature": self.temperature,
+            "top_p": self.top_p,
+            "top_k": self.top_k,
+            "repetition_penalty": self.repetition_penalty,
+        }
+        
+        return request_data
+    
+    def _call_ocr_api(self, image: Image.Image, task_type: str) -> str:
+        """
+        调用 GLM-OCR API 进行识别
+        
+        Args:
+            image: PIL图片
+            task_type: 任务类型
+            
+        Returns:
+            识别结果文本
+        """
+        if self.session is None:
+            raise RuntimeError("HTTP session not initialized")
+        
+        try:
+            # 构建请求
+            request_data = self._build_request_for_image(image, task_type)
+            
+            # 调用 API
+            response = self.session.post(
+                self.api_url,
+                json=request_data,
+                timeout=(self.connect_timeout, self.http_timeout),
+                verify=self.verify_ssl
+            )
+            
+            if response.status_code != 200:
+                logger.error(f"OCR API returned status {response.status_code}: {response.text}")
+                return ""
+            
+            # 解析响应
+            result = response.json()
+            
+            # 提取识别结果
+            if 'choices' in result and len(result['choices']) > 0:
+                content = result['choices'][0].get('message', {}).get('content', '')
+                return content
+            
+            logger.warning(f"No content in OCR response: {result}")
+            return ""
+            
+        except requests.exceptions.Timeout:
+            logger.error(f"OCR API timeout after {self.http_timeout}s")
+            return ""
+        except requests.exceptions.RequestException as e:
+            logger.error(f"OCR API request failed: {e}")
+            return ""
+        except json.JSONDecodeError as e:
+            logger.error(f"Failed to parse OCR response: {e}")
+            return ""
+        except Exception as e:
+            logger.error(f"OCR API call failed: {e}")
+            return ""
+    
+    def recognize_table(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
+        """
+        识别表格
+        
+        Args:
+            image: 输入图片
+            **kwargs: 额外参数(未使用)
+            
+        Returns:
+            包含 'html' 和 'markdown' 的字典
+        """
+        try:
+            # 预处理图片
+            image = self._preprocess_image(image)
+            
+            # 调用 API
+            result_text = self._call_ocr_api(image, 'table')
+            
+            if not result_text:
+                return {'html': '', 'markdown': '', 'cells': []}
+            
+            # GLM-OCR 默认返回 Markdown 格式
+            # 如果需要 HTML,可以使用简单的转换(或保持 Markdown)
+            return {
+                'html': result_text,  # GLM-OCR 可能返回 HTML 或 Markdown
+                'markdown': result_text,
+                'cells': [],  # GLM-OCR 不直接返回单元格坐标
+            }
+            
+        except Exception as e:
+            logger.error(f"❌ Table recognition failed: {e}")
+            return {'html': '', 'markdown': '', 'cells': []}
+    
+    def recognize_formula(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
+        """
+        识别公式
+        
+        Args:
+            image: 输入图片
+            **kwargs: 额外参数(未使用)
+            
+        Returns:
+            包含 'latex' 的字典
+        """
+        try:
+            # 预处理图片
+            image = self._preprocess_image(image)
+            
+            # 调用 API
+            result_text = self._call_ocr_api(image, 'formula')
+            
+            if not result_text:
+                return {'latex': '', 'confidence': 0.0, 'raw': {}}
+            
+            # 清理 LaTeX 格式(移除 markdown 代码块标记)
+            latex = self._clean_latex(result_text)
+            
+            return {
+                'latex': latex,
+                'confidence': 0.9 if latex else 0.0,
+                'raw': {'raw_output': result_text}
+            }
+            
+        except Exception as e:
+            logger.error(f"❌ Formula recognition failed: {e}")
+            return {'latex': '', 'confidence': 0.0, 'raw': {}}
+    
+    def recognize_text(self, image: Union[np.ndarray, Image.Image], **kwargs) -> Dict[str, Any]:
+        """
+        识别文本区域(包括普通文本和印章)
+        
+        Args:
+            image: 输入图片
+            **kwargs: 额外参数,可包含 'element_type' 指定类型(如 'seal')
+            
+        Returns:
+            包含 'text' 的字典
+        """
+        try:
+            # 预处理图片
+            image = self._preprocess_image(image)
+            
+            # 确定任务类型(如果是 seal,使用 seal 提示词)
+            element_type = kwargs.get('element_type', 'text')
+            task_type = 'seal' if element_type == 'seal' else 'text'
+            
+            # 调用 API
+            result_text = self._call_ocr_api(image, task_type)
+            
+            return {
+                'text': result_text or '',
+                'confidence': 0.9 if result_text else 0.0
+            }
+            
+        except Exception as e:
+            logger.error(f"❌ Text recognition failed: {e}")
+            return {'text': '', 'confidence': 0.0}
+    
+    def _clean_latex(self, latex_str: str) -> str:
+        """
+        清理 LaTeX 字符串,移除 Markdown 代码块标记
+        
+        Args:
+            latex_str: 原始 LaTeX 字符串
+            
+        Returns:
+            清理后的 LaTeX
+        """
+        if not latex_str:
+            return ""
+        
+        # 移除 Markdown 代码块标记
+        latex_str = latex_str.strip()
+        if latex_str.startswith('```'):
+            lines = latex_str.split('\n')
+            # 移除第一行的 ```latex 或 ```
+            if lines[0].startswith('```'):
+                lines = lines[1:]
+            # 移除最后一行的 ```
+            if lines and lines[-1].strip() == '```':
+                lines = lines[:-1]
+            latex_str = '\n'.join(lines)
+        
+        # 移除行内代码标记
+        if latex_str.startswith('`') and latex_str.endswith('`'):
+            latex_str = latex_str[1:-1]
+        
+        # 移除常见的 LaTeX 包裹符号
+        latex_str = latex_str.strip()
+        if latex_str.startswith('$') and latex_str.endswith('$'):
+            # 移除单个 $ 或 $$
+            if latex_str.startswith('$$') and latex_str.endswith('$$'):
+                latex_str = latex_str[2:-2]
+            else:
+                latex_str = latex_str[1:-1]
+        
+        return latex_str.strip()
+
+
+# 导出适配器类
+__all__ = ['GLMOCRVLRecognizer']