소스 검색

Merge pull request #2506 from myhloli/dev

feat(ocr): implement PPHGNetV2 architecture with multiple stages and layers
Xiaomeng Zhao 5 달 전
부모
커밋
04a712f940
1개의 변경된 파일810개의 추가작업 그리고 0개의 파일을 삭제
  1. 810 0
      magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_pphgnetv2.py

+ 810 - 0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/modeling/backbones/rec_pphgnetv2.py

@@ -0,0 +1,810 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class AdaptiveAvgPool2D(nn.AdaptiveAvgPool2d):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        if isinstance(self.output_size, int) and self.output_size == 1:
+            self._gap = True
+        elif (
+            isinstance(self.output_size, tuple)
+            and self.output_size[0] == 1
+            and self.output_size[1] == 1
+        ):
+            self._gap = True
+        else:
+            self._gap = False
+
+    def forward(self, x):
+        if self._gap:
+            # Global Average Pooling
+            N, C, _, _ = x.shape
+            x_mean = torch.mean(x, dim=[2, 3])
+            x_mean = torch.reshape(x_mean, [N, C, 1, 1])
+            return x_mean
+        else:
+            return F.adaptive_avg_pool2d(
+                x,
+                output_size=self.output_size
+            )
+
+class LearnableAffineBlock(nn.Module):
+    """
+    Create a learnable affine block module. This module can significantly improve accuracy on smaller models.
+
+    Args:
+        scale_value (float): The initial value of the scale parameter, default is 1.0.
+        bias_value (float): The initial value of the bias parameter, default is 0.0.
+        lr_mult (float): The learning rate multiplier, default is 1.0.
+        lab_lr (float): The learning rate, default is 0.01.
+    """
+
+    def __init__(self, scale_value=1.0, bias_value=0.0, lr_mult=1.0, lab_lr=0.01):
+        super().__init__()
+        self.scale = nn.Parameter(torch.Tensor([scale_value]))
+        self.bias = nn.Parameter(torch.Tensor([bias_value]))
+
+    def forward(self, x):
+        return self.scale * x + self.bias
+
+
+class ConvBNAct(nn.Module):
+    """
+    ConvBNAct is a combination of convolution and batchnorm layers.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        kernel_size (int): Size of the convolution kernel. Defaults to 3.
+        stride (int): Stride of the convolution. Defaults to 1.
+        padding (int/str): Padding or padding type for the convolution. Defaults to 1.
+        groups (int): Number of groups for the convolution. Defaults to 1.
+        use_act: (bool): Whether to use activation function. Defaults to True.
+        use_lab (bool): Whether to use the LAB operation. Defaults to False.
+        lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size=3,
+        stride=1,
+        padding=1,
+        groups=1,
+        use_act=True,
+        use_lab=False,
+        lr_mult=1.0,
+    ):
+        super().__init__()
+        self.use_act = use_act
+        self.use_lab = use_lab
+
+        self.conv = nn.Conv2d(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride,
+            padding=padding if isinstance(padding, str) else (kernel_size - 1) // 2,
+            # padding=(kernel_size - 1) // 2,
+            groups=groups,
+            bias=False,
+        )
+        self.bn = nn.BatchNorm2d(
+            out_channels,
+        )
+        if self.use_act:
+            self.act = nn.ReLU()
+            if self.use_lab:
+                self.lab = LearnableAffineBlock(lr_mult=lr_mult)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        if self.use_act:
+            x = self.act(x)
+            if self.use_lab:
+                x = self.lab(x)
+        return x
+
+
+class LightConvBNAct(nn.Module):
+    """
+    LightConvBNAct is a combination of pw and dw layers.
+
+    Args:
+        in_channels (int): Number of input channels.
+        out_channels (int): Number of output channels.
+        kernel_size (int): Size of the depth-wise convolution kernel.
+        use_lab (bool): Whether to use the LAB operation. Defaults to False.
+        lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        kernel_size,
+        use_lab=False,
+        lr_mult=1.0,
+        **kwargs,
+    ):
+        super().__init__()
+        self.conv1 = ConvBNAct(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=1,
+            use_act=False,
+            use_lab=use_lab,
+            lr_mult=lr_mult,
+        )
+        self.conv2 = ConvBNAct(
+            in_channels=out_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            groups=out_channels,
+            use_act=True,
+            use_lab=use_lab,
+            lr_mult=lr_mult,
+        )
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        return x
+
+
+class CustomMaxPool2d(nn.Module):
+    def __init__(
+            self,
+            kernel_size,
+            stride=None,
+            padding=0,
+            dilation=1,
+            return_indices=False,
+            ceil_mode=False,
+            data_format="NCHW",
+    ):
+        super(CustomMaxPool2d, self).__init__()
+        self.kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, kernel_size)
+        self.stride = stride if stride is not None else self.kernel_size
+        self.stride = self.stride if isinstance(self.stride, (tuple, list)) else (self.stride, self.stride)
+        self.dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
+        self.return_indices = return_indices
+        self.ceil_mode = ceil_mode
+        self.padding_mode = padding
+
+        # 当padding不是"same"时使用标准MaxPool2d
+        if padding != "same":
+            self.padding = padding if isinstance(padding, (tuple, list)) else (padding, padding)
+            self.pool = nn.MaxPool2d(
+                kernel_size=self.kernel_size,
+                stride=self.stride,
+                padding=self.padding,
+                dilation=self.dilation,
+                return_indices=self.return_indices,
+                ceil_mode=self.ceil_mode
+            )
+
+    def forward(self, x):
+        # 处理same padding
+        if self.padding_mode == "same":
+            input_height, input_width = x.size(2), x.size(3)
+
+            # 计算期望的输出尺寸
+            out_height = math.ceil(input_height / self.stride[0])
+            out_width = math.ceil(input_width / self.stride[1])
+
+            # 计算需要的padding
+            pad_height = max((out_height - 1) * self.stride[0] + self.kernel_size[0] - input_height, 0)
+            pad_width = max((out_width - 1) * self.stride[1] + self.kernel_size[1] - input_width, 0)
+
+            # 将padding分配到两边
+            pad_top = pad_height // 2
+            pad_bottom = pad_height - pad_top
+            pad_left = pad_width // 2
+            pad_right = pad_width - pad_left
+
+            # 应用padding
+            x = F.pad(x, (pad_left, pad_right, pad_top, pad_bottom))
+
+            # 使用标准max_pool2d函数
+            if self.return_indices:
+                return F.max_pool2d_with_indices(
+                    x,
+                    kernel_size=self.kernel_size,
+                    stride=self.stride,
+                    padding=0,  # 已经手动pad过了
+                    dilation=self.dilation,
+                    ceil_mode=self.ceil_mode
+                )
+            else:
+                return F.max_pool2d(
+                    x,
+                    kernel_size=self.kernel_size,
+                    stride=self.stride,
+                    padding=0,  # 已经手动pad过了
+                    dilation=self.dilation,
+                    ceil_mode=self.ceil_mode
+                )
+        else:
+            # 使用预定义的MaxPool2d
+            return self.pool(x)
+
+class StemBlock(nn.Module):
+    """
+    StemBlock for PP-HGNetV2.
+
+    Args:
+        in_channels (int): Number of input channels.
+        mid_channels (int): Number of middle channels.
+        out_channels (int): Number of output channels.
+        use_lab (bool): Whether to use the LAB operation. Defaults to False.
+        lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        mid_channels,
+        out_channels,
+        use_lab=False,
+        lr_mult=1.0,
+        text_rec=False,
+    ):
+        super().__init__()
+        self.stem1 = ConvBNAct(
+            in_channels=in_channels,
+            out_channels=mid_channels,
+            kernel_size=3,
+            stride=2,
+            use_lab=use_lab,
+            lr_mult=lr_mult,
+        )
+        self.stem2a = ConvBNAct(
+            in_channels=mid_channels,
+            out_channels=mid_channels // 2,
+            kernel_size=2,
+            stride=1,
+            padding="same",
+            use_lab=use_lab,
+            lr_mult=lr_mult,
+        )
+        self.stem2b = ConvBNAct(
+            in_channels=mid_channels // 2,
+            out_channels=mid_channels,
+            kernel_size=2,
+            stride=1,
+            padding="same",
+            use_lab=use_lab,
+            lr_mult=lr_mult,
+        )
+        self.stem3 = ConvBNAct(
+            in_channels=mid_channels * 2,
+            out_channels=mid_channels,
+            kernel_size=3,
+            stride=1 if text_rec else 2,
+            use_lab=use_lab,
+            lr_mult=lr_mult,
+        )
+        self.stem4 = ConvBNAct(
+            in_channels=mid_channels,
+            out_channels=out_channels,
+            kernel_size=1,
+            stride=1,
+            use_lab=use_lab,
+            lr_mult=lr_mult,
+        )
+        self.pool = CustomMaxPool2d(
+            kernel_size=2, stride=1, ceil_mode=True, padding="same"
+        )
+        # self.pool = nn.MaxPool2d(
+        #     kernel_size=2, stride=1, ceil_mode=True, padding=1
+        # )
+
+    def forward(self, x):
+        x = self.stem1(x)
+        x2 = self.stem2a(x)
+        x2 = self.stem2b(x2)
+        x1 = self.pool(x)
+
+        # if x1.shape[2:] != x2.shape[2:]:
+        #     x1 = F.interpolate(x1, size=x2.shape[2:], mode='bilinear', align_corners=False)
+
+        x = torch.cat([x1, x2], 1)
+        x = self.stem3(x)
+        x = self.stem4(x)
+
+        return x
+
+
+class HGV2_Block(nn.Module):
+    """
+    HGV2_Block, the basic unit that constitutes the HGV2_Stage.
+
+    Args:
+        in_channels (int): Number of input channels.
+        mid_channels (int): Number of middle channels.
+        out_channels (int): Number of output channels.
+        kernel_size (int): Size of the convolution kernel. Defaults to 3.
+        layer_num (int): Number of layers in the HGV2 block. Defaults to 6.
+        stride (int): Stride of the convolution. Defaults to 1.
+        padding (int/str): Padding or padding type for the convolution. Defaults to 1.
+        groups (int): Number of groups for the convolution. Defaults to 1.
+        use_act (bool): Whether to use activation function. Defaults to True.
+        use_lab (bool): Whether to use the LAB operation. Defaults to False.
+        lr_mult (float): Learning rate multiplier for the layer. Defaults to 1.0.
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        mid_channels,
+        out_channels,
+        kernel_size=3,
+        layer_num=6,
+        identity=False,
+        light_block=True,
+        use_lab=False,
+        lr_mult=1.0,
+    ):
+        super().__init__()
+        self.identity = identity
+
+        self.layers = nn.ModuleList()
+        block_type = "LightConvBNAct" if light_block else "ConvBNAct"
+        for i in range(layer_num):
+            self.layers.append(
+                eval(block_type)(
+                    in_channels=in_channels if i == 0 else mid_channels,
+                    out_channels=mid_channels,
+                    stride=1,
+                    kernel_size=kernel_size,
+                    use_lab=use_lab,
+                    lr_mult=lr_mult,
+                )
+            )
+        # feature aggregation
+        total_channels = in_channels + layer_num * mid_channels
+        self.aggregation_squeeze_conv = ConvBNAct(
+            in_channels=total_channels,
+            out_channels=out_channels // 2,
+            kernel_size=1,
+            stride=1,
+            use_lab=use_lab,
+            lr_mult=lr_mult,
+        )
+        self.aggregation_excitation_conv = ConvBNAct(
+            in_channels=out_channels // 2,
+            out_channels=out_channels,
+            kernel_size=1,
+            stride=1,
+            use_lab=use_lab,
+            lr_mult=lr_mult,
+        )
+
+    def forward(self, x):
+        identity = x
+        output = []
+        output.append(x)
+        for layer in self.layers:
+            x = layer(x)
+            output.append(x)
+        x = torch.cat(output, dim=1)
+        x = self.aggregation_squeeze_conv(x)
+        x = self.aggregation_excitation_conv(x)
+        if self.identity:
+            x += identity
+        return x
+
+
+class HGV2_Stage(nn.Module):
+    """
+    HGV2_Stage, the basic unit that constitutes the PPHGNetV2.
+
+    Args:
+        in_channels (int): Number of input channels.
+        mid_channels (int): Number of middle channels.
+        out_channels (int): Number of output channels.
+        block_num (int): Number of blocks in the HGV2 stage.
+        layer_num (int): Number of layers in the HGV2 block. Defaults to 6.
+        is_downsample (bool): Whether to use downsampling operation. Defaults to False.
+        light_block (bool): Whether to use light block. Defaults to True.
+        kernel_size (int): Size of the convolution kernel. Defaults to 3.
+        use_lab (bool, optional): Whether to use the LAB operation. Defaults to False.
+        lr_mult (float, optional): Learning rate multiplier for the layer. Defaults to 1.0.
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        mid_channels,
+        out_channels,
+        block_num,
+        layer_num=6,
+        is_downsample=True,
+        light_block=True,
+        kernel_size=3,
+        use_lab=False,
+        stride=2,
+        lr_mult=1.0,
+    ):
+
+        super().__init__()
+        self.is_downsample = is_downsample
+        if self.is_downsample:
+            self.downsample = ConvBNAct(
+                in_channels=in_channels,
+                out_channels=in_channels,
+                kernel_size=3,
+                stride=stride,
+                groups=in_channels,
+                use_act=False,
+                use_lab=use_lab,
+                lr_mult=lr_mult,
+            )
+
+        blocks_list = []
+        for i in range(block_num):
+            blocks_list.append(
+                HGV2_Block(
+                    in_channels=in_channels if i == 0 else out_channels,
+                    mid_channels=mid_channels,
+                    out_channels=out_channels,
+                    kernel_size=kernel_size,
+                    layer_num=layer_num,
+                    identity=False if i == 0 else True,
+                    light_block=light_block,
+                    use_lab=use_lab,
+                    lr_mult=lr_mult,
+                )
+            )
+        self.blocks = nn.Sequential(*blocks_list)
+
+    def forward(self, x):
+        if self.is_downsample:
+            x = self.downsample(x)
+        x = self.blocks(x)
+        return x
+
+
+class DropoutInferDownscale(nn.Module):
+    """
+    实现与Paddle的mode="downscale_in_infer"等效的Dropout
+    训练模式:out = input * mask(直接应用掩码,不进行放大)
+    推理模式:out = input * (1.0 - p)(在推理时按概率缩小)
+    """
+
+    def __init__(self, p=0.5):
+        super().__init__()
+        self.p = p
+
+    def forward(self, x):
+        if self.training:
+            # 训练时:应用随机mask但不放大
+            return F.dropout(x, self.p, training=True) * (1.0 - self.p)
+        else:
+            # 推理时:按照dropout概率缩小输出
+            return x * (1.0 - self.p)
+
+class PPHGNetV2(nn.Module):
+    """
+    PPHGNetV2
+
+    Args:
+        stage_config (dict): Config for PPHGNetV2 stages. such as the number of channels, stride, etc.
+        stem_channels: (list): Number of channels of the stem of the PPHGNetV2.
+        use_lab (bool): Whether to use the LAB operation. Defaults to False.
+        use_last_conv (bool): Whether to use the last conv layer as the output channel. Defaults to True.
+        class_expand (int): Number of channels for the last 1x1 convolutional layer.
+        drop_prob (float): Dropout probability for the last 1x1 convolutional layer. Defaults to 0.0.
+        class_num (int): The number of classes for the classification layer. Defaults to 1000.
+        lr_mult_list (list): Learning rate multiplier for the stages. Defaults to [1.0, 1.0, 1.0, 1.0, 1.0].
+    Returns:
+        model: nn.Layer. Specific PPHGNetV2 model depends on args.
+    """
+
+    def __init__(
+        self,
+        stage_config,
+        stem_channels=[3, 32, 64],
+        use_lab=False,
+        use_last_conv=True,
+        class_expand=2048,
+        dropout_prob=0.0,
+        class_num=1000,
+        lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
+        det=False,
+        text_rec=False,
+        out_indices=None,
+        **kwargs,
+    ):
+        super().__init__()
+        self.det = det
+        self.text_rec = text_rec
+        self.use_lab = use_lab
+        self.use_last_conv = use_last_conv
+        self.class_expand = class_expand
+        self.class_num = class_num
+        self.out_indices = out_indices if out_indices is not None else [0, 1, 2, 3]
+        self.out_channels = []
+
+        # stem
+        self.stem = StemBlock(
+            in_channels=stem_channels[0],
+            mid_channels=stem_channels[1],
+            out_channels=stem_channels[2],
+            use_lab=use_lab,
+            lr_mult=lr_mult_list[0],
+            text_rec=text_rec,
+        )
+
+        # stages
+        self.stages = nn.ModuleList()
+        for i, k in enumerate(stage_config):
+            (
+                in_channels,
+                mid_channels,
+                out_channels,
+                block_num,
+                is_downsample,
+                light_block,
+                kernel_size,
+                layer_num,
+                stride,
+            ) = stage_config[k]
+            self.stages.append(
+                HGV2_Stage(
+                    in_channels,
+                    mid_channels,
+                    out_channels,
+                    block_num,
+                    layer_num,
+                    is_downsample,
+                    light_block,
+                    kernel_size,
+                    use_lab,
+                    stride,
+                    lr_mult=lr_mult_list[i + 1],
+                )
+            )
+            if i in self.out_indices:
+                self.out_channels.append(out_channels)
+        if not self.det:
+            self.out_channels = stage_config["stage4"][2]
+
+        self.avg_pool = AdaptiveAvgPool2D(1)
+
+        if self.use_last_conv:
+            self.last_conv = nn.Conv2d(
+                in_channels=out_channels,
+                out_channels=self.class_expand,
+                kernel_size=1,
+                stride=1,
+                padding=0,
+                bias=False,
+            )
+            self.act = nn.ReLU()
+            if self.use_lab:
+                self.lab = LearnableAffineBlock()
+            self.dropout = DropoutInferDownscale(p=dropout_prob)
+
+        self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
+        if not self.det:
+            self.fc = nn.Linear(
+                self.class_expand if self.use_last_conv else out_channels,
+                self.class_num,
+            )
+
+        self._init_weights()
+
+    def _init_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight)
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.ones_(m.weight)
+                nn.init.zeros_(m.bias)
+            elif isinstance(m, nn.Linear):
+                nn.init.zeros_(m.bias)
+
+    def forward(self, x):
+        x = self.stem(x)
+        out = []
+        for i, stage in enumerate(self.stages):
+            x = stage(x)
+            if self.det and i in self.out_indices:
+                out.append(x)
+        if self.det:
+            return out
+
+        if self.text_rec:
+            if self.training:
+                x = F.adaptive_avg_pool2d(x, [1, 40])
+            else:
+                x = F.avg_pool2d(x, [3, 2])
+        return x
+
+
+def PPHGNetV2_B0(pretrained=False, use_ssld=False, **kwargs):
+    """
+    PPHGNetV2_B0
+    Args:
+        pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
+                    If str, means the path of the pretrained model.
+        use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
+    Returns:
+        model: nn.Layer. Specific `PPHGNetV2_B0` model depends on args.
+    """
+    stage_config = {
+        # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
+        "stage1": [16, 16, 64, 1, False, False, 3, 3],
+        "stage2": [64, 32, 256, 1, True, False, 3, 3],
+        "stage3": [256, 64, 512, 2, True, True, 5, 3],
+        "stage4": [512, 128, 1024, 1, True, True, 5, 3],
+    }
+
+    model = PPHGNetV2(
+        stem_channels=[3, 16, 16], stage_config=stage_config, use_lab=True, **kwargs
+    )
+    return model
+
+
+def PPHGNetV2_B1(pretrained=False, use_ssld=False, **kwargs):
+    """
+    PPHGNetV2_B1
+    Args:
+        pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
+                    If str, means the path of the pretrained model.
+        use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
+    Returns:
+        model: nn.Layer. Specific `PPHGNetV2_B1` model depends on args.
+    """
+    stage_config = {
+        # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
+        "stage1": [32, 32, 64, 1, False, False, 3, 3],
+        "stage2": [64, 48, 256, 1, True, False, 3, 3],
+        "stage3": [256, 96, 512, 2, True, True, 5, 3],
+        "stage4": [512, 192, 1024, 1, True, True, 5, 3],
+    }
+
+    model = PPHGNetV2(
+        stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
+    )
+    return model
+
+
+def PPHGNetV2_B2(pretrained=False, use_ssld=False, **kwargs):
+    """
+    PPHGNetV2_B2
+    Args:
+        pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
+                    If str, means the path of the pretrained model.
+        use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
+    Returns:
+        model: nn.Layer. Specific `PPHGNetV2_B2` model depends on args.
+    """
+    stage_config = {
+        # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
+        "stage1": [32, 32, 96, 1, False, False, 3, 4],
+        "stage2": [96, 64, 384, 1, True, False, 3, 4],
+        "stage3": [384, 128, 768, 3, True, True, 5, 4],
+        "stage4": [768, 256, 1536, 1, True, True, 5, 4],
+    }
+
+    model = PPHGNetV2(
+        stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
+    )
+    return model
+
+
+def PPHGNetV2_B3(pretrained=False, use_ssld=False, **kwargs):
+    """
+    PPHGNetV2_B3
+    Args:
+        pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
+                    If str, means the path of the pretrained model.
+        use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
+    Returns:
+        model: nn.Layer. Specific `PPHGNetV2_B3` model depends on args.
+    """
+    stage_config = {
+        # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
+        "stage1": [32, 32, 128, 1, False, False, 3, 5],
+        "stage2": [128, 64, 512, 1, True, False, 3, 5],
+        "stage3": [512, 128, 1024, 3, True, True, 5, 5],
+        "stage4": [1024, 256, 2048, 1, True, True, 5, 5],
+    }
+
+    model = PPHGNetV2(
+        stem_channels=[3, 24, 32], stage_config=stage_config, use_lab=True, **kwargs
+    )
+    return model
+
+
+def PPHGNetV2_B4(pretrained=False, use_ssld=False, det=False, text_rec=False, **kwargs):
+    """
+    PPHGNetV2_B4
+    Args:
+        pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
+                    If str, means the path of the pretrained model.
+        use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
+    Returns:
+        model: nn.Layer. Specific `PPHGNetV2_B4` model depends on args.
+    """
+    stage_config_rec = {
+        # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num, stride
+        "stage1": [48, 48, 128, 1, True, False, 3, 6, [2, 1]],
+        "stage2": [128, 96, 512, 1, True, False, 3, 6, [1, 2]],
+        "stage3": [512, 192, 1024, 3, True, True, 5, 6, [2, 1]],
+        "stage4": [1024, 384, 2048, 1, True, True, 5, 6, [2, 1]],
+    }
+
+    stage_config_det = {
+        # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
+        "stage1": [48, 48, 128, 1, False, False, 3, 6, 2],
+        "stage2": [128, 96, 512, 1, True, False, 3, 6, 2],
+        "stage3": [512, 192, 1024, 3, True, True, 5, 6, 2],
+        "stage4": [1024, 384, 2048, 1, True, True, 5, 6, 2],
+    }
+    model = PPHGNetV2(
+        stem_channels=[3, 32, 48],
+        stage_config=stage_config_det if det else stage_config_rec,
+        use_lab=False,
+        det=det,
+        text_rec=text_rec,
+        **kwargs,
+    )
+    return model
+
+
+def PPHGNetV2_B5(pretrained=False, use_ssld=False, **kwargs):
+    """
+    PPHGNetV2_B5
+    Args:
+        pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
+                    If str, means the path of the pretrained model.
+        use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
+    Returns:
+        model: nn.Layer. Specific `PPHGNetV2_B5` model depends on args.
+    """
+    stage_config = {
+        # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
+        "stage1": [64, 64, 128, 1, False, False, 3, 6],
+        "stage2": [128, 128, 512, 2, True, False, 3, 6],
+        "stage3": [512, 256, 1024, 5, True, True, 5, 6],
+        "stage4": [1024, 512, 2048, 2, True, True, 5, 6],
+    }
+
+    model = PPHGNetV2(
+        stem_channels=[3, 32, 64], stage_config=stage_config, use_lab=False, **kwargs
+    )
+    return model
+
+
+def PPHGNetV2_B6(pretrained=False, use_ssld=False, **kwargs):
+    """
+    PPHGNetV2_B6
+    Args:
+        pretrained (bool/str): If `True` load pretrained parameters, `False` otherwise.
+                    If str, means the path of the pretrained model.
+        use_ssld (bool) Whether using ssld pretrained model when pretrained is True.
+    Returns:
+        model: nn.Layer. Specific `PPHGNetV2_B6` model depends on args.
+    """
+    stage_config = {
+        # in_channels, mid_channels, out_channels, num_blocks, is_downsample, light_block, kernel_size, layer_num
+        "stage1": [96, 96, 192, 2, False, False, 3, 6],
+        "stage2": [192, 192, 512, 3, True, False, 3, 6],
+        "stage3": [512, 384, 1024, 6, True, True, 5, 6],
+        "stage4": [1024, 768, 2048, 3, True, True, 5, 6],
+    }
+
+    model = PPHGNetV2(
+        stem_channels=[3, 48, 96], stage_config=stage_config, use_lab=False, **kwargs
+    )
+    return model