Forráskód Böngészése

Add support for domestic acceleration cards in documentation

myhloli 1 hónapja
szülő
commit
14c38101f7

+ 365 - 0
docs/zh/usage/acceleration_cards/AMD.md

@@ -0,0 +1,365 @@
+## 基于Triton的ROCm 不同后端实现优化,基本实现vllm后端正常推理,以及pipeline后端中第一步layout用的DocLayout-YOLO
+
+**已有完整python vllm和mineru环境直接跳转第五步!!!**
+**其他GPU执行问题可以参考,先prof查看定位找到哪个算子问题,然后triton后端实现即可**
+测试了一下,基本和MinerU官网效果差不多,用AMD的人也不是很多,就在评论区分享给大家了
+
+### 1.结果介绍
+**补充一个200页的PDF python编程书测试一下速度,可以到1.99it/s:**
+Two Step Extraction: 100%|████████████████████████████████████████| 200/200 [01:40<00:00,  1.99it/s]
+
+**下面为之前14学术论文测试结果:**
+7900xtx mineru-gradio --server-name 0.0.0.0 --server-port 7860 --enable-vllm-engine true 速度大概为**1.6-1.8s/it**,没有仔细测试,简单试了两个文档。第二种矩阵乘法代替原来的dots点乘可以进一步提速到1.3s/it,优化后的主要算子耗时在hipblast(这个没法提升了)和vllm triton后端,各占25%耗时吧,vllm tirion后端这个这个只能等官方优化了。。。。
+doclayout-yolo的layout速度从原来的1.6it/s提高到15it/s,注意需要缓存一下输入的pdf尺寸后,triton必须要缓存尺寸没办法。主要是为了保留模型输入输出接口,最小代码改动。
+采用-b vlm-vllm-engine模式举个例子
+
+---
+**测试结果为优化为5d矩阵乘代替原来的点积结果:**
+2025-10-05 15:45:12.985 | INFO     | mineru.backend.vlm.vlm_analyze:get_model:128 - get vllm-engine predictor cost: 18.45s
+Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00, 12.20it/s]
+Processed prompts: 100%|█████████████████████| 14/14 [00:08<00:00,  1.56it/s, est. speed input: 2174.18 toks/s, output: 791.87 toks/s]
+Adding requests: 100%|█████████████████████████████████████████████████████████████████████████████| 278/278 [00:00<00:00, 323.03it/s]
+Processed prompts: 100%|██████████████████| 278/278 [00:07<00:00, 37.63it/s, est. speed input: 5264.66 toks/s, output: 2733.31 toks/s]
+
+mineru-gradio --server-name 0.0.0.0 --server-port 7860 --enable-vllm-engine true测试:
+2025-10-05 15:46:55.953 | WARNING  | mineru.cli.common:convert_pdf_bytes_to_bytes_by_pypdfium2:54 - end_page_id is out of range, use pdf_docs length
+Two Step Extraction: 100%|████████████████████████████████████████████████████████████████████████████| 14/14 [00:18<00:00,  1.30s/it]
+
+---
+
+### 2.原因介绍
+AMD RDNA使用vllm后端有严重的性能问题,原因是因为vllm的**qwen2_vl.py**中有一个算子在rocm kernel上没有对应的实现,导致性能出现严重的卷积计算回退,一次执行花了12s,。。。。。。。。一言难尽。即**MIOpen 库中缺少模型中特定 Conv3d(bfloat16) 的优化内核**。
+DocLayout-YOLO的**g2l_crm.py**空洞卷积也是这个问题,专业的CDNA MI210也没解决这个问题
+正好一起处理了。
+
+---
+
+### 3.环境介绍
+System: Ubuntu 24.04.3        Kernel: Linux 6.14.0-33-generic      ROCm version: 7.0.1
+python环境:
+python 3.12
+pytorch-triton-rocm   3.5.0+gitbbb06c03 
+torch                            2.10.0.dev20251001+rocm7.0
+torchvision                  0.25.0.dev20251003+rocm7.0
+vllm                              0.11.0rc2.dev198+g736fbf4c8.rocm701
+不同版本无所谓,处理方法是一样的。
+
+---
+
+### 4.前置环境安装
+```
+uv venv --python python3.12
+source .venv/bin/activate
+uv pip install --pre torch torchvision   -i https://pypi.tuna.tsinghua.edu.cn/simple/   --extra-index-url https://download.pytorch.org/whl/nightly/rocm7.0
+uv pip install pip
+# 避免覆盖我们本地的pytorch,改用pip而没有继续使用uv pip
+pip install -U "mineru[core]" -i https://pypi.mirrors.ustc.edu.cn/simple/
+```
+vllm 安装参考官方手册[Vllm](https://docs.vllm.com.cn/en/latest/getting_started/installation/gpu.html#amd-rocm)
+```
+#手动安装aiter,vllm,amd-smi等,自行找一个位置clone,然后进入该目录吧
+git clone --recursive https://github.com/ROCm/aiter.git
+cd aiter
+git submodule sync; git submodule update --init --recursive
+python setup.py develop
+cd ..
+git clone https://github.com/vllm-project/vllm.git
+cd vllm/
+cp -r /opt/rocm/share/amd_smi ~/Pytorch/vllm/
+pip install amd_smi/
+pip install --upgrade numba \
+    scipy \
+    huggingface-hub[cli,hf_transfer] \
+    setuptools_scm
+pip install -r requirements/rocm.txt
+export PYTORCH_ROCM_ARCH="gfx1100"   #根据自己的GPU架构 rocminfo | grep gfx
+python setup.py develop
+```
+---
+
+### 5.vllm中关键triton算子添加
+#### 这里我给出两种解决方法,第一种解决方法就是前面提到的优化到1.5到1.8s/it,第二种方法有手动优化算子到矩阵乘法,7900xtx肯定适用,大概1.3s/it,其他AMD GPU相对方案一也有提速,但是不一定是最佳速度实现,里面的手动部分可能需要微调。
+**注意pip把triton 后端的flash_attn卸载了,搞了半天各种尝试还是报错,问题比较大,直接不用就行了**
+```
+#定位自己vllm位置XXX
+pip show vllm
+```
+**关键更改**
+XXX/vllm/model_executor/models/qwen2_vl.py文件:
+**1.qwen2_vl.py文件33行下增加from .qwen2_vl_vision_kernels import triton_conv3d_patchify**
+```
+from collections.abc import Iterable, Mapping, Sequence
+from functools import partial
+from typing import Annotated, Any, Callable, Literal, Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .qwen2_vl_vision_kernels import triton_conv3d_patchify
+```
+**接下来分为方案一(2.1和3.1)和方案二(2.2和3.2),选取一种实现即可**
+
+---
+**方案1**
+**2.1qwen2_vl.py文件498行class Qwen2VisionPatchEmbed(nn.Module),PS.就是这玩意AMD没有现成的内核算子导致回退**
+```
+class Qwen2VisionPatchEmbed(nn.Module):
+
+    def __init__(
+        self,
+        patch_size: int = 14,
+        temporal_patch_size: int = 2,
+        in_channels: int = 3,
+        embed_dim: int = 1152,
+    ) -> None:
+        super().__init__()
+        self.patch_size = patch_size
+        self.temporal_patch_size = temporal_patch_size
+        self.embed_dim = embed_dim
+
+        kernel_size = (temporal_patch_size, patch_size, patch_size)
+        self.proj = nn.Conv3d(in_channels,
+                              embed_dim,
+                              kernel_size=kernel_size,
+                              stride=kernel_size,
+                              bias=False)
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        L, C = x.shape
+        x_reshaped = x.view(L, -1, self.temporal_patch_size, self.patch_size,
+                            self.patch_size)
+        
+        # Call your custom Triton kernel instead of self.proj
+        x_out = triton_conv3d_patchify(x_reshaped, self.proj.weight)
+        
+        # The output of our kernel is already the correct shape [L, embed_dim]
+        return x_out
+```
+**3.1XXX/vllm/model_executor/models/目录下创建qwen2_vl_vision_kernels.py文件,用triton实现**
+```
+import torch
+from vllm.triton_utils import tl, triton
+
+@triton.jit
+def _conv3d_patchify_kernel(
+    # Pointers to tensors
+    X, W, Y,
+    # Tensor dimensions
+    N, C_in, D_in, H_in, W_in,
+    C_out, KD, KH, KW,
+    # Stride and padding for memory access
+    stride_xn, stride_xc, stride_xd, stride_xh, stride_xw,
+    stride_wn, stride_wc, stride_wd, stride_wh, stride_ww,
+    stride_yn, stride_yc,
+    # Triton-specific metaparameters
+    BLOCK_SIZE: tl.constexpr,
+):
+    """
+    Triton kernel for a non-overlapping 3D patching convolution.
+    Each kernel instance computes one output value for one patch.
+    """
+    # Get the program IDs for the N (patch) and C_out (output channel) dimensions
+    pid_n = tl.program_id(0)  # The index of the patch we are processing
+    pid_cout = tl.program_id(1) # The index of the output channel we are computing
+
+    # --- Calculate memory pointers ---
+    # Pointer to the start of the current input patch
+    x_ptr = X + (pid_n * stride_xn)
+    # Pointer to the start of the current filter (weight)
+    w_ptr = W + (pid_cout * stride_wn)
+    # Pointer to where the output will be stored
+    y_ptr = Y + (pid_n * stride_yn + pid_cout * stride_yc)
+
+    # --- Perform the convolution (element-wise product and sum) ---
+    # This is a dot product between the flattened patch and the flattened filter.
+    accumulator = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
+
+    # Iterate over the elements of the patch/filter
+    for c_offset in range(0, C_in):
+        for d_offset in range(0, KD):
+            for h_offset in range(0, KH):
+                # Unrolled loop for the innermost dimension (width) for performance
+                for w_offset in range(0, KW, BLOCK_SIZE):
+                    # Create masks to handle cases where KW is not a multiple of BLOCK_SIZE
+                    w_range = w_offset + tl.arange(0, BLOCK_SIZE)
+                    w_mask = w_range < KW
+
+                    # Calculate offsets to load data
+                    patch_offset = (c_offset * stride_xc + d_offset * stride_xd +
+                                    h_offset * stride_xh + w_range * stride_xw)
+                    filter_offset = (c_offset * stride_wc + d_offset * stride_wd +
+                                     h_offset * stride_wh + w_range * stride_ww)
+
+                    # Load patch and filter data, applying masks
+                    patch_vals = tl.load(x_ptr + patch_offset, mask=w_mask, other=0.0)
+                    filter_vals = tl.load(w_ptr + filter_offset, mask=w_mask, other=0.0)
+
+                    # Multiply and accumulate
+                    accumulator += patch_vals.to(tl.float32) * filter_vals.to(tl.float32)
+
+    # Sum the accumulator block and store the single output value
+    output_val = tl.sum(accumulator, axis=0)
+    tl.store(y_ptr, output_val)
+
+
+def triton_conv3d_patchify(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
+    """
+    Python wrapper for the 3D patching convolution Triton kernel.
+    """
+    # Get tensor dimensions
+    N, C_in, D_in, H_in, W_in = x.shape
+    C_out, _, KD, KH, KW = weight.shape
+
+    # Create the output tensor
+    # The output of this specific conv is (N, C_out, 1, 1, 1), which we squeeze
+    Y = torch.empty((N, C_out), dtype=x.dtype, device=x.device)
+
+    # Define the grid for launching the Triton kernel
+    # Each kernel instance handles one patch (N) for one output channel (C_out)
+    grid = (N, C_out)
+
+    # Launch the kernel
+    # We pass all strides to make the kernel flexible
+    _conv3d_patchify_kernel[grid](
+        x, weight, Y,
+        N, C_in, D_in, H_in, W_in,
+        C_out, KD, KH, KW,
+        x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
+        weight.stride(0), weight.stride(1), weight.stride(2), weight.stride(3), weight.stride(4),
+        Y.stride(0), Y.stride(1),
+        BLOCK_SIZE=16, # A reasonable default, can be tuned
+    )
+
+    return Y
+```
+---
+**方案2**
+**2.2qwen2_vl.py文件498行class Qwen2VisionPatchEmbed(nn.Module)函数,PS.就是这玩意AMD没有现成的内核算子导致回退,这里我们直接5D张量一步到位,改为矩阵乘法**
+```
+class Qwen2VisionPatchEmbed(nn.Module):
+
+    def __init__(
+        self,
+        patch_size: int = 14,
+        temporal_patch_size: int = 2,
+        in_channels: int = 3,
+        embed_dim: int = 1152,
+    ) -> None:
+        super().__init__()
+        self.patch_size = patch_size
+        self.temporal_patch_size = temporal_patch_size
+        self.embed_dim = embed_dim
+
+        kernel_size = (temporal_patch_size, patch_size, patch_size)
+
+        self.proj = nn.Conv3d(in_channels,
+                              embed_dim,
+                              kernel_size=kernel_size,
+                              stride=kernel_size,
+                              bias=False)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        L, C = x.shape
+        x_reshaped_5d = x.view(L, -1, self.temporal_patch_size, self.patch_size,
+                               self.patch_size)
+
+        return triton_conv3d_patchify(x_reshaped_5d, self.proj.weight)
+```
+**3.2XXX/vllm/model_executor/models/目录下创建qwen2_vl_vision_kernels.py文件,用triton实现**
+```
+import torch
+from vllm.triton_utils import tl, triton
+
+@triton.jit
+def _conv_gemm_kernel(
+    A, B, C, M, N, K,
+    stride_am, stride_ak,
+    stride_bk, stride_bn,
+    stride_cm, stride_cn,
+    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
+):
+    pid_m = tl.program_id(0)
+    pid_n = tl.program_id(1)
+    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+    offs_k = tl.arange(0, BLOCK_K)
+    a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
+    b_ptrs = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
+    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+    for k in range(0, K, BLOCK_K):
+        a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)
+        b = tl.load(b_ptrs, mask=(offs_k[:, None] < K) & (offs_n[None, :] < N), other=0.0)
+        accumulator += tl.dot(a, b)
+        a_ptrs += BLOCK_K * stride_ak
+        b_ptrs += BLOCK_K * stride_bk
+        offs_k += BLOCK_K
+    c = accumulator.to(C.dtype.element_ty)
+    offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+    c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
+    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
+    tl.store(c_ptrs, c, mask=c_mask)
+
+def triton_conv3d_patchify(x_5d: torch.Tensor, weight_5d: torch.Tensor) -> torch.Tensor:
+    N_patches, _, _, _, _ = x_5d.shape
+    C_out, _, _, _, _ = weight_5d.shape
+    A = x_5d.view(N_patches, -1)
+    B = weight_5d.view(C_out, -1).transpose(0, 1).contiguous()
+    M, K = A.shape
+    _K, N = B.shape
+    assert K == _K
+    C = torch.empty((M, N), device=A.device, dtype=A.dtype)
+
+    # --- 针对7900xtx的手动调优配置,其他GPU的最优组合可能需要自行寻找,AMD的autotune效果就是没有效果 ---
+    best_config = {
+        'BLOCK_M': 128,
+        'BLOCK_N': 128,
+        'BLOCK_K': 32,
+    }
+    num_stages = 4
+    num_warps = 8
+
+    grid = (triton.cdiv(M, best_config['BLOCK_M']),
+            triton.cdiv(N, best_config['BLOCK_N']))
+
+    _conv_gemm_kernel[grid](
+        A, B, C,
+        M, N, K,
+        A.stride(0), A.stride(1),
+        B.stride(0), B.stride(1),
+        C.stride(0), C.stride(1),
+        **best_config,
+        num_stages=num_stages,
+        num_warps=num_warps
+    )
+
+    return C
+```
+---
+**4.关闭终端后再次使用mineru-gradio会报一个Lora错误,修改代码跳过它**
+```
+pip show mineru_vl_utils
+```
+
+打开该文件XXX/mineru_vl_utils/vlm_client/vllm_async_engine_client.py修改第58行self.tokenizer = vllm_async_llm.tokenizer.get_lora_tokenizer()为:
+```
+        try:
+            self.tokenizer = vllm_async_llm.tokenizer.get_lora_tokenizer()
+        except AttributeError:
+            # 如果没有 get_lora_tokenizer 方法,直接使用原始 tokenizer
+            self.tokenizer = vllm_async_llm.tokenizer
+```
+
+**最后整两个环境变量后愉快玩耍即可**
+```
+export MINERU_MODEL_SOURCE=modelscope
+export TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
+```
+---
+
+### 6.vllm后端已经没有问题,下面是pipeline 中layout用的doclayout-yolo模型空洞卷积问题
+### 我在 [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO/issues/120#issuecomment-3368144275) 下做了一个回答,因此 pipeline 的空洞卷积问题不在这里赘述,直接点击链接查看即可。
+查看自己doclayout-yolo安装位置如下,然后进入修改链接中回复介绍的文件即可
+```
+pip show doclayout-yolo
+```
+

+ 64 - 0
docs/zh/usage/acceleration_cards/Ascend.md

@@ -0,0 +1,64 @@
+#### 1 系统
+NAME="Ubuntu"
+VERSION="20.04.6 LTS (Focal Fossa)"
+昇腾910B2
+驱动 23.0.6.2
+CANN 7.5.X
+Miner U 2.1.9
+#### 2 踩坑记录
+坑1: **图形库相关的问题,总之就是动态库导致TLS的内存分配失败(OpenCV库在ARM64架构上的兼容性问题)**
+⭐这个错误 ImportError: /lib/aarch64-linux-gnu/libGLdispatch.so.0: cannot allocate memory in static TLS block 是由于OpenCV库在ARM64架构上的兼容性问题导致的。从错误堆栈可以看到,问题出现在导入cv2模块时,这发生在MinerU的VLM后端初始化过程中。
+解决方法:
+1 安装减少内存问题的opencv版本
+```
+pip install --upgrade albumentations albucore simsimd# Uninstall current opencv
+pip uninstall opencv-python opencv-contrib-python
+
+# Install headless version (no GUI dependencies)
+pip install opencv-python-headless
+
+python -c "import cv2; print(cv2.__version__)"2 apt-get install一些包
+```
+换成清华源然后重命名为sources.list.tuna,然后挪到根目录下面
+```
+deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/ focal main restricted universe multiverse
+deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/ focal-updates main restricted universe multiverse
+deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/ focal-backports main restricted universe multiverse
+deb https://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/ focal-security main restricted universe multiversesudo apt-get update -o Dir::Etc::sourcelist="sources.list.tuna" -o Dir::Etc::sourceparts="-" -o APT::Get::List-Cleanup="0"
+sudo apt-get install libgl1-mesa-glx -o Dir::Etc::sourcelist="sources.list.tuna" -o Dir::Etc::sourceparts="-" -o APT::Get::List-Cleanup="0"
+sudo apt-get install libglib2.0-0 libsm6 libxext6 libxrender-dev libgomp1 -o Dir::Etc::sourcelist="sources.list.tuna" -o Dir::Etc::sourceparts="-" -o APT::Get::List-Cleanup="0"
+sudo apt-get install libgl1-mesa-dev libgles2-mesa-dev -o Dir::Etc::sourcelist="sources.list.tuna" -o Dir::Etc::sourceparts="-" -o APT::Get::List-Cleanup="0"
+sudo apt-get install libgomp1 -o Dir::Etc::sourcelist="sources.list.tuna" -o Dir::Etc::sourceparts="-" -o APT::Get::List-Cleanup="0"
+export OPENCV_IO_ENABLE_OPENEXR=0  export QT_QPA_PLATFORM=offscreen
+```
+↑这些不知道哪些好使,或者有没有好使的
+
+3  强制覆盖conda环境自带的动态库(conda的和系统的冲突)
+```
+查找:find /usr/lib /lib /root/.local/conda -name "libgomp.so*" 2>/dev/null
+export LD_PRELOAD="/usr/lib/aarch64-linux-gnu/libstdc++.so.6:/usr/lib/aarch64-linux-gnu/libgomp.so.1"
+export LD_PRELOAD=/lib/aarch64-linux-gnu/libGLdispatch.so.0:$LD_PRELOAD
+```
+此外,还可以把conda环境中自带的的强制挪走
+```
+mv $CONDA_PREFIX/lib/libstdc++.so.6 $CONDA_PREFIX/lib/libstdc++.so.6.bak
+mv $CONDA_PREFIX/lib/libgomp.so.1 $CONDA_PREFIX/lib/libgomp.so.1.bak
+mv $CONDA_PREFIX/lib/libGLdispatch.so.0 $CONDA_PREFIX/lib/libGLdispatch.so.0.bak  # 如果有的话
+simsimd包相关:
+mv /root/.local/conda/envs/pdfparser/lib/python3.10/site-packages/simsimd./libgomp-947d5fa1.so.1.0.0 /root/.local/conda/envs/pdfparser/lib/python3.10/site-packages/simsimd./libgomp-947d5fa1.so.1.0.0.bak
+```
+或者:
+降级simsimd                3.7.2
+降级albumentations         1.3.1
+sklean包相关:
+```
+# 找到 scikit-learn 内部的 libgomp 路径
+SKLEARN_LIBGOMP="/root/.local/conda/envs/pdfparser/lib/python3.10/site-packages/scikit_learn.libs/libgomp-947d5fa1.so.1.0.0"
+
+# 预加载这个特定的 libgomp 版本
+export LD_PRELOAD="$SKLEARN_LIBGOMP:$LD_PRELOAD"
+```
+4 其他
+torch / torch_npu 2.5.1
+pip install "numpy<2.0" 2.0和昇腾不兼容
+export MINERU_MODEL_SOURCE=modelscope

+ 117 - 0
docs/zh/usage/acceleration_cards/METAX.md

@@ -0,0 +1,117 @@
+## 在C500+MACA上部署并使用Mineru
+
+### 获取MACA镜像,包含torch-maca,maca,sglang-maca
+
+镜像获取地址:https://developer.metax-tech.com/softnova/docker ,
+选择maca-c500-pytorch:2.33.0.6-ubuntu22.04-amd64
+
+若在docker上部署镜像则需要启动GPU设备访问
+```bash
+docker run --device=/dev/dri --device=/dev/mxcd....
+```
+
+#### 注意事项
+
+由于此镜像默认开启TORCH_ALLOW_TF32_CUBLAS_OVERRIDE,会导致backed:vlm-transformers推理结果错误
+
+```bash
+unset TORCH_ALLOW_TF32_CUBLAS_OVERRIDE
+```
+
+### 安装MinerU
+
+使用--no-deps,去除对一些cuda版本包的依赖,后续采用pip install-r requirements.txt 安装其他依赖
+```bash
+pip install -U "mineru[core]" --no-deps
+```
+
+```tex
+boto3>=1.28.43
+click>=8.1.7
+loguru>=0.7.2
+numpy==1.26.4
+pdfminer.six==20250506
+tqdm>=4.67.1
+requests
+httpx
+pillow>=11.0.0
+pypdfium2>=4.30.0
+pypdf>=5.6.0
+reportlab
+pdftext>=0.6.2
+modelscope>=1.26.0
+huggingface-hub>=0.32.4
+json-repair>=0.46.2
+opencv-python>=4.11.0.86
+fast-langdetect>=0.2.3,<0.3.0
+transformers>=4.51.1
+accelerate>=1.5.1
+pydantic
+matplotlib>=3.10,<4
+ultralytics>=8.3.48,<9
+dill>=0.3.8,<1
+rapid_table>=1.0.5,<2.0.0
+PyYAML>=6.0.2,<7 
+ftfy>=6.3.1,<7
+openai>=1.70.0,<2
+shapely>=2.0.7,<3
+pyclipper>=1.3.0,<2
+omegaconf>=2.3.0,<3
+transformers>=4.49.0,!=4.51.0,<5.0.0
+fastapi
+python-multipart
+uvicorn
+gradio>=5.34,<6
+gradio-pdf>=0.0.22
+albumentations
+beautifulsoup4
+scikit-image==0.25.0
+outlines==0.1.11
+magika>=0.6.2,<0.7.0
+mineru-vl-utils>=0.1.6,<1
+```
+上述内容保存为requirments.txt,进行安装
+```bash
+pip install -r requirments.txt
+```
+安装doclayout_yolo,这里doclayout_yolo会依赖torch-cuda,使用--no-deps
+```bash
+pip install doclayout-yolo --no-deps
+```
+### 在线使用
+**基础使用命令为:mineru -p <input_path> -o <output_path> -b vlm-transformers**
+
+- `<input_path>`: Local PDF/image file or directory
+- `<output_path>`: Output directory
+- -b  --backend [pipeline|vlm-transformers|vlm-vllm-engine|vlm-http-client] (default:pipeline)<br/>
+
+其他详细使用命令可参考官方文档[Quick Usage - MinerU](https://opendatalab.github.io/MinerU/usage/quick_usage/#quick-model-source-configuration)
+
+### 离线使用
+
+**所用模型为本地模型,需要设置环境变量和config配置文件**<br/>
+#### 下载模型到本地
+通过mineru交互式命令行工具进行下载,下载完后会自动更新mineru.json配置文件
+```bash
+mineru-models-download
+```
+也可以在[HuggingFace](http://www.huggingface.co.)或[ModelScope](https://www.modelscope.cn/home)找到所需模型源(PDF-Extract-Kit-1.0和MinerU2.5-2509-1.2B)进行下载,
+下载完成后,创建mineru.json文件,按如下进行修改
+```json
+{
+    "models-dir": {
+        "pipeline": "/path/pdf-extract-kit-1.0/",
+        "vlm": "/path/MinerU2.5-2509-1.2B"
+    },
+    "config_version": "1.3.0"
+}
+```
+path为本地模型的存储路径,其中models-dir为本地模型的路径,pipeline代表backend为pipeline时,所需要的模型路径,vlm代表backend为vlm-开头,所需要的模型路径
+
+#### 修改环境变量
+
+```bash
+export MINERU_MODEL_SOURCE=local
+export MINERU_TOOLS_CONFIG_JSON=/path/mineru.json   //此环境变量为配置文件的路径
+```
+修改完成后即可正常使用<br/>

+ 73 - 0
docs/zh/usage/acceleration_cards/Tecorigin.md

@@ -0,0 +1,73 @@
+# TECO适配
+
+## 快速开始
+使用本工具执行推理的主要流程如下:
+1. 基础环境安装:介绍推理前需要完成的基础环境检查和安装。
+3. 构建Docker环境:介绍如何使用Dockerfile创建模型推理时所需的Docker环境。
+4. 启动推理:介绍如何启动推理。
+
+### 1 基础环境安装
+请参考[Teco用户手册的安装准备章节](http://docs.tecorigin.com/release/torch_2.4/v2.2.0/#fc980a30f1125aa88bad4246ff0cedcc),完成训练前的基础环境检查和安装。
+
+### 2 构建docker
+#### 2.1 执行以下命令,下载Docker镜像至本地(Docker镜像包:pytorch-3.0.0-torch_sdaa3.0.0.tar)
+
+    wget 镜像下载链接(链接获取请联系太初内部人员)
+
+#### 2.2 校验Docker镜像包,执行以下命令,生成MD5码是否与官方MD5码b2a7f60508c0d199a99b8b6b35da3954一致:
+
+    md5sum pytorch-3.0.0-torch_sdaa3.0.0.tar
+
+#### 2.3 执行以下命令,导入Docker镜像
+
+    docker load < pytorch-3.0.0-torch_sdaa3.0.0.tar
+
+#### 2.4 执行以下命令,构建名为MinerU的Docker容器
+
+    docker run -itd --name="MinerU" --net=host --device=/dev/tcaicard0 --device=/dev/tcaicard1 --device=/dev/tcaicard2 --device=/dev/tcaicard3 --cap-add SYS_PTRACE --cap-add SYS_ADMIN --shm-size 64g jfrog.tecorigin.net/tecotp-docker/release/ubuntu22.04/x86_64/pytorch:3.0.0-torch_sdaa3.0.0 /bin/bash
+
+#### 2.5 执行以下命令,进入名称为tecopytorch_docker的Docker容器。
+
+    docker exec -it MinerU bash
+
+
+### 3 执行以下命令安装MinerU 
+- 安装前的准备
+    ```
+    cd <MinerU>
+    pip install --upgrade pip
+    pip install uv
+    ```    
+- 由于镜像中安装了torch,并且不需要安装nvidia-nccl-cu12、nvidia-cudnn-cu12等包,因此需要注释掉一部分安装依赖。
+- 请注释掉<MinerU>/pyproject.toml文件中所有的"doclayout_yolo==0.0.4"依赖,并且将torch开头的包也注释掉。
+- 执行以下命令安装MinerU
+    ```
+    uv pip install -e .[core]
+    ``` 
+- 下载安装doclayout_yolo==0.0.4
+    ```
+    pip install doclayout_yolo==0.0.4 --no-deps
+    ``` 
+- 下载安装其他包(doclayout_yolo==0.0.4的依赖)
+    ```
+    pip install albumentations py-cpuinfo seaborn thop numpy==1.24.4
+    ``` 
+- 由于部分张量内部内存分布不连续,需要修改如下两个文件
+    <ultralytics安装路径>/ultralytics/utils/tal.py(330行左右,将view --> reshape)
+    <doclayout_yolo安装路径>/doclayout_yolo/utils/tal.py(375行左右,将view --> reshape)
+### 4 执行推理
+- 开启sdaa环境
+    ```
+    export TORCH_SDAA_AUTOLOAD=cuda_migrate
+    ```
+- 首次运行推理命令前请添加以下环境下载模型权重
+    ```
+    export HF_ENDPOINT=https://hf-mirror.com
+    ```
+- 运行以下命令执行推理
+    ```
+     mineru   -p 'input path'  -o  'output_path' --lang 'model_name'
+    ```
+其中model_name可从'ch', 'ch_server', 'ch_lite', 'en', 'korean', 'japan', 'chinese_cht', 'ta', 'te', 'ka', 'latin', 'arabic', 'east_slavic', 'cyrillic', 'devanagari'选择
+### 5 适配用到的软件栈版本列表
+使用v3.0.0软件栈版本适配,获取方式联系太初内部人员

+ 0 - 0
docs/zh/usage/domestic_acceleration_cards/METAX.md


+ 0 - 0
docs/zh/usage/domestic_acceleration_cards/Tecorigin.md


+ 5 - 3
docs/zh/usage/index.md

@@ -20,9 +20,11 @@
     * [DataFlow](plugin/DataFlow.md)
     * [BISHENG](plugin/BISHENG.md)
     * [RagFlow](plugin/RagFlow.md)
-- 国产加速卡适配
-    * [沐曦 METAX](./domestic_acceleration_cards/METAX.md)
-    * [太初元碁 Tecorigin](./domestic_acceleration_cards/Tecorigin.md)
+- 其他加速卡适配(由社区贡献)
+    * [昇腾 Ascend](acceleration_cards/Ascend.md) #3233
+    * [AMD](acceleration_cards/AMD.md)  #3662
+    * [沐曦 METAX](acceleration_cards/METAX.md) @147phoenix
+    * [太初元碁 Tecorigin](acceleration_cards/Tecorigin.md) @Tecorigin
 
 ## 开始使用