|
|
@@ -0,0 +1,365 @@
|
|
|
+## 基于Triton的ROCm 不同后端实现优化,基本实现vllm后端正常推理,以及pipeline后端中第一步layout用的DocLayout-YOLO
|
|
|
+
|
|
|
+**已有完整python vllm和mineru环境直接跳转第五步!!!**
|
|
|
+**其他GPU执行问题可以参考,先prof查看定位找到哪个算子问题,然后triton后端实现即可**
|
|
|
+测试了一下,基本和MinerU官网效果差不多,用AMD的人也不是很多,就在评论区分享给大家了
|
|
|
+
|
|
|
+### 1.结果介绍
|
|
|
+**补充一个200页的PDF python编程书测试一下速度,可以到1.99it/s:**
|
|
|
+Two Step Extraction: 100%|████████████████████████████████████████| 200/200 [01:40<00:00, 1.99it/s]
|
|
|
+
|
|
|
+**下面为之前14学术论文测试结果:**
|
|
|
+7900xtx mineru-gradio --server-name 0.0.0.0 --server-port 7860 --enable-vllm-engine true 速度大概为**1.6-1.8s/it**,没有仔细测试,简单试了两个文档。第二种矩阵乘法代替原来的dots点乘可以进一步提速到1.3s/it,优化后的主要算子耗时在hipblast(这个没法提升了)和vllm triton后端,各占25%耗时吧,vllm tirion后端这个这个只能等官方优化了。。。。
|
|
|
+doclayout-yolo的layout速度从原来的1.6it/s提高到15it/s,注意需要缓存一下输入的pdf尺寸后,triton必须要缓存尺寸没办法。主要是为了保留模型输入输出接口,最小代码改动。
|
|
|
+采用-b vlm-vllm-engine模式举个例子
|
|
|
+
|
|
|
+---
|
|
|
+**测试结果为优化为5d矩阵乘代替原来的点积结果:**
|
|
|
+2025-10-05 15:45:12.985 | INFO | mineru.backend.vlm.vlm_analyze:get_model:128 - get vllm-engine predictor cost: 18.45s
|
|
|
+Adding requests: 100%|████████████████████████████████████████████████████████████████████████████████| 14/14 [00:01<00:00, 12.20it/s]
|
|
|
+Processed prompts: 100%|█████████████████████| 14/14 [00:08<00:00, 1.56it/s, est. speed input: 2174.18 toks/s, output: 791.87 toks/s]
|
|
|
+Adding requests: 100%|█████████████████████████████████████████████████████████████████████████████| 278/278 [00:00<00:00, 323.03it/s]
|
|
|
+Processed prompts: 100%|██████████████████| 278/278 [00:07<00:00, 37.63it/s, est. speed input: 5264.66 toks/s, output: 2733.31 toks/s]
|
|
|
+
|
|
|
+mineru-gradio --server-name 0.0.0.0 --server-port 7860 --enable-vllm-engine true测试:
|
|
|
+2025-10-05 15:46:55.953 | WARNING | mineru.cli.common:convert_pdf_bytes_to_bytes_by_pypdfium2:54 - end_page_id is out of range, use pdf_docs length
|
|
|
+Two Step Extraction: 100%|████████████████████████████████████████████████████████████████████████████| 14/14 [00:18<00:00, 1.30s/it]
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+### 2.原因介绍
|
|
|
+AMD RDNA使用vllm后端有严重的性能问题,原因是因为vllm的**qwen2_vl.py**中有一个算子在rocm kernel上没有对应的实现,导致性能出现严重的卷积计算回退,一次执行花了12s,。。。。。。。。一言难尽。即**MIOpen 库中缺少模型中特定 Conv3d(bfloat16) 的优化内核**。
|
|
|
+DocLayout-YOLO的**g2l_crm.py**空洞卷积也是这个问题,专业的CDNA MI210也没解决这个问题
|
|
|
+正好一起处理了。
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+### 3.环境介绍
|
|
|
+System: Ubuntu 24.04.3 Kernel: Linux 6.14.0-33-generic ROCm version: 7.0.1
|
|
|
+python环境:
|
|
|
+python 3.12
|
|
|
+pytorch-triton-rocm 3.5.0+gitbbb06c03
|
|
|
+torch 2.10.0.dev20251001+rocm7.0
|
|
|
+torchvision 0.25.0.dev20251003+rocm7.0
|
|
|
+vllm 0.11.0rc2.dev198+g736fbf4c8.rocm701
|
|
|
+不同版本无所谓,处理方法是一样的。
|
|
|
+
|
|
|
+---
|
|
|
+
|
|
|
+### 4.前置环境安装
|
|
|
+```
|
|
|
+uv venv --python python3.12
|
|
|
+source .venv/bin/activate
|
|
|
+uv pip install --pre torch torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple/ --extra-index-url https://download.pytorch.org/whl/nightly/rocm7.0
|
|
|
+uv pip install pip
|
|
|
+# 避免覆盖我们本地的pytorch,改用pip而没有继续使用uv pip
|
|
|
+pip install -U "mineru[core]" -i https://pypi.mirrors.ustc.edu.cn/simple/
|
|
|
+```
|
|
|
+vllm 安装参考官方手册[Vllm](https://docs.vllm.com.cn/en/latest/getting_started/installation/gpu.html#amd-rocm)
|
|
|
+```
|
|
|
+#手动安装aiter,vllm,amd-smi等,自行找一个位置clone,然后进入该目录吧
|
|
|
+git clone --recursive https://github.com/ROCm/aiter.git
|
|
|
+cd aiter
|
|
|
+git submodule sync; git submodule update --init --recursive
|
|
|
+python setup.py develop
|
|
|
+cd ..
|
|
|
+git clone https://github.com/vllm-project/vllm.git
|
|
|
+cd vllm/
|
|
|
+cp -r /opt/rocm/share/amd_smi ~/Pytorch/vllm/
|
|
|
+pip install amd_smi/
|
|
|
+pip install --upgrade numba \
|
|
|
+ scipy \
|
|
|
+ huggingface-hub[cli,hf_transfer] \
|
|
|
+ setuptools_scm
|
|
|
+pip install -r requirements/rocm.txt
|
|
|
+export PYTORCH_ROCM_ARCH="gfx1100" #根据自己的GPU架构 rocminfo | grep gfx
|
|
|
+python setup.py develop
|
|
|
+```
|
|
|
+---
|
|
|
+
|
|
|
+### 5.vllm中关键triton算子添加
|
|
|
+#### 这里我给出两种解决方法,第一种解决方法就是前面提到的优化到1.5到1.8s/it,第二种方法有手动优化算子到矩阵乘法,7900xtx肯定适用,大概1.3s/it,其他AMD GPU相对方案一也有提速,但是不一定是最佳速度实现,里面的手动部分可能需要微调。
|
|
|
+**注意pip把triton 后端的flash_attn卸载了,搞了半天各种尝试还是报错,问题比较大,直接不用就行了**
|
|
|
+```
|
|
|
+#定位自己vllm位置XXX
|
|
|
+pip show vllm
|
|
|
+```
|
|
|
+**关键更改**
|
|
|
+XXX/vllm/model_executor/models/qwen2_vl.py文件:
|
|
|
+**1.qwen2_vl.py文件33行下增加from .qwen2_vl_vision_kernels import triton_conv3d_patchify**
|
|
|
+```
|
|
|
+from collections.abc import Iterable, Mapping, Sequence
|
|
|
+from functools import partial
|
|
|
+from typing import Annotated, Any, Callable, Literal, Optional, Union
|
|
|
+
|
|
|
+import torch
|
|
|
+import torch.nn as nn
|
|
|
+import torch.nn.functional as F
|
|
|
+from .qwen2_vl_vision_kernels import triton_conv3d_patchify
|
|
|
+```
|
|
|
+**接下来分为方案一(2.1和3.1)和方案二(2.2和3.2),选取一种实现即可**
|
|
|
+
|
|
|
+---
|
|
|
+**方案1**
|
|
|
+**2.1qwen2_vl.py文件498行class Qwen2VisionPatchEmbed(nn.Module),PS.就是这玩意AMD没有现成的内核算子导致回退**
|
|
|
+```
|
|
|
+class Qwen2VisionPatchEmbed(nn.Module):
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ patch_size: int = 14,
|
|
|
+ temporal_patch_size: int = 2,
|
|
|
+ in_channels: int = 3,
|
|
|
+ embed_dim: int = 1152,
|
|
|
+ ) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self.patch_size = patch_size
|
|
|
+ self.temporal_patch_size = temporal_patch_size
|
|
|
+ self.embed_dim = embed_dim
|
|
|
+
|
|
|
+ kernel_size = (temporal_patch_size, patch_size, patch_size)
|
|
|
+ self.proj = nn.Conv3d(in_channels,
|
|
|
+ embed_dim,
|
|
|
+ kernel_size=kernel_size,
|
|
|
+ stride=kernel_size,
|
|
|
+ bias=False)
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
+ L, C = x.shape
|
|
|
+ x_reshaped = x.view(L, -1, self.temporal_patch_size, self.patch_size,
|
|
|
+ self.patch_size)
|
|
|
+
|
|
|
+ # Call your custom Triton kernel instead of self.proj
|
|
|
+ x_out = triton_conv3d_patchify(x_reshaped, self.proj.weight)
|
|
|
+
|
|
|
+ # The output of our kernel is already the correct shape [L, embed_dim]
|
|
|
+ return x_out
|
|
|
+```
|
|
|
+**3.1XXX/vllm/model_executor/models/目录下创建qwen2_vl_vision_kernels.py文件,用triton实现**
|
|
|
+```
|
|
|
+import torch
|
|
|
+from vllm.triton_utils import tl, triton
|
|
|
+
|
|
|
+@triton.jit
|
|
|
+def _conv3d_patchify_kernel(
|
|
|
+ # Pointers to tensors
|
|
|
+ X, W, Y,
|
|
|
+ # Tensor dimensions
|
|
|
+ N, C_in, D_in, H_in, W_in,
|
|
|
+ C_out, KD, KH, KW,
|
|
|
+ # Stride and padding for memory access
|
|
|
+ stride_xn, stride_xc, stride_xd, stride_xh, stride_xw,
|
|
|
+ stride_wn, stride_wc, stride_wd, stride_wh, stride_ww,
|
|
|
+ stride_yn, stride_yc,
|
|
|
+ # Triton-specific metaparameters
|
|
|
+ BLOCK_SIZE: tl.constexpr,
|
|
|
+):
|
|
|
+ """
|
|
|
+ Triton kernel for a non-overlapping 3D patching convolution.
|
|
|
+ Each kernel instance computes one output value for one patch.
|
|
|
+ """
|
|
|
+ # Get the program IDs for the N (patch) and C_out (output channel) dimensions
|
|
|
+ pid_n = tl.program_id(0) # The index of the patch we are processing
|
|
|
+ pid_cout = tl.program_id(1) # The index of the output channel we are computing
|
|
|
+
|
|
|
+ # --- Calculate memory pointers ---
|
|
|
+ # Pointer to the start of the current input patch
|
|
|
+ x_ptr = X + (pid_n * stride_xn)
|
|
|
+ # Pointer to the start of the current filter (weight)
|
|
|
+ w_ptr = W + (pid_cout * stride_wn)
|
|
|
+ # Pointer to where the output will be stored
|
|
|
+ y_ptr = Y + (pid_n * stride_yn + pid_cout * stride_yc)
|
|
|
+
|
|
|
+ # --- Perform the convolution (element-wise product and sum) ---
|
|
|
+ # This is a dot product between the flattened patch and the flattened filter.
|
|
|
+ accumulator = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
|
|
+
|
|
|
+ # Iterate over the elements of the patch/filter
|
|
|
+ for c_offset in range(0, C_in):
|
|
|
+ for d_offset in range(0, KD):
|
|
|
+ for h_offset in range(0, KH):
|
|
|
+ # Unrolled loop for the innermost dimension (width) for performance
|
|
|
+ for w_offset in range(0, KW, BLOCK_SIZE):
|
|
|
+ # Create masks to handle cases where KW is not a multiple of BLOCK_SIZE
|
|
|
+ w_range = w_offset + tl.arange(0, BLOCK_SIZE)
|
|
|
+ w_mask = w_range < KW
|
|
|
+
|
|
|
+ # Calculate offsets to load data
|
|
|
+ patch_offset = (c_offset * stride_xc + d_offset * stride_xd +
|
|
|
+ h_offset * stride_xh + w_range * stride_xw)
|
|
|
+ filter_offset = (c_offset * stride_wc + d_offset * stride_wd +
|
|
|
+ h_offset * stride_wh + w_range * stride_ww)
|
|
|
+
|
|
|
+ # Load patch and filter data, applying masks
|
|
|
+ patch_vals = tl.load(x_ptr + patch_offset, mask=w_mask, other=0.0)
|
|
|
+ filter_vals = tl.load(w_ptr + filter_offset, mask=w_mask, other=0.0)
|
|
|
+
|
|
|
+ # Multiply and accumulate
|
|
|
+ accumulator += patch_vals.to(tl.float32) * filter_vals.to(tl.float32)
|
|
|
+
|
|
|
+ # Sum the accumulator block and store the single output value
|
|
|
+ output_val = tl.sum(accumulator, axis=0)
|
|
|
+ tl.store(y_ptr, output_val)
|
|
|
+
|
|
|
+
|
|
|
+def triton_conv3d_patchify(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ Python wrapper for the 3D patching convolution Triton kernel.
|
|
|
+ """
|
|
|
+ # Get tensor dimensions
|
|
|
+ N, C_in, D_in, H_in, W_in = x.shape
|
|
|
+ C_out, _, KD, KH, KW = weight.shape
|
|
|
+
|
|
|
+ # Create the output tensor
|
|
|
+ # The output of this specific conv is (N, C_out, 1, 1, 1), which we squeeze
|
|
|
+ Y = torch.empty((N, C_out), dtype=x.dtype, device=x.device)
|
|
|
+
|
|
|
+ # Define the grid for launching the Triton kernel
|
|
|
+ # Each kernel instance handles one patch (N) for one output channel (C_out)
|
|
|
+ grid = (N, C_out)
|
|
|
+
|
|
|
+ # Launch the kernel
|
|
|
+ # We pass all strides to make the kernel flexible
|
|
|
+ _conv3d_patchify_kernel[grid](
|
|
|
+ x, weight, Y,
|
|
|
+ N, C_in, D_in, H_in, W_in,
|
|
|
+ C_out, KD, KH, KW,
|
|
|
+ x.stride(0), x.stride(1), x.stride(2), x.stride(3), x.stride(4),
|
|
|
+ weight.stride(0), weight.stride(1), weight.stride(2), weight.stride(3), weight.stride(4),
|
|
|
+ Y.stride(0), Y.stride(1),
|
|
|
+ BLOCK_SIZE=16, # A reasonable default, can be tuned
|
|
|
+ )
|
|
|
+
|
|
|
+ return Y
|
|
|
+```
|
|
|
+---
|
|
|
+**方案2**
|
|
|
+**2.2qwen2_vl.py文件498行class Qwen2VisionPatchEmbed(nn.Module)函数,PS.就是这玩意AMD没有现成的内核算子导致回退,这里我们直接5D张量一步到位,改为矩阵乘法**
|
|
|
+```
|
|
|
+class Qwen2VisionPatchEmbed(nn.Module):
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ patch_size: int = 14,
|
|
|
+ temporal_patch_size: int = 2,
|
|
|
+ in_channels: int = 3,
|
|
|
+ embed_dim: int = 1152,
|
|
|
+ ) -> None:
|
|
|
+ super().__init__()
|
|
|
+ self.patch_size = patch_size
|
|
|
+ self.temporal_patch_size = temporal_patch_size
|
|
|
+ self.embed_dim = embed_dim
|
|
|
+
|
|
|
+ kernel_size = (temporal_patch_size, patch_size, patch_size)
|
|
|
+
|
|
|
+ self.proj = nn.Conv3d(in_channels,
|
|
|
+ embed_dim,
|
|
|
+ kernel_size=kernel_size,
|
|
|
+ stride=kernel_size,
|
|
|
+ bias=False)
|
|
|
+
|
|
|
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
+ L, C = x.shape
|
|
|
+ x_reshaped_5d = x.view(L, -1, self.temporal_patch_size, self.patch_size,
|
|
|
+ self.patch_size)
|
|
|
+
|
|
|
+ return triton_conv3d_patchify(x_reshaped_5d, self.proj.weight)
|
|
|
+```
|
|
|
+**3.2XXX/vllm/model_executor/models/目录下创建qwen2_vl_vision_kernels.py文件,用triton实现**
|
|
|
+```
|
|
|
+import torch
|
|
|
+from vllm.triton_utils import tl, triton
|
|
|
+
|
|
|
+@triton.jit
|
|
|
+def _conv_gemm_kernel(
|
|
|
+ A, B, C, M, N, K,
|
|
|
+ stride_am, stride_ak,
|
|
|
+ stride_bk, stride_bn,
|
|
|
+ stride_cm, stride_cn,
|
|
|
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
|
|
|
+):
|
|
|
+ pid_m = tl.program_id(0)
|
|
|
+ pid_n = tl.program_id(1)
|
|
|
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
|
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
|
+ offs_k = tl.arange(0, BLOCK_K)
|
|
|
+ a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
|
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
|
|
|
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
|
|
+ for k in range(0, K, BLOCK_K):
|
|
|
+ a = tl.load(a_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0)
|
|
|
+ b = tl.load(b_ptrs, mask=(offs_k[:, None] < K) & (offs_n[None, :] < N), other=0.0)
|
|
|
+ accumulator += tl.dot(a, b)
|
|
|
+ a_ptrs += BLOCK_K * stride_ak
|
|
|
+ b_ptrs += BLOCK_K * stride_bk
|
|
|
+ offs_k += BLOCK_K
|
|
|
+ c = accumulator.to(C.dtype.element_ty)
|
|
|
+ offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
|
+ offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
|
|
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
|
|
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
|
+ tl.store(c_ptrs, c, mask=c_mask)
|
|
|
+
|
|
|
+def triton_conv3d_patchify(x_5d: torch.Tensor, weight_5d: torch.Tensor) -> torch.Tensor:
|
|
|
+ N_patches, _, _, _, _ = x_5d.shape
|
|
|
+ C_out, _, _, _, _ = weight_5d.shape
|
|
|
+ A = x_5d.view(N_patches, -1)
|
|
|
+ B = weight_5d.view(C_out, -1).transpose(0, 1).contiguous()
|
|
|
+ M, K = A.shape
|
|
|
+ _K, N = B.shape
|
|
|
+ assert K == _K
|
|
|
+ C = torch.empty((M, N), device=A.device, dtype=A.dtype)
|
|
|
+
|
|
|
+ # --- 针对7900xtx的手动调优配置,其他GPU的最优组合可能需要自行寻找,AMD的autotune效果就是没有效果 ---
|
|
|
+ best_config = {
|
|
|
+ 'BLOCK_M': 128,
|
|
|
+ 'BLOCK_N': 128,
|
|
|
+ 'BLOCK_K': 32,
|
|
|
+ }
|
|
|
+ num_stages = 4
|
|
|
+ num_warps = 8
|
|
|
+
|
|
|
+ grid = (triton.cdiv(M, best_config['BLOCK_M']),
|
|
|
+ triton.cdiv(N, best_config['BLOCK_N']))
|
|
|
+
|
|
|
+ _conv_gemm_kernel[grid](
|
|
|
+ A, B, C,
|
|
|
+ M, N, K,
|
|
|
+ A.stride(0), A.stride(1),
|
|
|
+ B.stride(0), B.stride(1),
|
|
|
+ C.stride(0), C.stride(1),
|
|
|
+ **best_config,
|
|
|
+ num_stages=num_stages,
|
|
|
+ num_warps=num_warps
|
|
|
+ )
|
|
|
+
|
|
|
+ return C
|
|
|
+```
|
|
|
+---
|
|
|
+**4.关闭终端后再次使用mineru-gradio会报一个Lora错误,修改代码跳过它**
|
|
|
+```
|
|
|
+pip show mineru_vl_utils
|
|
|
+```
|
|
|
+
|
|
|
+打开该文件XXX/mineru_vl_utils/vlm_client/vllm_async_engine_client.py修改第58行self.tokenizer = vllm_async_llm.tokenizer.get_lora_tokenizer()为:
|
|
|
+```
|
|
|
+ try:
|
|
|
+ self.tokenizer = vllm_async_llm.tokenizer.get_lora_tokenizer()
|
|
|
+ except AttributeError:
|
|
|
+ # 如果没有 get_lora_tokenizer 方法,直接使用原始 tokenizer
|
|
|
+ self.tokenizer = vllm_async_llm.tokenizer
|
|
|
+```
|
|
|
+
|
|
|
+**最后整两个环境变量后愉快玩耍即可**
|
|
|
+```
|
|
|
+export MINERU_MODEL_SOURCE=modelscope
|
|
|
+export TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1
|
|
|
+```
|
|
|
+---
|
|
|
+
|
|
|
+### 6.vllm后端已经没有问题,下面是pipeline 中layout用的doclayout-yolo模型空洞卷积问题
|
|
|
+### 我在 [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO/issues/120#issuecomment-3368144275) 下做了一个回答,因此 pipeline 的空洞卷积问题不在这里赘述,直接点击链接查看即可。
|
|
|
+查看自己doclayout-yolo安装位置如下,然后进入修改链接中回复介绍的文件即可
|
|
|
+```
|
|
|
+pip show doclayout-yolo
|
|
|
+```
|
|
|
+
|