Selaa lähdekoodia

Seal text det (#2043)

* add seal det

* add seal det

* add seal det

* add seal det

* add seal det

* Update PP-OCRv4_mobile_seal_det.yaml

* Update PP-OCRv4_server_seal_det.yaml

* Update support_model_list.md

* Update ocr.py

* Update text_det.py

* Update transforms.py

* Update text_det.py

* Update transforms.py

* Update transforms.py

* Update text_det.py

* ocr pipeline
Sunflower7788 1 vuosi sitten
vanhempi
commit
984425190e

+ 6 - 2
README.md

@@ -117,12 +117,16 @@ PaddleX 3.0 覆盖了 16 条产业级模型产线,其中 9 条基础产线可
     <summary><b>more</b></summary><br/>Mask-RT-DETR-L<br/>Mask-RT-DETR-X<br/>Mask-RT-DETR-H<br/>SOLOv2<br/>MaskRCNN-ResNet50<br/>MaskRCNN-ResNet50-FPN<br/>MaskRCNN-ResNet50-vd-FPN<br/>MaskRCNN-ResNet50-vd-SSLDv2-FPN<br/>MaskRCNN-ResNet101-FPN<br/>MaskRCNN-ResNet101-vd-FPN<br/>MaskRCNN-ResNeXt101-vd-FPN</td>Cascade-MaskRCNN-ResNet50-FPN</td>Cascade-MaskRCNN-ResNet50-vd-SSLDv2-FPN</td>PP-YOLOE_seg-S</td>
   </tr>
   <tr>
-    <td rowspan="3">基础产线</td>
-    <td rowspan="3">通用OCR</td>
+    <td rowspan="4">基础产线</td>
+    <td rowspan="4">通用OCR</td>
     <td>文本检测</td>
     <td>PP-OCRv4_mobile_det<br/>PP-OCRv4_server_det</td>
   </tr>
   <tr>
+    <td>印章文本检测</td>
+    <td>PP-OCRv4_mobile_seal_det<br/>PP-OCRv4_server_seal_det</td>
+  </tr>
+  <tr>
     <td>文本识别</td>
     <td>PP-OCRv4_mobile_rec<br/>PP-OCRv4_server_rec</td>
   </tr>

+ 7 - 1
docs/tutorials/models/support_model_list.md

@@ -264,11 +264,17 @@
 | :--- | :---: |
 | SLANet | [SLANet.yaml](../../../paddlex/configs/table_recognition/SLANet.yaml)|
 ## 六、文本检测
-### 1.PP-OCRv4 系列
+### 1.PP-OCRv4常规文本检测系列
 | 模型名称 | config |
 | :--- | :---: |
 | PP-OCRv4_server_det | [PP-OCRv4_server_det.yaml](../../../paddlex/configs/text_detection/PP-OCRv4_server_det.yaml)|
 | PP-OCRv4_mobile_det | [PP-OCRv4_mobile_det.yaml](../../../paddlex/configs/text_detection/PP-OCRv4_mobile_det.yaml)|
+
+### 1.PP-OCRv4 印章文本检测系列
+| 模型名称 | config |
+| :--- | :---: |
+| PP-OCRv4_server_seal_det | [PP-OCRv4_server_seal_det.yaml](../../../paddlex/configs/text_detection_seal/PP-OCRv4_server_seal_det.yaml)|
+| PP-OCRv4_mobile_seal_det | [PP-OCRv4_mobile_det.yaml](../../../paddlex/configs/text_detection_seal/PP-OCRv4_mobile_seal_det.yaml)|
 ## 七、文本识别
 ### 1.PP-OCRv4 系列
 | 模型名称 | config |

+ 40 - 0
paddlex/configs/text_detection_seal/PP-OCRv4_mobile_seal_det.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: PP-OCRv4_mobile_seal_det
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  module: text_det
+  dataset_dir: "/paddle/dataset/paddlex/ocr_det/ocr_curve_det_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 100
+  batch_size: 8
+  learning_rate: 0.001
+  pretrain_weight_path: null
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: https://paddleocr.bj.bcebos.com/pretrained/ch_PP-OCRv4_mobile_det_curve_trained.pdparams
+
+Predict:
+  model_dir: "output/best_accuracy"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/seal_text_det.png"
+  kernel_option:
+    run_mode: paddle
+    batch_size: 1

+ 40 - 0
paddlex/configs/text_detection_seal/PP-OCRv4_server_seal_det.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: PP-OCRv4_server_seal_det
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  module: text_det
+  dataset_dir: "/paddle/dataset/paddlex/ocr_det/ocr_curve_det_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 100
+  batch_size: 8
+  learning_rate: 0.001
+  pretrain_weight_path: null
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: https://paddleocr.bj.bcebos.com/pretrained/ch_PP-OCRv4_det_curve_trained.pdparams
+
+Predict:
+  model_dir: "output/best_accuracy"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/seal_text_det.png"
+  kernel_option:
+    run_mode: paddle
+    batch_size: 1

+ 699 - 0
paddlex/inference/components/task_related/seal_det_warp.py

@@ -0,0 +1,699 @@
+import os, sys
+import numpy as np
+from numpy import cos, sin, arctan, sqrt
+import cv2
+import copy
+import time
+
+
+def Homography(image, img_points, world_width, world_height,
+               interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0):
+    """
+    将图像透视变换到新的视角,返回变换后的图像。
+    
+    Args:
+        image (np.ndarray): 输入的图像,应为numpy数组类型。
+        img_points (List[Tuple[int, int]]): 图像上的四个点的坐标,顺序为左上角、右上角、右下角、左下角。
+        world_width (int): 变换后图像在世界坐标系中的宽度。
+        world_height (int): 变换后图像在世界坐标系中的高度。
+        interpolation (int, optional): 插值方式,默认为cv2.INTER_CUBIC。
+        ratio_width (float, optional): 变换后图像在x轴上的缩放比例,默认为1.0。
+        ratio_height (float, optional): 变换后图像在y轴上的缩放比例,默认为1.0。
+    
+    Returns:
+        np.ndarray: 变换后的图像,为numpy数组类型。
+    
+    """
+    _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
+
+    expand_x = int(0.5 * world_width * (ratio_width - 1))
+    expand_y = int(0.5 * world_height * (ratio_height - 1))
+
+    pt_lefttop = [expand_x, expand_y]
+    pt_righttop = [expand_x + world_width, expand_y]
+    pt_leftbottom = [expand_x + world_width, expand_y + world_height]
+    pt_rightbottom = [expand_x, expand_y + world_height]
+
+    pts_std = np.float32([pt_lefttop, pt_righttop,
+                          pt_leftbottom, pt_rightbottom])
+
+    img_crop_width = int(world_width * ratio_width)
+    img_crop_height = int(world_height * ratio_height)
+
+    M = cv2.getPerspectiveTransform(_points, pts_std)
+
+    dst_img = cv2.warpPerspective(
+        image,
+        M, (img_crop_width, img_crop_height),
+        borderMode=cv2.BORDER_CONSTANT,  # BORDER_CONSTANT BORDER_REPLICATE
+        flags=interpolation)
+
+    return dst_img
+
+
+class CurveTextRectifier:
+    """
+    spatial transformer via monocular vision
+    """
+    def __init__(self):
+        self.get_virtual_camera_parameter()
+
+
+    def get_virtual_camera_parameter(self):
+        vcam_thz = 0
+        vcam_thx1 = 180
+        vcam_thy = 180
+        vcam_thx2 = 0
+
+        vcam_x = 0
+        vcam_y = 0
+        vcam_z = 100
+
+        radian = np.pi / 180
+
+        angle_z = radian * vcam_thz
+        angle_x1 = radian * vcam_thx1
+        angle_y = radian * vcam_thy
+        angle_x2 = radian * vcam_thx2
+
+        optic_x = vcam_x
+        optic_y = vcam_y
+        optic_z = vcam_z
+
+        fu = 100
+        fv = 100
+
+        matT = np.zeros((4, 4))
+        matT[0, 0] = cos(angle_z) * cos(angle_y) - sin(angle_z) * sin(angle_x1) * sin(angle_y)
+        matT[0, 1] = cos(angle_z) * sin(angle_y) * sin(angle_x2) - sin(angle_z) * (
+                    cos(angle_x1) * cos(angle_x2) - sin(angle_x1) * cos(angle_y) * sin(angle_x2))
+        matT[0, 2] = cos(angle_z) * sin(angle_y) * cos(angle_x2) + sin(angle_z) * (
+                    cos(angle_x1) * sin(angle_x2) + sin(angle_x1) * cos(angle_y) * cos(angle_x2))
+        matT[0, 3] = optic_x
+        matT[1, 0] = sin(angle_z) * cos(angle_y) + cos(angle_z) * sin(angle_x1) * sin(angle_y)
+        matT[1, 1] = sin(angle_z) * sin(angle_y) * sin(angle_x2) + cos(angle_z) * (
+                    cos(angle_x1) * cos(angle_x2) - sin(angle_x1) * cos(angle_y) * sin(angle_x2))
+        matT[1, 2] = sin(angle_z) * sin(angle_y) * cos(angle_x2) - cos(angle_z) * (
+                    cos(angle_x1) * sin(angle_x2) + sin(angle_x1) * cos(angle_y) * cos(angle_x2))
+        matT[1, 3] = optic_y
+        matT[2, 0] = -cos(angle_x1) * sin(angle_y)
+        matT[2, 1] = cos(angle_x1) * cos(angle_y) * sin(angle_x2) + sin(angle_x1) * cos(angle_x2)
+        matT[2, 2] = cos(angle_x1) * cos(angle_y) * cos(angle_x2) - sin(angle_x1) * sin(angle_x2)
+        matT[2, 3] = optic_z
+        matT[3, 0] = 0
+        matT[3, 1] = 0
+        matT[3, 2] = 0
+        matT[3, 3] = 1
+
+        matS = np.zeros((4, 4))
+        matS[2, 3] = 0.5
+        matS[3, 2] = 0.5
+
+        self.ifu = 1 / fu
+        self.ifv = 1 / fv
+
+        self.matT = matT
+        self.matS = matS
+        self.K = np.dot(matT.T, matS)
+        self.K = np.dot(self.K, matT)
+
+
+    def vertical_text_process(self, points, org_size):
+        """
+        change sequence amd process
+        :param points:
+        :param org_size:
+        :return:
+        """
+        org_w, org_h = org_size
+        _points = np.array(points).reshape(-1).tolist()
+        _points = np.array(_points[2:] + _points[:2]).reshape(-1, 2)
+
+        # convert to horizontal points
+        adjusted_points = np.zeros(_points.shape, dtype=np.float32)
+        adjusted_points[:, 0] = _points[:, 1]
+        adjusted_points[:, 1] = org_h - _points[:, 0] - 1
+
+        _image_coord, _world_coord, _new_image_size = self.horizontal_text_process(adjusted_points)
+
+        # # convert to vertical points back
+        image_coord = _points.reshape(1, -1, 2)
+        world_coord = np.zeros(_world_coord.shape, dtype=np.float32)
+        world_coord[:, :, 0] = 0 - _world_coord[:, :, 1]
+        world_coord[:, :, 1] = _world_coord[:, :, 0]
+        world_coord[:, :, 2] = _world_coord[:, :, 2]
+        new_image_size = (_new_image_size[1], _new_image_size[0])
+
+        return image_coord, world_coord, new_image_size
+
+
+    def horizontal_text_process(self, points):
+        """
+        get image coordinate and world coordinate
+        :param points:
+        :return:
+        """
+        poly = np.array(points).reshape(-1)
+
+        dx_list = []
+        dy_list = []
+        for i in range(1, len(poly) // 2):
+            xdx = poly[i * 2] - poly[(i - 1) * 2]
+            xdy = poly[i * 2 + 1] - poly[(i - 1) * 2 + 1]
+            d = sqrt(xdx ** 2 + xdy ** 2)
+            dx_list.append(d)
+
+        for i in range(0, len(poly) // 4):
+            ydx = poly[i * 2] - poly[len(poly) - 1 - (i * 2 + 1)]
+            ydy = poly[i * 2 + 1] - poly[len(poly) - 1 - (i * 2)]
+            d = sqrt(ydx ** 2 + ydy ** 2)
+            dy_list.append(d)
+
+        dx_list = [(dx_list[i] + dx_list[len(dx_list) - 1 - i]) / 2 for i in range(len(dx_list) // 2)]
+
+        height = np.around(np.mean(dy_list))
+
+        rect_coord = [0, 0]
+        for i in range(0, len(poly) // 4 - 1):
+            x = rect_coord[-2]
+            x += dx_list[i]
+            y = 0
+            rect_coord.append(x)
+            rect_coord.append(y)
+
+        rect_coord_half = copy.deepcopy(rect_coord)
+        for i in range(0, len(poly) // 4):
+            x = rect_coord_half[len(rect_coord_half) - 2 * i - 2]
+            y = height
+            rect_coord.append(x)
+            rect_coord.append(y)
+
+        np_rect_coord = np.array(rect_coord).reshape(-1, 2)
+        x_min = np.min(np_rect_coord[:, 0])
+        y_min = np.min(np_rect_coord[:, 1])
+        x_max = np.max(np_rect_coord[:, 0])
+        y_max = np.max(np_rect_coord[:, 1])
+        new_image_size = (int(x_max - x_min + 0.5), int(y_max - y_min + 0.5))
+        x_mean = (x_max - x_min) / 2
+        y_mean = (y_max - y_min) / 2
+        np_rect_coord[:, 0] -= x_mean
+        np_rect_coord[:, 1] -= y_mean
+        rect_coord = np_rect_coord.reshape(-1).tolist()
+
+        rect_coord = np.array(rect_coord).reshape(-1, 2)
+        world_coord = np.ones((len(rect_coord), 3)) * 0
+
+        world_coord[:, :2] = rect_coord
+
+        image_coord = np.array(poly).reshape(1, -1, 2)
+        world_coord = world_coord.reshape(1, -1, 3)
+
+        return image_coord, world_coord, new_image_size
+
+
+    def horizontal_text_estimate(self, points):
+        """
+        horizontal or vertical text
+        :param points:
+        :return:
+        """
+        pts = np.array(points).reshape(-1, 2)
+        x_min = int(np.min(pts[:, 0]))
+        y_min = int(np.min(pts[:, 1]))
+        x_max = int(np.max(pts[:, 0]))
+        y_max = int(np.max(pts[:, 1]))
+        x = x_max - x_min
+        y = y_max - y_min
+        is_horizontal_text = True
+        if y / x > 1.5: # vertical text condition
+            is_horizontal_text = False
+        return is_horizontal_text
+
+
+    def virtual_camera_to_world(self, size):
+        ifu, ifv = self.ifu, self.ifv
+        K, matT = self.K, self.matT
+
+        ppu = size[0] / 2 + 1e-6
+        ppv = size[1] / 2 + 1e-6
+
+        P = np.zeros((size[1], size[0], 3))
+
+        lu = np.array([i for i in range(size[0])])
+        lv = np.array([i for i in range(size[1])])
+        u, v = np.meshgrid(lu, lv)
+
+        yp = (v - ppv) * ifv
+        xp = (u - ppu) * ifu
+        angle_a = arctan(sqrt(xp * xp + yp * yp))
+        angle_b = arctan(yp / xp)
+
+        D0 = sin(angle_a) * cos(angle_b)
+        D1 = sin(angle_a) * sin(angle_b)
+        D2 = cos(angle_a)
+
+        D0[xp <= 0] = -D0[xp <= 0]
+        D1[xp <= 0] = -D1[xp <= 0]
+
+        ratio_a = K[0, 0] * D0 * D0 + K[1, 1] * D1 * D1 + K[2, 2] * D2 * D2 + \
+                  (K[0, 1] + K[1, 0]) * D0 * D1 + (K[0, 2] + K[2, 0]) * D0 * D2 + (K[1, 2] + K[2, 1]) * D1 * D2
+        ratio_b = (K[0, 3] + K[3, 0]) * D0 + (K[1, 3] + K[3, 1]) * D1 + (K[2, 3] + K[3, 2]) * D2
+        ratio_c = K[3, 3] * np.ones(ratio_b.shape)
+
+        delta = ratio_b * ratio_b - 4 * ratio_a * ratio_c
+        t = np.zeros(delta.shape)
+        t[ratio_a == 0] = -ratio_c[ratio_a == 0] / ratio_b[ratio_a == 0]
+        t[ratio_a != 0] = (-ratio_b[ratio_a != 0] + sqrt(delta[ratio_a != 0])) / (2 * ratio_a[ratio_a != 0])
+        t[delta < 0] = 0
+
+        P[:, :, 0] = matT[0, 3] + t * (matT[0, 0] * D0 + matT[0, 1] * D1 + matT[0, 2] * D2)
+        P[:, :, 1] = matT[1, 3] + t * (matT[1, 0] * D0 + matT[1, 1] * D1 + matT[1, 2] * D2)
+        P[:, :, 2] = matT[2, 3] + t * (matT[2, 0] * D0 + matT[2, 1] * D1 + matT[2, 2] * D2)
+
+        return P
+
+
+    def world_to_image(self, image_size, world, intrinsic, distCoeffs, rotation, tvec):
+        r11 = rotation[0, 0]
+        r12 = rotation[0, 1]
+        r13 = rotation[0, 2]
+        r21 = rotation[1, 0]
+        r22 = rotation[1, 1]
+        r23 = rotation[1, 2]
+        r31 = rotation[2, 0]
+        r32 = rotation[2, 1]
+        r33 = rotation[2, 2]
+
+        t1 = tvec[0]
+        t2 = tvec[1]
+        t3 = tvec[2]
+
+        k1 = distCoeffs[0]
+        k2 = distCoeffs[1]
+        p1 = distCoeffs[2]
+        p2 = distCoeffs[3]
+        k3 = distCoeffs[4]
+        k4 = distCoeffs[5]
+        k5 = distCoeffs[6]
+        k6 = distCoeffs[7]
+
+        if len(distCoeffs) > 8:
+            s1 = distCoeffs[8]
+            s2 = distCoeffs[9]
+            s3 = distCoeffs[10]
+            s4 = distCoeffs[11]
+        else:
+            s1 = s2 = s3 = s4 = 0
+
+        if len(distCoeffs) > 12:
+            tx = distCoeffs[12]
+            ty = distCoeffs[13]
+        else:
+            tx = ty = 0
+
+        fu = intrinsic[0, 0]
+        fv = intrinsic[1, 1]
+        ppu = intrinsic[0, 2]
+        ppv = intrinsic[1, 2]
+
+        cos_tx = cos(tx)
+        cos_ty = cos(ty)
+        sin_tx = sin(tx)
+        sin_ty = sin(ty)
+
+        tao11 = cos_ty * cos_tx * cos_ty + sin_ty * cos_tx * sin_ty
+        tao12 = cos_ty * cos_tx * sin_ty * sin_tx - sin_ty * cos_tx * cos_ty * sin_tx
+        tao13 = -cos_ty * cos_tx * sin_ty * cos_tx + sin_ty * cos_tx * cos_ty * cos_tx
+        tao21 = -sin_tx * sin_ty
+        tao22 = cos_ty * cos_tx * cos_tx + sin_tx * cos_ty * sin_tx
+        tao23 = cos_ty * cos_tx * sin_tx - sin_tx * cos_ty * cos_tx
+
+        P = np.zeros((image_size[1], image_size[0], 2))
+
+        c3 = r31 * world[:, :, 0] + r32 * world[:, :, 1] + r33 * world[:, :, 2] + t3
+        c1 = r11 * world[:, :, 0] + r12 * world[:, :, 1] + r13 * world[:, :, 2] + t1
+        c2 = r21 * world[:, :, 0] + r22 * world[:, :, 1] + r23 * world[:, :, 2] + t2
+
+        x1 = c1 / c3
+        y1 = c2 / c3
+        x12 = x1 * x1
+        y12 = y1 * y1
+        x1y1 = 2 * x1 * y1
+        r2 = x12 + y12
+        r4 = r2 * r2
+        r6 = r2 * r4
+
+        radial_distortion = (1 + k1 * r2 + k2 * r4 + k3 * r6) / (1 + k4 * r2 + k5 * r4 + k6 * r6)
+        x2 = x1 * radial_distortion + p1 * x1y1 + p2 * (r2 + 2 * x12) + s1 * r2 + s2 * r4
+        y2 = y1 * radial_distortion + p2 * x1y1 + p1 * (r2 + 2 * y12) + s3 * r2 + s4 * r4
+
+        x3 = tao11 * x2 + tao12 * y2 + tao13
+        y3 = tao21 * x2 + tao22 * y2 + tao23
+
+        P[:, :, 0] = fu * x3 + ppu
+        P[:, :, 1] = fv * y3 + ppv
+        P[c3 <= 0] = 0
+
+        return P
+
+
+    def spatial_transform(self, image_data, new_image_size, mtx, dist, rvecs, tvecs, interpolation):
+        rotation, _ = cv2.Rodrigues(rvecs)
+        world_map = self.virtual_camera_to_world(new_image_size)
+        image_map = self.world_to_image(new_image_size, world_map, mtx, dist, rotation, tvecs)
+        image_map = image_map.astype(np.float32)
+        dst = cv2.remap(image_data, image_map[:, :, 0], image_map[:, :, 1], interpolation)
+        return dst
+
+
+    def calibrate(self, org_size, image_coord, world_coord):
+        """
+        calibration
+        :param org_size:
+        :param image_coord:
+        :param world_coord:
+        :return:
+        """
+        # flag = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL  | cv2.CALIB_THIN_PRISM_MODEL
+        flag = cv2.CALIB_RATIONAL_MODEL
+        flag2 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL
+        flag3 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_THIN_PRISM_MODEL
+        flag4 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_ZERO_TANGENT_DIST | cv2.CALIB_FIX_ASPECT_RATIO
+        flag5 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL | cv2.CALIB_ZERO_TANGENT_DIST
+        flag6 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_FIX_ASPECT_RATIO
+        flag_list = [flag2, flag3, flag4, flag5, flag6]
+
+        ret, mtx, dist, rvecs, tvecs = cv2.calibrateCamera(world_coord.astype(np.float32),
+                                                                image_coord.astype(np.float32),
+                                                                org_size,
+                                                                None,
+                                                                None,
+                                                                flags=flag)
+        if ret > 2:
+            # strategies
+            min_ret = ret
+            for i, flag in enumerate(flag_list):
+                _ret, _mtx, _dist, _rvecs, _tvecs = cv2.calibrateCamera(world_coord.astype(np.float32),
+                                                                   image_coord.astype(np.float32),
+                                                                   org_size,
+                                                                   None,
+                                                                   None,
+                                                                   flags=flag)
+                if _ret < min_ret:
+                    min_ret = _ret
+                    ret, mtx, dist, rvecs, tvecs = _ret, _mtx, _dist, _rvecs, _tvecs
+
+        return ret, mtx, dist, rvecs, tvecs
+
+
+    def dc_homo(self, img, img_points, obj_points, is_horizontal_text, interpolation=cv2.INTER_LINEAR,
+                ratio_width=1.0, ratio_height=1.0):
+        """
+        divide and conquer: homography
+        # ratio_width and ratio_height must be 1.0 here
+        """
+        _img_points = img_points.reshape(-1, 2)
+        _obj_points = obj_points.reshape(-1, 3)
+
+        homo_img_list = []
+        width_list = []
+        height_list = []
+        # divide and conquer
+        for i in range(len(_img_points) // 2 - 1):
+            new_img_points = np.zeros((4, 2)).astype(np.float32)
+            new_obj_points = np.zeros((4, 2)).astype(np.float32)
+
+            new_img_points[0:2, :] = _img_points[i:(i + 2), :2]
+            new_img_points[2:4, :] = _img_points[::-1, :][i:(i + 2), :2][::-1, :]
+
+            new_obj_points[0:2, :] = _obj_points[i:(i + 2), :2]
+            new_obj_points[2:4, :] = _obj_points[::-1, :][i:(i + 2), :2][::-1, :]
+
+            if is_horizontal_text:
+                world_width = np.abs(new_obj_points[1, 0] - new_obj_points[0, 0])
+                world_height = np.abs(new_obj_points[3, 1] - new_obj_points[0, 1])
+            else:
+                world_width = np.abs(new_obj_points[1, 1] - new_obj_points[0, 1])
+                world_height = np.abs(new_obj_points[3, 0] - new_obj_points[0, 0])
+
+            homo_img = Homography(img, new_img_points, world_width, world_height,
+                                              interpolation=interpolation,
+                                              ratio_width=ratio_width, ratio_height=ratio_height)
+
+            homo_img_list.append(homo_img)
+            _h, _w = homo_img.shape[:2]
+            width_list.append(_w)
+            height_list.append(_h)
+
+        # stitching
+        rectified_image = np.zeros((np.max(height_list), sum(width_list), 3)).astype(np.uint8)
+
+        st = 0
+        for (homo_img, w, h) in zip(homo_img_list, width_list, height_list):
+            rectified_image[:h, st:st + w, :] = homo_img
+            st += w
+
+        if not is_horizontal_text:
+            # vertical rotation
+            rectified_image = np.rot90(rectified_image, 3)
+
+        return rectified_image
+
+    def Homography(self, image, img_points, world_width, world_height,
+                interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0):
+        """
+        将图像透视变换到新的视角,返回变换后的图像。
+        
+        Args:
+            image (np.ndarray): 输入的图像,应为numpy数组类型。
+            img_points (List[Tuple[int, int]]): 图像上的四个点的坐标,顺序为左上角、右上角、右下角、左下角。
+            world_width (int): 变换后图像在世界坐标系中的宽度。
+            world_height (int): 变换后图像在世界坐标系中的高度。
+            interpolation (int, optional): 插值方式,默认为cv2.INTER_CUBIC。
+            ratio_width (float, optional): 变换后图像在x轴上的缩放比例,默认为1.0。
+            ratio_height (float, optional): 变换后图像在y轴上的缩放比例,默认为1.0。
+        
+        Returns:
+            np.ndarray: 变换后的图像,为numpy数组类型。
+        
+        """
+        _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
+
+        expand_x = int(0.5 * world_width * (ratio_width - 1))
+        expand_y = int(0.5 * world_height * (ratio_height - 1))
+
+        pt_lefttop = [expand_x, expand_y]
+        pt_righttop = [expand_x + world_width, expand_y]
+        pt_leftbottom = [expand_x + world_width, expand_y + world_height]
+        pt_rightbottom = [expand_x, expand_y + world_height]
+
+        pts_std = np.float32([pt_lefttop, pt_righttop,
+                            pt_leftbottom, pt_rightbottom])
+
+        img_crop_width = int(world_width * ratio_width)
+        img_crop_height = int(world_height * ratio_height)
+
+        M = cv2.getPerspectiveTransform(_points, pts_std)
+
+        dst_img = cv2.warpPerspective(
+            image,
+            M, (img_crop_width, img_crop_height),
+            borderMode=cv2.BORDER_CONSTANT,  # BORDER_CONSTANT BORDER_REPLICATE
+            flags=interpolation)
+
+        return dst_img
+
+
+    def __call__(self, image_data, points, interpolation=cv2.INTER_LINEAR, ratio_width=1.0, ratio_height=1.0, mode='calibration'):
+        """
+        spatial transform for a poly text
+        :param image_data:
+        :param points: [x1,y1,x2,y2,x3,y3,...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return:
+        """
+        org_h, org_w = image_data.shape[:2]
+        org_size = (org_w, org_h)
+        self.image = image_data
+
+        is_horizontal_text = self.horizontal_text_estimate(points)
+        if is_horizontal_text:
+            image_coord, world_coord, new_image_size = self.horizontal_text_process(points)
+        else:
+            image_coord, world_coord, new_image_size = self.vertical_text_process(points, org_size)
+
+        if mode.lower() == 'calibration':
+            ret, mtx, dist, rvecs, tvecs = self.calibrate(org_size, image_coord, world_coord)
+
+            st_size = (int(new_image_size[0]*ratio_width), int(new_image_size[1]*ratio_height))
+            dst = self.spatial_transform(image_data, st_size, mtx, dist[0], rvecs[0], tvecs[0], interpolation)
+        elif mode.lower() == 'homography':
+            # ratio_width and ratio_height must be 1.0 here and ret set to 0.01 without loss manually
+            ret = 0.01
+            dst = self.dc_homo(image_data, image_coord, world_coord, is_horizontal_text,
+                               interpolation=interpolation, ratio_width=1.0, ratio_height=1.0)
+        else:
+            raise ValueError('mode must be ["calibration", "homography"], but got {}'.format(mode))
+
+        return dst, ret
+
+
+class AutoRectifier:
+    def __init__(self):
+        self.npoints = 10
+        self.curveTextRectifier = CurveTextRectifier()
+
+    @staticmethod
+    def get_rotate_crop_image(img, points, interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0):
+        """
+        crop or homography
+        :param img:
+        :param points:
+        :param interpolation:
+        :param ratio_width:
+        :param ratio_height:
+        :return:
+        """
+        h, w = img.shape[:2]
+        _points = np.array(points).reshape(-1, 2).astype(np.float32)
+
+        if len(_points) != 4:
+            x_min = int(np.min(_points[:, 0]))
+            y_min = int(np.min(_points[:, 1]))
+            x_max = int(np.max(_points[:, 0]))
+            y_max = int(np.max(_points[:, 1]))
+            dx = x_max - x_min
+            dy = y_max - y_min
+            expand_x = int(0.5 * dx * (ratio_width - 1))
+            expand_y = int(0.5 * dy * (ratio_height - 1))
+            x_min = np.clip(int(x_min - expand_x), 0, w - 1)
+            y_min = np.clip(int(y_min - expand_y), 0, h - 1)
+            x_max = np.clip(int(x_max + expand_x), 0, w - 1)
+            y_max = np.clip(int(y_max + expand_y), 0, h - 1)
+
+            dst_img = img[y_min:y_max, x_min:x_max, :].copy()
+        else:
+            img_crop_width = int(
+                max(
+                    np.linalg.norm(_points[0] - _points[1]),
+                    np.linalg.norm(_points[2] - _points[3])))
+            img_crop_height = int(
+                max(
+                    np.linalg.norm(_points[0] - _points[3]),
+                    np.linalg.norm(_points[1] - _points[2])))
+
+            dst_img = Homography(img, _points, img_crop_width, img_crop_height, interpolation, ratio_width, ratio_height)
+
+        return dst_img
+
+
+    def visualize(self, image_data, points_list):
+        visualization = image_data.copy()
+
+        for box in points_list:
+            box = np.array(box).reshape(-1, 2).astype(np.int32)
+            cv2.drawContours(visualization, [np.array(box).reshape((-1, 1, 2))], -1, (0, 0, 255), 2)
+            for i, p in enumerate(box):
+                if i != 0:
+                    cv2.circle(visualization, tuple(p), radius=1, color=(255, 0, 0), thickness=2)
+                else:
+                    cv2.circle(visualization, tuple(p), radius=1, color=(255, 255, 0), thickness=2)
+        return visualization
+
+
+    def __call__(self, image_data, points, interpolation=cv2.INTER_LINEAR,
+                 ratio_width=1.0, ratio_height=1.0, loss_thresh=5.0, mode='calibration'):
+        """
+        rectification in strategies for a poly text
+        :param image_data:
+        :param points: [x1,y1,x2,y2,x3,y3,...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return:
+        """
+        _points = np.array(points).reshape(-1,2)
+        if len(_points) >= self.npoints and len(_points) % 2 == 0:
+            try:
+                curveTextRectifier = CurveTextRectifier()
+
+                dst_img, loss = curveTextRectifier(image_data, points, interpolation, ratio_width, ratio_height, mode)
+                if loss >= 2:
+                    # for robust
+                    # large loss means it cannot be reconstruct correctly, we must find other way to reconstruct
+                    img_list, loss_list = [dst_img], [loss]
+                    _dst_img, _loss = PlanB()(image_data, points, curveTextRectifier,
+                                              interpolation, ratio_width, ratio_height,
+                                              loss_thresh=loss_thresh,
+                                              square=True)
+                    img_list += [_dst_img]
+                    loss_list += [_loss]
+
+                    _dst_img, _loss = PlanB()(image_data, points, curveTextRectifier,
+                                              interpolation, ratio_width, ratio_height,
+                                              loss_thresh=loss_thresh, square=False)
+                    img_list += [_dst_img]
+                    loss_list += [_loss]
+
+                    min_loss = min(loss_list)
+                    dst_img = img_list[loss_list.index(min_loss)]
+
+                    if min_loss >= loss_thresh:
+                        print('calibration loss: {} is too large for spatial transformer. It is failed. Using get_rotate_crop_image'.format(loss))
+                        dst_img = self.get_rotate_crop_image(image_data, points, interpolation, ratio_width, ratio_height)
+                        print('here')
+            except Exception as e:
+                print(e)
+                dst_img = self.get_rotate_crop_image(image_data, points, interpolation, ratio_width, ratio_height)
+        else:
+            dst_img = self.get_rotate_crop_image(image_data, _points, interpolation, ratio_width, ratio_height)
+
+        return dst_img
+
+
+    def run(self, image_data, points_list, interpolation=cv2.INTER_LINEAR,
+            ratio_width=1.0, ratio_height=1.0, loss_thresh=5.0, mode='calibration'):
+        """
+        run for texts in an image
+        :param image_data: numpy.ndarray. The shape is [h, w, 3]
+        :param points_list: [[x1,y1,x2,y2,x3,y3,...], [x1,y1,x2,y2,x3,y3,...], ...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return: res: roi-image list, visualized_image: draw polys in original image
+        """
+        if image_data is None:
+            raise ValueError
+        if not isinstance(points_list, list):
+            raise ValueError
+        for points in points_list:
+            if not isinstance(points, list):
+                raise ValueError
+
+        if ratio_width < 1.0 or ratio_height < 1.0:
+            raise ValueError('ratio_width and ratio_height cannot be smaller than 1, but got {}', (ratio_width, ratio_height))
+
+        if mode.lower() != 'calibration' and mode.lower() != 'homography':
+            raise ValueError('mode must be ["calibration", "homography"], but got {}'.format(mode))
+
+        if mode.lower() == 'homography' and ratio_width != 1.0 and ratio_height != 1.0:
+            raise ValueError('ratio_width and ratio_height must be 1.0 when mode is homography, but got mode:{}, ratio:({},{})'.format(mode, ratio_width, ratio_height))
+
+        res = []
+        for points in points_list:
+            rectified_img = self(image_data, points, interpolation, ratio_width, ratio_height,
+                                 loss_thresh=loss_thresh, mode=mode)
+            res.append(rectified_img)
+
+        # visualize
+        visualized_image = self.visualize(image_data, points_list)
+
+        return res, visualized_image
+

+ 338 - 8
paddlex/inference/components/task_related/text_det.py

@@ -20,12 +20,14 @@ import copy
 import math
 import pyclipper
 import numpy as np
+from numpy.linalg import norm
 from PIL import Image
 from shapely.geometry import Polygon
 
 from ...utils.io import ImageReader
 from ....utils import logging
 from ..base import BaseComponent
+from .seal_det_warp import AutoRectifier
 
 
 __all__ = ["DetResizeForTest", "NormalizeImage", "DBPostProcess", "CropByPolys"]
@@ -430,19 +432,31 @@ class CropByPolys(BaseComponent):
     def apply(self, img_path, dt_polys):
         """apply"""
         img = self._reader.read(img_path)
-        dt_boxes = np.array(dt_polys)
+        
         # TODO
         # dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
-        output_list = []
-        for bno in range(len(dt_boxes)):
-            tmp_box = copy.deepcopy(dt_boxes[bno])
-            if self.det_box_type == "quad":
-                img_crop = self.get_rotate_crop_image(img, tmp_box)
-            else:
+        if self.det_box_type == "quad":
+            dt_boxes = self.sorted_boxes(dt_polys)
+            dt_boxes = np.array(dt_boxes)
+            output_list = []
+            for bno in range(len(dt_boxes)):
+                tmp_box = copy.deepcopy(dt_boxes[bno])
                 img_crop = self.get_minarea_rect_crop(img, tmp_box)
-            output_list.append(
+                output_list.append(
                 {"img": img_crop, "img_size": [img_crop.shape[1], img_crop.shape[0]]}
             )
+        elif self.det_box_type == "poly":
+            output_list = []
+            dt_boxes = dt_polys
+            for bno in range(len(dt_boxes)):
+                tmp_box = copy.deepcopy(dt_boxes[bno])
+                img_crop = self.get_poly_rect_crop(img.copy(), tmp_box)
+                output_list.append(
+                {"img": img_crop, "img_size": [img_crop.shape[1], img_crop.shape[0]]}
+                )
+        else:
+            raise NotImplementedError
+
         return output_list
 
     def sorted_boxes(self, dt_boxes):
@@ -537,3 +551,319 @@ class CropByPolys(BaseComponent):
         if dst_img_height * 1.0 / dst_img_width >= 1.5:
             dst_img = np.rot90(dst_img)
         return dst_img
+
+    def reorder_poly_edge(self, points):
+        """Get the respective points composing head edge, tail edge, top
+        sideline and bottom sideline.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+
+        Returns:
+            head_edge (ndarray): The two points composing the head edge of text
+                polygon.
+            tail_edge (ndarray): The two points composing the tail edge of text
+                polygon.
+            top_sideline (ndarray): The points composing top curved sideline of
+                text polygon.
+            bot_sideline (ndarray): The points composing bottom curved sideline
+                of text polygon.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+
+        orientation_thr=2.0             # 一个经验超参数
+
+        head_inds, tail_inds = self.find_head_tail(points, orientation_thr)
+        head_edge, tail_edge = points[head_inds], points[tail_inds]
+
+
+        pad_points = np.vstack([points, points])
+        if tail_inds[1] < 1:
+            tail_inds[1] = len(points)
+        sideline1 = pad_points[head_inds[1]:tail_inds[1]]
+        sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
+        return head_edge, tail_edge, sideline1, sideline2
+
+    def vector_slope(self, vec):
+        assert len(vec) == 2
+        return abs(vec[1] / (vec[0] + 1e-8)) 
+
+    def find_head_tail(self, points, orientation_thr):
+        """Find the head edge and tail edge of a text polygon.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+            orientation_thr (float): The threshold for distinguishing between
+                head edge and tail edge among the horizontal and vertical edges
+                of a quadrangle.
+
+        Returns:
+            head_inds (list): The indexes of two points composing head edge.
+            tail_inds (list): The indexes of two points composing tail edge.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+        assert isinstance(orientation_thr, float)
+
+        if len(points) > 4:
+            pad_points = np.vstack([points, points[0]])
+            edge_vec = pad_points[1:] - pad_points[:-1]
+
+            theta_sum = []
+            adjacent_vec_theta = []
+            for i, edge_vec1 in enumerate(edge_vec):
+                adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
+                adjacent_edge_vec = edge_vec[adjacent_ind]
+                temp_theta_sum = np.sum(
+                    self.vector_angle(edge_vec1, adjacent_edge_vec))
+                temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
+                                                        adjacent_edge_vec[1])
+                theta_sum.append(temp_theta_sum)
+                adjacent_vec_theta.append(temp_adjacent_theta)
+            theta_sum_score = np.array(theta_sum) / np.pi
+            adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
+            poly_center = np.mean(points, axis=0)
+            edge_dist = np.maximum(
+                norm(
+                    pad_points[1:] - poly_center, axis=-1),
+                norm(
+                    pad_points[:-1] - poly_center, axis=-1))
+            dist_score = edge_dist / np.max(edge_dist)
+            position_score = np.zeros(len(edge_vec))
+            score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
+            score += 0.35 * dist_score
+            if len(points) % 2 == 0:
+                position_score[(len(score) // 2 - 1)] += 1
+                position_score[-1] += 1
+            score += 0.1 * position_score
+            pad_score = np.concatenate([score, score])
+            score_matrix = np.zeros((len(score), len(score) - 3))
+            x = np.arange(len(score) - 3) / float(len(score) - 4)
+            gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
+                (x - 0.5) / 0.5, 2.) / 2)
+            gaussian = gaussian / np.max(gaussian)
+            for i in range(len(score)):
+                score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
+                    score) - 1)] * gaussian * 0.3
+
+            head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
+                                                            score_matrix.shape)
+            tail_start = (head_start + tail_increment + 2) % len(points)
+            head_end = (head_start + 1) % len(points)
+            tail_end = (tail_start + 1) % len(points)
+
+            if head_end > tail_end:
+                head_start, tail_start = tail_start, head_start
+                head_end, tail_end = tail_end, head_end
+            head_inds = [head_start, head_end]
+            tail_inds = [tail_start, tail_end]
+        else:
+            if vector_slope(points[1] - points[0]) + vector_slope(points[
+                    3] - points[2]) < vector_slope(points[2] - points[
+                        1]) + vector_slope(points[0] - points[3]):
+                horizontal_edge_inds = [[0, 1], [2, 3]]
+                vertical_edge_inds = [[3, 0], [1, 2]]
+            else:
+                horizontal_edge_inds = [[3, 0], [1, 2]]
+                vertical_edge_inds = [[0, 1], [2, 3]]
+
+            vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
+                vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
+                    0]] - points[vertical_edge_inds[1][1]])
+            horizontal_len_sum = norm(points[horizontal_edge_inds[0][
+                0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
+                    horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
+                                                            [1]])
+
+            if vertical_len_sum > horizontal_len_sum * orientation_thr:
+                head_inds = horizontal_edge_inds[0]
+                tail_inds = horizontal_edge_inds[1]
+            else:
+                head_inds = vertical_edge_inds[0]
+                tail_inds = vertical_edge_inds[1]
+
+        return head_inds, tail_inds
+
+    def vector_angle(self, vec1, vec2):
+        if vec1.ndim > 1:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
+        if vec2.ndim > 1:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
+        return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
+
+
+    def get_minarea_rect(self, img, points):
+        bounding_box = cv2.minAreaRect(points)
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = [points[index_a], points[index_b], points[index_c], points[index_d]]
+        crop_img = self.get_rotate_crop_image(img, np.array(box))
+        return crop_img, box
+
+    def sample_points_on_bbox_bp(self, line, n=50):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+        from numpy.linalg import norm
+        # 断言检查输入参数的有效性
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [
+            norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
+        ]
+        total_length = sum(length_list)
+        length_cumsum = np.cumsum([0.0] + length_list)
+        delta_length = total_length / (float(n) + 1e-8)
+        current_edge_ind = 0
+        resampled_line = [line[0]]
+
+        for i in range(1, n):
+            current_line_len = i * delta_length
+            while current_edge_ind + 1 < len(
+                    length_cumsum) and current_line_len >= length_cumsum[
+                        current_edge_ind + 1]:
+                current_edge_ind += 1
+            current_edge_end_shift = current_line_len - length_cumsum[
+                current_edge_ind]
+            if current_edge_ind >= len(length_list):
+                break
+            end_shift_ratio = current_edge_end_shift / length_list[
+                current_edge_ind]
+            current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
+                                                    - line[current_edge_ind]
+                                                    ) * end_shift_ratio
+            resampled_line.append(current_point)
+        resampled_line.append(line[-1])
+        resampled_line = np.array(resampled_line)
+        return resampled_line
+
+    def sample_points_on_bbox(self, line, n=50):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [
+            norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
+        ]
+        total_length = sum(length_list)
+        mean_length = total_length / (len(length_list) + 1e-8)
+        group = [[0]]
+        for i in range(len(length_list)):
+            point_id = i+1
+            if length_list[i] < 0.9 * mean_length:
+                for g in group:
+                    if i in g:
+                        g.append(point_id)
+                        break
+            else:
+                g = [point_id]
+                group.append(g)
+
+        top_tail_len = norm(line[0] - line[-1])
+        if top_tail_len < 0.9 * mean_length:
+            group[0].extend(g)
+            group.remove(g)
+        mean_positions = []  
+        for indices in group:  
+            x_sum = 0  
+            y_sum = 0  
+            for index in indices:  
+                x, y = line[index]  
+                x_sum += x  
+                y_sum += y  
+            num_points = len(indices)  
+            mean_x = x_sum / num_points  
+            mean_y = y_sum / num_points  
+            mean_positions.append((mean_x, mean_y)) 
+        resampled_line = np.array(mean_positions)
+        return resampled_line
+
+    def get_poly_rect_crop(self, img, points):
+        '''
+            修改该函数,实现使用polygon,对不规则、弯曲文本的矫正以及crop
+            args: img: 图片 ndarrary格式
+            points: polygon格式的多点坐标 N*2 shape, ndarray格式
+            return: 矫正后的图片 ndarray格式
+        '''
+        points = np.array(points).astype(np.int32).reshape(-1, 2)
+        temp_crop_img, temp_box = self.get_minarea_rect(img, points)
+        # 计算最小外接矩形与polygon的IoU
+        def get_union(pD, pG):
+            return Polygon(pD).union(Polygon(pG)).area
+
+        def get_intersection_over_union(pD, pG):
+            return get_intersection(pD, pG) / (get_union(pD, pG)+ 1e-10)
+
+        def get_intersection(pD, pG):
+            return Polygon(pD).intersection(Polygon(pG)).area
+
+        cal_IoU = get_intersection_over_union(points, temp_box)
+
+        if cal_IoU >= 0.7:
+            points = self.sample_points_on_bbox_bp(points, 31)
+            return temp_crop_img
+
+        points_sample = self.sample_points_on_bbox(points)
+        points_sample = points_sample.astype(np.int32)
+        head_edge, tail_edge, top_line, bot_line = self.reorder_poly_edge(points_sample)
+
+        resample_top_line = self.sample_points_on_bbox_bp(top_line, 15)
+        resample_bot_line = self.sample_points_on_bbox_bp(bot_line, 15)
+
+        sideline_mean_shift = np.mean(
+            resample_top_line, axis=0) - np.mean(
+                resample_bot_line, axis=0)
+        if sideline_mean_shift[1] > 0:
+            resample_bot_line, resample_top_line = resample_top_line, resample_bot_line
+        rectifier = AutoRectifier()
+        new_points = np.concatenate([resample_top_line, resample_bot_line])
+        new_points_list = list(new_points.astype(np.float32).reshape(1, -1).tolist())
+
+        if len(img.shape) == 2:
+            img = np.stack((img,)*3, axis=-1)
+        img_crop, image = rectifier.run(img, new_points_list, mode='homography')
+        return img_crop[0]

+ 6 - 2
paddlex/inference/pipelines/ocr.py

@@ -23,13 +23,17 @@ class OCRPipeline(BasePipeline):
     entities = "ocr"
 
     def __init__(
-        self, det_model, rec_model, rec_batch_size, predictor_kwargs=None, **kwargs
+        self, det_model, rec_model, rec_batch_size, predictor_kwargs=None, is_curve=False, **kwargs
     ):
         super().__init__(predictor_kwargs)
         self._det_predict = self._create_predictor(det_model)
         self._rec_predict = self._create_predictor(rec_model, batch_size=rec_batch_size)
         # TODO: foo
-        self._crop_by_polys = CropByPolys(det_box_type="foo")
+        if is_curve:
+            det_box_type = 'poly'
+        else:
+            det_box_type = 'quad'
+        self._crop_by_polys = CropByPolys(det_box_type=det_box_type)
 
     def predict(self, x):
         for det_res in self._det_predict(x):

+ 4 - 0
paddlex/inference/predictors/official_models.py

@@ -183,6 +183,10 @@ PP-OCRv4_mobile_rec_infer.tar",
 PP-OCRv4_server_det_infer.tar",
     "PP-OCRv4_mobile_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\
 PP-OCRv4_mobile_det_infer.tar",
+    "PP-OCRv4_server_seal_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+PP-OCRv4_server_seal_det_infer.tar",
+    "PP-OCRv4_mobile_seal_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+PP-OCRv4_mobile_seal_det_infer.tar",
     "ch_RepSVTR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\
 openatom_rec_repsvtr_ch_infer.tar",
     "ch_SVTRv2_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\

+ 45 - 13
paddlex/inference/results/ocr.py

@@ -31,13 +31,35 @@ class OCRResult(BaseResult):
         if len(self["dt_polys"]) == 0:
             logging.warning("No text detected!")
 
+    def get_minarea_rect(self, points):
+        bounding_box = cv2.minAreaRect(points)
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = np.array([points[index_a], points[index_b], points[index_c], points[index_d]]).astype(np.int32)
+
+        return box
+
     def _get_res_img(
         self,
         drop_score=0.5,
         font_path=PINGFANG_FONT_FILE_PATH,
     ):
         """draw ocr result"""
-        boxes = np.array(self["dt_polys"])
+        boxes = self["dt_polys"]
         txts = self["rec_text"]
         scores = self["rec_score"]
         img = self._img_reader.read(self["img_path"])
@@ -46,23 +68,33 @@ class OCRResult(BaseResult):
         img_left = image.copy()
         img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
         random.seed(0)
-
         draw_left = ImageDraw.Draw(img_left)
         if txts is None or len(txts) != len(boxes):
             txts = [None] * len(boxes)
         for idx, (box, txt) in enumerate(zip(boxes, txts)):
-            if scores is not None and scores[idx] < drop_score:
+            try:
+                if scores is not None and scores[idx] < drop_score:
+                    continue
+                color = (
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                    random.randint(0, 255),
+                )
+                box = np.array(box)
+                if len(box) > 4:
+                    pts = [(x, y) for x, y in box.tolist()]
+                    draw_left.polygon(pts, outline=color, width=8)
+                    box = self.get_minarea_rect(box)
+                    height = int(0.5 * (max(box[:,1]) - min(box[:,1])))
+                    box[:2,1] = np.mean(box[:,1])
+                    box[2:,1] = np.mean(box[:,1]) + min(20, height)
+                draw_left.polygon(box, fill=color)
+                img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
+                pts = np.array(box, np.int32).reshape((-1, 1, 2))
+                cv2.polylines(img_right_text, [pts], True, color, 1)
+                img_right = cv2.bitwise_and(img_right, img_right_text)
+            except:
                 continue
-            color = (
-                random.randint(0, 255),
-                random.randint(0, 255),
-                random.randint(0, 255),
-            )
-            draw_left.polygon(box, fill=color)
-            img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
-            pts = np.array(box, np.int32).reshape((-1, 1, 2))
-            cv2.polylines(img_right_text, [pts], True, color, 1)
-            img_right = cv2.bitwise_and(img_right, img_right_text)
         img_left = Image.blend(image, img_left, 0.5)
         img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
         img_show.paste(img_left, (0, 0, w, h))

+ 3 - 3
paddlex/inference/results/text_det.py

@@ -23,10 +23,10 @@ class TextDetResult(BaseResult):
 
     def _get_res_img(self):
         """draw rectangle"""
-        boxes = np.array(self["dt_polys"])
+        boxes = self["dt_polys"]
         img = self._img_reader.read(self["img_path"])
         res_img = img.copy()
-        for box in boxes.astype(int):
-            box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
+        for box in boxes:
+            box = np.reshape(np.array(box).astype(int), [-1, 1, 2]).astype(np.int64)
             cv2.polylines(res_img, [box], True, (0, 0, 255), 2)
         return res_img

+ 4 - 0
paddlex/modules/base/predictor/utils/official_models.py

@@ -196,6 +196,10 @@ PP-OCRv4_mobile_rec_infer.tar",
 PP-OCRv4_server_det_infer.tar",
     "PP-OCRv4_mobile_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
 PP-OCRv4_mobile_det_infer.tar",
+    "PP-OCRv4_server_seal_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+PP-OCRv4_server_seal_det_infer.tar",
+    "PP-OCRv4_mobile_seal_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
+PP-OCRv4_mobile_seal_det_infer.tar",
     "ch_RepSVTR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
 openatom_rec_repsvtr_ch_infer.tar",
     "ch_SVTRv2_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\

+ 2 - 0
paddlex/modules/text_detection/model_list.py

@@ -15,4 +15,6 @@
 MODELS = [
     "PP-OCRv4_mobile_det",
     "PP-OCRv4_server_det",
+    "PP-OCRv4_mobile_seal_det",
+    "PP-OCRv4_server_seal_det"
 ]

+ 34 - 13
paddlex/modules/text_detection/predictor/predictor.py

@@ -60,9 +60,15 @@ class TextDetPredictor(BasePredictor):
 
     def _get_pre_transforms_from_config(self):
         """get preprocess transforms"""
+
+        if self.model_name in ['PP-OCRv4_server_seal_det', 'PP-OCRv4_mobile_seal_det']:
+            limit_side_len = 736
+        else:
+            limit_side_len = 960
+    
         return [
             image_common.ReadImage(),
-            T.DetResizeForTest(limit_side_len=960, limit_type="max"),
+            T.DetResizeForTest(limit_side_len=limit_side_len, limit_type="max"),
             T.NormalizeImage(
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225],
@@ -74,21 +80,36 @@ class TextDetPredictor(BasePredictor):
 
     def _get_post_transforms_from_config(self):
         """get postprocess transforms"""
-        post_transforms = [
-            T.DBPostProcess(
-                thresh=0.3,
-                box_thresh=0.6,
-                max_candidates=1000,
-                unclip_ratio=1.5,
-                use_dilation=False,
-                score_mode="fast",
-                box_type="quad",
-            )
-        ]
+        if self.model_name in ['PP-OCRv4_server_seal_det', 'PP-OCRv4_mobile_seal_det']:
+            task = 'poly'
+            post_transforms = [
+                T.DBPostProcess(
+                    thresh=0.2,
+                    box_thresh=0.6,
+                    max_candidates=1000,
+                    unclip_ratio=1.5,
+                    use_dilation=False,
+                    score_mode="fast",
+                    box_type="poly",
+                )
+            ]
+        else:
+            task = 'quad'
+            post_transforms = [
+                T.DBPostProcess(
+                    thresh=0.3,
+                    box_thresh=0.6,
+                    max_candidates=1000,
+                    unclip_ratio=1.5,
+                    use_dilation=False,
+                    score_mode="fast",
+                    box_type="quad",
+                )
+            ]
         if not self.disable_print:
             post_transforms.append(T.PrintResult())
         if not self.disable_save:
             post_transforms.append(
-                T.SaveTextDetResults(self.output),
+                T.SaveTextDetResults(self.output, task),
             )
         return post_transforms

+ 353 - 12
paddlex/modules/text_detection/predictor/transforms.py

@@ -20,6 +20,7 @@ import copy
 import math
 import pyclipper
 import numpy as np
+from numpy.linalg import norm
 from PIL import Image
 from shapely.geometry import Polygon
 
@@ -28,6 +29,7 @@ from ...base.predictor.io.writers import ImageWriter
 from ...base.predictor.io.readers import ImageReader
 from ...base.predictor import BaseTransform
 from .keys import TextDetKeys as K
+from .utils import AutoRectifier
 
 __all__ = [
     "DetResizeForTest",
@@ -461,17 +463,23 @@ class CropByPolys(BaseTransform):
     def apply(self, data):
         """apply"""
         ori_im = data[K.ORI_IM]
-        # TODO
-        # dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
-        dt_boxes = np.array(data[K.DT_POLYS])
-        img_crop_list = []
-        for bno in range(len(dt_boxes)):
-            tmp_box = copy.deepcopy(dt_boxes[bno])
-            if self.det_box_type == "quad":
-                img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
-            else:
+        if self.det_box_type == "quad":
+            dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
+            dt_boxes = np.array(dt_boxes)
+            img_crop_list = []
+            for bno in range(len(dt_boxes)):
+                tmp_box = copy.deepcopy(dt_boxes[bno])
                 img_crop = self.get_minarea_rect_crop(ori_im, tmp_box)
-            img_crop_list.append(img_crop)
+                img_crop_list.append(img_crop)
+        elif self.det_box_type == "poly":
+            img_crop_list = []
+            dt_boxes = data[K.DT_POLYS]
+            for bno in range(len(dt_boxes)):
+                tmp_box = copy.deepcopy(dt_boxes[bno])
+                img_crop = self.get_poly_rect_crop(ori_im.copy(), tmp_box)
+                img_crop_list.append(img_crop)
+        else:
+            raise NotImplementedError
         data[K.SUB_IMGS] = img_crop_list
         return data
 
@@ -533,6 +541,7 @@ class CropByPolys(BaseTransform):
         crop_img = self.get_rotate_crop_image(img, np.array(box))
         return crop_img
 
+
     def get_rotate_crop_image(self, img, points):
         """
         img_height, img_width = img.shape[0:2]
@@ -578,13 +587,330 @@ class CropByPolys(BaseTransform):
             dst_img = np.rot90(dst_img)
         return dst_img
 
+    def reorder_poly_edge(self, points):
+        """Get the respective points composing head edge, tail edge, top
+        sideline and bottom sideline.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+
+        Returns:
+            head_edge (ndarray): The two points composing the head edge of text
+                polygon.
+            tail_edge (ndarray): The two points composing the tail edge of text
+                polygon.
+            top_sideline (ndarray): The points composing top curved sideline of
+                text polygon.
+            bot_sideline (ndarray): The points composing bottom curved sideline
+                of text polygon.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+
+        orientation_thr=2.0             # 一个经验超参数
+
+        head_inds, tail_inds = self.find_head_tail(points, orientation_thr)
+        head_edge, tail_edge = points[head_inds], points[tail_inds]
+
+
+        pad_points = np.vstack([points, points])
+        if tail_inds[1] < 1:
+            tail_inds[1] = len(points)
+        sideline1 = pad_points[head_inds[1]:tail_inds[1]]
+        sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
+        return head_edge, tail_edge, sideline1, sideline2
+
+    def vector_slope(self, vec):
+        assert len(vec) == 2
+        return abs(vec[1] / (vec[0] + 1e-8)) 
+
+    def find_head_tail(self, points, orientation_thr):
+        """Find the head edge and tail edge of a text polygon.
+
+        Args:
+            points (ndarray): The points composing a text polygon.
+            orientation_thr (float): The threshold for distinguishing between
+                head edge and tail edge among the horizontal and vertical edges
+                of a quadrangle.
+
+        Returns:
+            head_inds (list): The indexes of two points composing head edge.
+            tail_inds (list): The indexes of two points composing tail edge.
+        """
+
+        assert points.ndim == 2
+        assert points.shape[0] >= 4
+        assert points.shape[1] == 2
+        assert isinstance(orientation_thr, float)
+
+        if len(points) > 4:
+            pad_points = np.vstack([points, points[0]])
+            edge_vec = pad_points[1:] - pad_points[:-1]
+
+            theta_sum = []
+            adjacent_vec_theta = []
+            for i, edge_vec1 in enumerate(edge_vec):
+                adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
+                adjacent_edge_vec = edge_vec[adjacent_ind]
+                temp_theta_sum = np.sum(
+                    self.vector_angle(edge_vec1, adjacent_edge_vec))
+                temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
+                                                        adjacent_edge_vec[1])
+                theta_sum.append(temp_theta_sum)
+                adjacent_vec_theta.append(temp_adjacent_theta)
+            theta_sum_score = np.array(theta_sum) / np.pi
+            adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
+            poly_center = np.mean(points, axis=0)
+            edge_dist = np.maximum(
+                norm(
+                    pad_points[1:] - poly_center, axis=-1),
+                norm(
+                    pad_points[:-1] - poly_center, axis=-1))
+            dist_score = edge_dist / np.max(edge_dist)
+            position_score = np.zeros(len(edge_vec))
+            score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
+            score += 0.35 * dist_score
+            if len(points) % 2 == 0:
+                position_score[(len(score) // 2 - 1)] += 1
+                position_score[-1] += 1
+            score += 0.1 * position_score
+            pad_score = np.concatenate([score, score])
+            score_matrix = np.zeros((len(score), len(score) - 3))
+            x = np.arange(len(score) - 3) / float(len(score) - 4)
+            gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
+                (x - 0.5) / 0.5, 2.) / 2)
+            gaussian = gaussian / np.max(gaussian)
+            for i in range(len(score)):
+                score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
+                    score) - 1)] * gaussian * 0.3
+
+            head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
+                                                            score_matrix.shape)
+            tail_start = (head_start + tail_increment + 2) % len(points)
+            head_end = (head_start + 1) % len(points)
+            tail_end = (tail_start + 1) % len(points)
+
+            if head_end > tail_end:
+                head_start, tail_start = tail_start, head_start
+                head_end, tail_end = tail_end, head_end
+            head_inds = [head_start, head_end]
+            tail_inds = [tail_start, tail_end]
+        else:
+            if vector_slope(points[1] - points[0]) + vector_slope(points[
+                    3] - points[2]) < vector_slope(points[2] - points[
+                        1]) + vector_slope(points[0] - points[3]):
+                horizontal_edge_inds = [[0, 1], [2, 3]]
+                vertical_edge_inds = [[3, 0], [1, 2]]
+            else:
+                horizontal_edge_inds = [[3, 0], [1, 2]]
+                vertical_edge_inds = [[0, 1], [2, 3]]
+
+            vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
+                vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
+                    0]] - points[vertical_edge_inds[1][1]])
+            horizontal_len_sum = norm(points[horizontal_edge_inds[0][
+                0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
+                    horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
+                                                            [1]])
+
+            if vertical_len_sum > horizontal_len_sum * orientation_thr:
+                head_inds = horizontal_edge_inds[0]
+                tail_inds = horizontal_edge_inds[1]
+            else:
+                head_inds = vertical_edge_inds[0]
+                tail_inds = vertical_edge_inds[1]
+
+        return head_inds, tail_inds
+
+    def vector_angle(self, vec1, vec2):
+        if vec1.ndim > 1:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec1 = vec1 / (norm(vec1, axis=-1) + 1e-8)
+        if vec2.ndim > 1:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8).reshape((-1, 1))
+        else:
+            unit_vec2 = vec2 / (norm(vec2, axis=-1) + 1e-8)
+        return np.arccos(np.clip(np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
+
+
+    def get_minarea_rect(self, img, points):
+        bounding_box = cv2.minAreaRect(points)
+        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+        index_a, index_b, index_c, index_d = 0, 1, 2, 3
+        if points[1][1] > points[0][1]:
+            index_a = 0
+            index_d = 1
+        else:
+            index_a = 1
+            index_d = 0
+        if points[3][1] > points[2][1]:
+            index_b = 2
+            index_c = 3
+        else:
+            index_b = 3
+            index_c = 2
+
+        box = [points[index_a], points[index_b], points[index_c], points[index_d]]
+        crop_img = self.get_rotate_crop_image(img, np.array(box))
+        return crop_img, box
+
+    def sample_points_on_bbox_bp(self, line, n=50):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+        from numpy.linalg import norm
+        # 断言检查输入参数的有效性
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [
+            norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
+        ]
+        total_length = sum(length_list)
+        length_cumsum = np.cumsum([0.0] + length_list)
+        delta_length = total_length / (float(n) + 1e-8)
+        current_edge_ind = 0
+        resampled_line = [line[0]]
+
+        for i in range(1, n):
+            current_line_len = i * delta_length
+            while current_edge_ind + 1 < len(
+                    length_cumsum) and current_line_len >= length_cumsum[
+                        current_edge_ind + 1]:
+                current_edge_ind += 1
+            current_edge_end_shift = current_line_len - length_cumsum[
+                current_edge_ind]
+            if current_edge_ind >= len(length_list):
+                break
+            end_shift_ratio = current_edge_end_shift / length_list[
+                current_edge_ind]
+            current_point = line[current_edge_ind] + (line[current_edge_ind + 1]
+                                                    - line[current_edge_ind]
+                                                    ) * end_shift_ratio
+            resampled_line.append(current_point)
+        resampled_line.append(line[-1])
+        resampled_line = np.array(resampled_line)
+        return resampled_line
+
+    def sample_points_on_bbox(self, line, n=50):
+        """Resample n points on a line.
+
+        Args:
+            line (ndarray): The points composing a line.
+            n (int): The resampled points number.
+
+        Returns:
+            resampled_line (ndarray): The points composing the resampled line.
+        """
+        assert line.ndim == 2
+        assert line.shape[0] >= 2
+        assert line.shape[1] == 2
+        assert isinstance(n, int)
+        assert n > 0
+
+        length_list = [
+            norm(line[i + 1] - line[i]) for i in range(len(line) - 1)
+        ]
+        total_length = sum(length_list)
+        mean_length = total_length / (len(length_list) + 1e-8)
+        group = [[0]]
+        for i in range(len(length_list)):
+            point_id = i+1
+            if length_list[i] < 0.9 * mean_length:
+                for g in group:
+                    if i in g:
+                        g.append(point_id)
+                        break
+            else:
+                g = [point_id]
+                group.append(g)
+
+        top_tail_len = norm(line[0] - line[-1])
+        if top_tail_len < 0.9 * mean_length:
+            group[0].extend(g)
+            group.remove(g)
+        mean_positions = []  
+        for indices in group:  
+            x_sum = 0  
+            y_sum = 0  
+            for index in indices:  
+                x, y = line[index]  
+                x_sum += x  
+                y_sum += y  
+            num_points = len(indices)  
+            mean_x = x_sum / num_points  
+            mean_y = y_sum / num_points  
+            mean_positions.append((mean_x, mean_y)) 
+        resampled_line = np.array(mean_positions)
+        return resampled_line
+
+    def get_poly_rect_crop(self, img, points):
+        '''
+            修改该函数,实现使用polygon,对不规则、弯曲文本的矫正以及crop
+            args: img: 图片 ndarrary格式
+            points: polygon格式的多点坐标 N*2 shape, ndarray格式
+            return: 矫正后的图片 ndarray格式
+        '''
+        points = np.array(points).astype(np.int32).reshape(-1, 2)
+        temp_crop_img, temp_box = self.get_minarea_rect(img, points)
+        # 计算最小外接矩形与polygon的IoU
+        def get_union(pD, pG):
+            return Polygon(pD).union(Polygon(pG)).area
+
+        def get_intersection_over_union(pD, pG):
+            return get_intersection(pD, pG) / (get_union(pD, pG)+ 1e-10)
+
+        def get_intersection(pD, pG):
+            return Polygon(pD).intersection(Polygon(pG)).area
+
+        cal_IoU = get_intersection_over_union(points, temp_box)
+
+        if cal_IoU >= 0.7:
+            points = self.sample_points_on_bbox_bp(points, 31)
+            return temp_crop_img
+
+        points_sample = self.sample_points_on_bbox(points)
+        points_sample = points_sample.astype(np.int32)
+        head_edge, tail_edge, top_line, bot_line = self.reorder_poly_edge(points_sample)
+
+        resample_top_line = self.sample_points_on_bbox_bp(top_line, 15)
+        resample_bot_line = self.sample_points_on_bbox_bp(bot_line, 15)
+
+        sideline_mean_shift = np.mean(
+            resample_top_line, axis=0) - np.mean(
+                resample_bot_line, axis=0)
+        if sideline_mean_shift[1] > 0:
+            resample_bot_line, resample_top_line = resample_top_line, resample_bot_line
+        rectifier = AutoRectifier()
+        new_points = np.concatenate([resample_top_line, resample_bot_line])
+        new_points_list = list(new_points.astype(np.float32).reshape(1, -1).tolist())
+
+        if len(img.shape) == 2:
+            img = np.stack((img,)*3, axis=-1)
+        img_crop, image = rectifier.run(img, new_points_list, mode='homography')
+        return img_crop[0]
+
 
 class SaveTextDetResults(BaseTransform):
     """Save Text Det Results"""
 
-    def __init__(self, save_dir):
+    def __init__(self, save_dir, task='quad'):
         super().__init__()
         self.save_dir = save_dir
+        self.task = task
         # We use pillow backend to save both numpy arrays and PIL Image objects
         self._writer = ImageWriter(backend="opencv")
 
@@ -598,7 +924,10 @@ class SaveTextDetResults(BaseTransform):
         fn = os.path.basename(data["input_path"])
         save_path = os.path.join(self.save_dir, fn)
         bbox_res = data[K.DT_POLYS]
-        vis_img = self.draw_rectangle(data[K.IM_PATH], bbox_res)
+        if self.task == "quad":
+            vis_img = self.draw_rectangle(data[K.IM_PATH], bbox_res)
+        else:
+            vis_img = self.draw_polyline(data[K.IM_PATH], bbox_res)
         self._writer.write(save_path, vis_img)
         return data
 
@@ -621,6 +950,16 @@ class SaveTextDetResults(BaseTransform):
             box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
             cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
         return img_show
+    
+    def draw_polyline(self, img_path, boxes):
+        """draw polyline"""
+        img = cv2.imread(img_path)
+        img_show = img.copy()
+        for box in boxes:
+            box = np.array(box).astype(int)
+            box = np.reshape(box, [-1, 1, 2]).astype(np.int64)
+            cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
+        return img_show
 
 
 class PrintResult(BaseTransform):
@@ -644,3 +983,5 @@ class PrintResult(BaseTransform):
 
     # DT_SCORES = 'dt_scores'
     # DT_POLYS = 'dt_polys'
+
+

+ 698 - 0
paddlex/modules/text_detection/predictor/utils.py

@@ -0,0 +1,698 @@
+import os, sys
+import numpy as np
+from numpy import cos, sin, arctan, sqrt
+import cv2
+import copy
+import time
+
+def Homography(image, img_points, world_width, world_height,
+               interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0):
+    """
+    将图像透视变换到新的视角,返回变换后的图像。
+    
+    Args:
+        image (np.ndarray): 输入的图像,应为numpy数组类型。
+        img_points (List[Tuple[int, int]]): 图像上的四个点的坐标,顺序为左上角、右上角、右下角、左下角。
+        world_width (int): 变换后图像在世界坐标系中的宽度。
+        world_height (int): 变换后图像在世界坐标系中的高度。
+        interpolation (int, optional): 插值方式,默认为cv2.INTER_CUBIC。
+        ratio_width (float, optional): 变换后图像在x轴上的缩放比例,默认为1.0。
+        ratio_height (float, optional): 变换后图像在y轴上的缩放比例,默认为1.0。
+    
+    Returns:
+        np.ndarray: 变换后的图像,为numpy数组类型。
+    
+    """
+    _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
+
+    expand_x = int(0.5 * world_width * (ratio_width - 1))
+    expand_y = int(0.5 * world_height * (ratio_height - 1))
+
+    pt_lefttop = [expand_x, expand_y]
+    pt_righttop = [expand_x + world_width, expand_y]
+    pt_leftbottom = [expand_x + world_width, expand_y + world_height]
+    pt_rightbottom = [expand_x, expand_y + world_height]
+
+    pts_std = np.float32([pt_lefttop, pt_righttop,
+                          pt_leftbottom, pt_rightbottom])
+
+    img_crop_width = int(world_width * ratio_width)
+    img_crop_height = int(world_height * ratio_height)
+
+    M = cv2.getPerspectiveTransform(_points, pts_std)
+
+    dst_img = cv2.warpPerspective(
+        image,
+        M, (img_crop_width, img_crop_height),
+        borderMode=cv2.BORDER_CONSTANT,  # BORDER_CONSTANT BORDER_REPLICATE
+        flags=interpolation)
+
+    return dst_img
+
+
+class CurveTextRectifier:
+    """
+    spatial transformer via monocular vision
+    """
+    def __init__(self):
+        self.get_virtual_camera_parameter()
+
+
+    def get_virtual_camera_parameter(self):
+        vcam_thz = 0
+        vcam_thx1 = 180
+        vcam_thy = 180
+        vcam_thx2 = 0
+
+        vcam_x = 0
+        vcam_y = 0
+        vcam_z = 100
+
+        radian = np.pi / 180
+
+        angle_z = radian * vcam_thz
+        angle_x1 = radian * vcam_thx1
+        angle_y = radian * vcam_thy
+        angle_x2 = radian * vcam_thx2
+
+        optic_x = vcam_x
+        optic_y = vcam_y
+        optic_z = vcam_z
+
+        fu = 100
+        fv = 100
+
+        matT = np.zeros((4, 4))
+        matT[0, 0] = cos(angle_z) * cos(angle_y) - sin(angle_z) * sin(angle_x1) * sin(angle_y)
+        matT[0, 1] = cos(angle_z) * sin(angle_y) * sin(angle_x2) - sin(angle_z) * (
+                    cos(angle_x1) * cos(angle_x2) - sin(angle_x1) * cos(angle_y) * sin(angle_x2))
+        matT[0, 2] = cos(angle_z) * sin(angle_y) * cos(angle_x2) + sin(angle_z) * (
+                    cos(angle_x1) * sin(angle_x2) + sin(angle_x1) * cos(angle_y) * cos(angle_x2))
+        matT[0, 3] = optic_x
+        matT[1, 0] = sin(angle_z) * cos(angle_y) + cos(angle_z) * sin(angle_x1) * sin(angle_y)
+        matT[1, 1] = sin(angle_z) * sin(angle_y) * sin(angle_x2) + cos(angle_z) * (
+                    cos(angle_x1) * cos(angle_x2) - sin(angle_x1) * cos(angle_y) * sin(angle_x2))
+        matT[1, 2] = sin(angle_z) * sin(angle_y) * cos(angle_x2) - cos(angle_z) * (
+                    cos(angle_x1) * sin(angle_x2) + sin(angle_x1) * cos(angle_y) * cos(angle_x2))
+        matT[1, 3] = optic_y
+        matT[2, 0] = -cos(angle_x1) * sin(angle_y)
+        matT[2, 1] = cos(angle_x1) * cos(angle_y) * sin(angle_x2) + sin(angle_x1) * cos(angle_x2)
+        matT[2, 2] = cos(angle_x1) * cos(angle_y) * cos(angle_x2) - sin(angle_x1) * sin(angle_x2)
+        matT[2, 3] = optic_z
+        matT[3, 0] = 0
+        matT[3, 1] = 0
+        matT[3, 2] = 0
+        matT[3, 3] = 1
+
+        matS = np.zeros((4, 4))
+        matS[2, 3] = 0.5
+        matS[3, 2] = 0.5
+
+        self.ifu = 1 / fu
+        self.ifv = 1 / fv
+
+        self.matT = matT
+        self.matS = matS
+        self.K = np.dot(matT.T, matS)
+        self.K = np.dot(self.K, matT)
+
+
+    def vertical_text_process(self, points, org_size):
+        """
+        change sequence amd process
+        :param points:
+        :param org_size:
+        :return:
+        """
+        org_w, org_h = org_size
+        _points = np.array(points).reshape(-1).tolist()
+        _points = np.array(_points[2:] + _points[:2]).reshape(-1, 2)
+
+        # convert to horizontal points
+        adjusted_points = np.zeros(_points.shape, dtype=np.float32)
+        adjusted_points[:, 0] = _points[:, 1]
+        adjusted_points[:, 1] = org_h - _points[:, 0] - 1
+
+        _image_coord, _world_coord, _new_image_size = self.horizontal_text_process(adjusted_points)
+
+        # # convert to vertical points back
+        image_coord = _points.reshape(1, -1, 2)
+        world_coord = np.zeros(_world_coord.shape, dtype=np.float32)
+        world_coord[:, :, 0] = 0 - _world_coord[:, :, 1]
+        world_coord[:, :, 1] = _world_coord[:, :, 0]
+        world_coord[:, :, 2] = _world_coord[:, :, 2]
+        new_image_size = (_new_image_size[1], _new_image_size[0])
+
+        return image_coord, world_coord, new_image_size
+
+
+    def horizontal_text_process(self, points):
+        """
+        get image coordinate and world coordinate
+        :param points:
+        :return:
+        """
+        poly = np.array(points).reshape(-1)
+
+        dx_list = []
+        dy_list = []
+        for i in range(1, len(poly) // 2):
+            xdx = poly[i * 2] - poly[(i - 1) * 2]
+            xdy = poly[i * 2 + 1] - poly[(i - 1) * 2 + 1]
+            d = sqrt(xdx ** 2 + xdy ** 2)
+            dx_list.append(d)
+
+        for i in range(0, len(poly) // 4):
+            ydx = poly[i * 2] - poly[len(poly) - 1 - (i * 2 + 1)]
+            ydy = poly[i * 2 + 1] - poly[len(poly) - 1 - (i * 2)]
+            d = sqrt(ydx ** 2 + ydy ** 2)
+            dy_list.append(d)
+
+        dx_list = [(dx_list[i] + dx_list[len(dx_list) - 1 - i]) / 2 for i in range(len(dx_list) // 2)]
+
+        height = np.around(np.mean(dy_list))
+
+        rect_coord = [0, 0]
+        for i in range(0, len(poly) // 4 - 1):
+            x = rect_coord[-2]
+            x += dx_list[i]
+            y = 0
+            rect_coord.append(x)
+            rect_coord.append(y)
+
+        rect_coord_half = copy.deepcopy(rect_coord)
+        for i in range(0, len(poly) // 4):
+            x = rect_coord_half[len(rect_coord_half) - 2 * i - 2]
+            y = height
+            rect_coord.append(x)
+            rect_coord.append(y)
+
+        np_rect_coord = np.array(rect_coord).reshape(-1, 2)
+        x_min = np.min(np_rect_coord[:, 0])
+        y_min = np.min(np_rect_coord[:, 1])
+        x_max = np.max(np_rect_coord[:, 0])
+        y_max = np.max(np_rect_coord[:, 1])
+        new_image_size = (int(x_max - x_min + 0.5), int(y_max - y_min + 0.5))
+        x_mean = (x_max - x_min) / 2
+        y_mean = (y_max - y_min) / 2
+        np_rect_coord[:, 0] -= x_mean
+        np_rect_coord[:, 1] -= y_mean
+        rect_coord = np_rect_coord.reshape(-1).tolist()
+
+        rect_coord = np.array(rect_coord).reshape(-1, 2)
+        world_coord = np.ones((len(rect_coord), 3)) * 0
+
+        world_coord[:, :2] = rect_coord
+
+        image_coord = np.array(poly).reshape(1, -1, 2)
+        world_coord = world_coord.reshape(1, -1, 3)
+
+        return image_coord, world_coord, new_image_size
+
+
+    def horizontal_text_estimate(self, points):
+        """
+        horizontal or vertical text
+        :param points:
+        :return:
+        """
+        pts = np.array(points).reshape(-1, 2)
+        x_min = int(np.min(pts[:, 0]))
+        y_min = int(np.min(pts[:, 1]))
+        x_max = int(np.max(pts[:, 0]))
+        y_max = int(np.max(pts[:, 1]))
+        x = x_max - x_min
+        y = y_max - y_min
+        is_horizontal_text = True
+        if y / x > 1.5: # vertical text condition
+            is_horizontal_text = False
+        return is_horizontal_text
+
+
+    def virtual_camera_to_world(self, size):
+        ifu, ifv = self.ifu, self.ifv
+        K, matT = self.K, self.matT
+
+        ppu = size[0] / 2 + 1e-6
+        ppv = size[1] / 2 + 1e-6
+
+        P = np.zeros((size[1], size[0], 3))
+
+        lu = np.array([i for i in range(size[0])])
+        lv = np.array([i for i in range(size[1])])
+        u, v = np.meshgrid(lu, lv)
+
+        yp = (v - ppv) * ifv
+        xp = (u - ppu) * ifu
+        angle_a = arctan(sqrt(xp * xp + yp * yp))
+        angle_b = arctan(yp / xp)
+
+        D0 = sin(angle_a) * cos(angle_b)
+        D1 = sin(angle_a) * sin(angle_b)
+        D2 = cos(angle_a)
+
+        D0[xp <= 0] = -D0[xp <= 0]
+        D1[xp <= 0] = -D1[xp <= 0]
+
+        ratio_a = K[0, 0] * D0 * D0 + K[1, 1] * D1 * D1 + K[2, 2] * D2 * D2 + \
+                  (K[0, 1] + K[1, 0]) * D0 * D1 + (K[0, 2] + K[2, 0]) * D0 * D2 + (K[1, 2] + K[2, 1]) * D1 * D2
+        ratio_b = (K[0, 3] + K[3, 0]) * D0 + (K[1, 3] + K[3, 1]) * D1 + (K[2, 3] + K[3, 2]) * D2
+        ratio_c = K[3, 3] * np.ones(ratio_b.shape)
+
+        delta = ratio_b * ratio_b - 4 * ratio_a * ratio_c
+        t = np.zeros(delta.shape)
+        t[ratio_a == 0] = -ratio_c[ratio_a == 0] / ratio_b[ratio_a == 0]
+        t[ratio_a != 0] = (-ratio_b[ratio_a != 0] + sqrt(delta[ratio_a != 0])) / (2 * ratio_a[ratio_a != 0])
+        t[delta < 0] = 0
+
+        P[:, :, 0] = matT[0, 3] + t * (matT[0, 0] * D0 + matT[0, 1] * D1 + matT[0, 2] * D2)
+        P[:, :, 1] = matT[1, 3] + t * (matT[1, 0] * D0 + matT[1, 1] * D1 + matT[1, 2] * D2)
+        P[:, :, 2] = matT[2, 3] + t * (matT[2, 0] * D0 + matT[2, 1] * D1 + matT[2, 2] * D2)
+
+        return P
+
+
+    def world_to_image(self, image_size, world, intrinsic, distCoeffs, rotation, tvec):
+        r11 = rotation[0, 0]
+        r12 = rotation[0, 1]
+        r13 = rotation[0, 2]
+        r21 = rotation[1, 0]
+        r22 = rotation[1, 1]
+        r23 = rotation[1, 2]
+        r31 = rotation[2, 0]
+        r32 = rotation[2, 1]
+        r33 = rotation[2, 2]
+
+        t1 = tvec[0]
+        t2 = tvec[1]
+        t3 = tvec[2]
+
+        k1 = distCoeffs[0]
+        k2 = distCoeffs[1]
+        p1 = distCoeffs[2]
+        p2 = distCoeffs[3]
+        k3 = distCoeffs[4]
+        k4 = distCoeffs[5]
+        k5 = distCoeffs[6]
+        k6 = distCoeffs[7]
+
+        if len(distCoeffs) > 8:
+            s1 = distCoeffs[8]
+            s2 = distCoeffs[9]
+            s3 = distCoeffs[10]
+            s4 = distCoeffs[11]
+        else:
+            s1 = s2 = s3 = s4 = 0
+
+        if len(distCoeffs) > 12:
+            tx = distCoeffs[12]
+            ty = distCoeffs[13]
+        else:
+            tx = ty = 0
+
+        fu = intrinsic[0, 0]
+        fv = intrinsic[1, 1]
+        ppu = intrinsic[0, 2]
+        ppv = intrinsic[1, 2]
+
+        cos_tx = cos(tx)
+        cos_ty = cos(ty)
+        sin_tx = sin(tx)
+        sin_ty = sin(ty)
+
+        tao11 = cos_ty * cos_tx * cos_ty + sin_ty * cos_tx * sin_ty
+        tao12 = cos_ty * cos_tx * sin_ty * sin_tx - sin_ty * cos_tx * cos_ty * sin_tx
+        tao13 = -cos_ty * cos_tx * sin_ty * cos_tx + sin_ty * cos_tx * cos_ty * cos_tx
+        tao21 = -sin_tx * sin_ty
+        tao22 = cos_ty * cos_tx * cos_tx + sin_tx * cos_ty * sin_tx
+        tao23 = cos_ty * cos_tx * sin_tx - sin_tx * cos_ty * cos_tx
+
+        P = np.zeros((image_size[1], image_size[0], 2))
+
+        c3 = r31 * world[:, :, 0] + r32 * world[:, :, 1] + r33 * world[:, :, 2] + t3
+        c1 = r11 * world[:, :, 0] + r12 * world[:, :, 1] + r13 * world[:, :, 2] + t1
+        c2 = r21 * world[:, :, 0] + r22 * world[:, :, 1] + r23 * world[:, :, 2] + t2
+
+        x1 = c1 / c3
+        y1 = c2 / c3
+        x12 = x1 * x1
+        y12 = y1 * y1
+        x1y1 = 2 * x1 * y1
+        r2 = x12 + y12
+        r4 = r2 * r2
+        r6 = r2 * r4
+
+        radial_distortion = (1 + k1 * r2 + k2 * r4 + k3 * r6) / (1 + k4 * r2 + k5 * r4 + k6 * r6)
+        x2 = x1 * radial_distortion + p1 * x1y1 + p2 * (r2 + 2 * x12) + s1 * r2 + s2 * r4
+        y2 = y1 * radial_distortion + p2 * x1y1 + p1 * (r2 + 2 * y12) + s3 * r2 + s4 * r4
+
+        x3 = tao11 * x2 + tao12 * y2 + tao13
+        y3 = tao21 * x2 + tao22 * y2 + tao23
+
+        P[:, :, 0] = fu * x3 + ppu
+        P[:, :, 1] = fv * y3 + ppv
+        P[c3 <= 0] = 0
+
+        return P
+
+
+    def spatial_transform(self, image_data, new_image_size, mtx, dist, rvecs, tvecs, interpolation):
+        rotation, _ = cv2.Rodrigues(rvecs)
+        world_map = self.virtual_camera_to_world(new_image_size)
+        image_map = self.world_to_image(new_image_size, world_map, mtx, dist, rotation, tvecs)
+        image_map = image_map.astype(np.float32)
+        dst = cv2.remap(image_data, image_map[:, :, 0], image_map[:, :, 1], interpolation)
+        return dst
+
+
+    def calibrate(self, org_size, image_coord, world_coord):
+        """
+        calibration
+        :param org_size:
+        :param image_coord:
+        :param world_coord:
+        :return:
+        """
+        # flag = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL  | cv2.CALIB_THIN_PRISM_MODEL
+        flag = cv2.CALIB_RATIONAL_MODEL
+        flag2 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL
+        flag3 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_THIN_PRISM_MODEL
+        flag4 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_ZERO_TANGENT_DIST | cv2.CALIB_FIX_ASPECT_RATIO
+        flag5 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_TILTED_MODEL | cv2.CALIB_ZERO_TANGENT_DIST
+        flag6 = cv2.CALIB_RATIONAL_MODEL | cv2.CALIB_FIX_ASPECT_RATIO
+        flag_list = [flag2, flag3, flag4, flag5, flag6]
+
+        ret, mtx, dist, rvecs, tvecs = cv2.calibrateCamera(world_coord.astype(np.float32),
+                                                                image_coord.astype(np.float32),
+                                                                org_size,
+                                                                None,
+                                                                None,
+                                                                flags=flag)
+        if ret > 2:
+            # strategies
+            min_ret = ret
+            for i, flag in enumerate(flag_list):
+                _ret, _mtx, _dist, _rvecs, _tvecs = cv2.calibrateCamera(world_coord.astype(np.float32),
+                                                                   image_coord.astype(np.float32),
+                                                                   org_size,
+                                                                   None,
+                                                                   None,
+                                                                   flags=flag)
+                if _ret < min_ret:
+                    min_ret = _ret
+                    ret, mtx, dist, rvecs, tvecs = _ret, _mtx, _dist, _rvecs, _tvecs
+
+        return ret, mtx, dist, rvecs, tvecs
+
+
+    def dc_homo(self, img, img_points, obj_points, is_horizontal_text, interpolation=cv2.INTER_LINEAR,
+                ratio_width=1.0, ratio_height=1.0):
+        """
+        divide and conquer: homography
+        # ratio_width and ratio_height must be 1.0 here
+        """
+        _img_points = img_points.reshape(-1, 2)
+        _obj_points = obj_points.reshape(-1, 3)
+
+        homo_img_list = []
+        width_list = []
+        height_list = []
+        # divide and conquer
+        for i in range(len(_img_points) // 2 - 1):
+            new_img_points = np.zeros((4, 2)).astype(np.float32)
+            new_obj_points = np.zeros((4, 2)).astype(np.float32)
+
+            new_img_points[0:2, :] = _img_points[i:(i + 2), :2]
+            new_img_points[2:4, :] = _img_points[::-1, :][i:(i + 2), :2][::-1, :]
+
+            new_obj_points[0:2, :] = _obj_points[i:(i + 2), :2]
+            new_obj_points[2:4, :] = _obj_points[::-1, :][i:(i + 2), :2][::-1, :]
+
+            if is_horizontal_text:
+                world_width = np.abs(new_obj_points[1, 0] - new_obj_points[0, 0])
+                world_height = np.abs(new_obj_points[3, 1] - new_obj_points[0, 1])
+            else:
+                world_width = np.abs(new_obj_points[1, 1] - new_obj_points[0, 1])
+                world_height = np.abs(new_obj_points[3, 0] - new_obj_points[0, 0])
+
+            homo_img = Homography(img, new_img_points, world_width, world_height,
+                                              interpolation=interpolation,
+                                              ratio_width=ratio_width, ratio_height=ratio_height)
+
+            homo_img_list.append(homo_img)
+            _h, _w = homo_img.shape[:2]
+            width_list.append(_w)
+            height_list.append(_h)
+
+        # stitching
+        rectified_image = np.zeros((np.max(height_list), sum(width_list), 3)).astype(np.uint8)
+
+        st = 0
+        for (homo_img, w, h) in zip(homo_img_list, width_list, height_list):
+            rectified_image[:h, st:st + w, :] = homo_img
+            st += w
+
+        if not is_horizontal_text:
+            # vertical rotation
+            rectified_image = np.rot90(rectified_image, 3)
+
+        return rectified_image
+
+    def Homography(self, image, img_points, world_width, world_height,
+                interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0):
+        """
+        将图像透视变换到新的视角,返回变换后的图像。
+        
+        Args:
+            image (np.ndarray): 输入的图像,应为numpy数组类型。
+            img_points (List[Tuple[int, int]]): 图像上的四个点的坐标,顺序为左上角、右上角、右下角、左下角。
+            world_width (int): 变换后图像在世界坐标系中的宽度。
+            world_height (int): 变换后图像在世界坐标系中的高度。
+            interpolation (int, optional): 插值方式,默认为cv2.INTER_CUBIC。
+            ratio_width (float, optional): 变换后图像在x轴上的缩放比例,默认为1.0。
+            ratio_height (float, optional): 变换后图像在y轴上的缩放比例,默认为1.0。
+        
+        Returns:
+            np.ndarray: 变换后的图像,为numpy数组类型。
+        
+        """
+        _points = np.array(img_points).reshape(-1, 2).astype(np.float32)
+
+        expand_x = int(0.5 * world_width * (ratio_width - 1))
+        expand_y = int(0.5 * world_height * (ratio_height - 1))
+
+        pt_lefttop = [expand_x, expand_y]
+        pt_righttop = [expand_x + world_width, expand_y]
+        pt_leftbottom = [expand_x + world_width, expand_y + world_height]
+        pt_rightbottom = [expand_x, expand_y + world_height]
+
+        pts_std = np.float32([pt_lefttop, pt_righttop,
+                            pt_leftbottom, pt_rightbottom])
+
+        img_crop_width = int(world_width * ratio_width)
+        img_crop_height = int(world_height * ratio_height)
+
+        M = cv2.getPerspectiveTransform(_points, pts_std)
+
+        dst_img = cv2.warpPerspective(
+            image,
+            M, (img_crop_width, img_crop_height),
+            borderMode=cv2.BORDER_CONSTANT,  # BORDER_CONSTANT BORDER_REPLICATE
+            flags=interpolation)
+
+        return dst_img
+
+
+    def __call__(self, image_data, points, interpolation=cv2.INTER_LINEAR, ratio_width=1.0, ratio_height=1.0, mode='calibration'):
+        """
+        spatial transform for a poly text
+        :param image_data:
+        :param points: [x1,y1,x2,y2,x3,y3,...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return:
+        """
+        org_h, org_w = image_data.shape[:2]
+        org_size = (org_w, org_h)
+        self.image = image_data
+
+        is_horizontal_text = self.horizontal_text_estimate(points)
+        if is_horizontal_text:
+            image_coord, world_coord, new_image_size = self.horizontal_text_process(points)
+        else:
+            image_coord, world_coord, new_image_size = self.vertical_text_process(points, org_size)
+
+        if mode.lower() == 'calibration':
+            ret, mtx, dist, rvecs, tvecs = self.calibrate(org_size, image_coord, world_coord)
+
+            st_size = (int(new_image_size[0]*ratio_width), int(new_image_size[1]*ratio_height))
+            dst = self.spatial_transform(image_data, st_size, mtx, dist[0], rvecs[0], tvecs[0], interpolation)
+        elif mode.lower() == 'homography':
+            # ratio_width and ratio_height must be 1.0 here and ret set to 0.01 without loss manually
+            ret = 0.01
+            dst = self.dc_homo(image_data, image_coord, world_coord, is_horizontal_text,
+                               interpolation=interpolation, ratio_width=1.0, ratio_height=1.0)
+        else:
+            raise ValueError('mode must be ["calibration", "homography"], but got {}'.format(mode))
+
+        return dst, ret
+
+
+class AutoRectifier:
+    def __init__(self):
+        self.npoints = 10
+        self.curveTextRectifier = CurveTextRectifier()
+
+    @staticmethod
+    def get_rotate_crop_image(img, points, interpolation=cv2.INTER_CUBIC, ratio_width=1.0, ratio_height=1.0):
+        """
+        crop or homography
+        :param img:
+        :param points:
+        :param interpolation:
+        :param ratio_width:
+        :param ratio_height:
+        :return:
+        """
+        h, w = img.shape[:2]
+        _points = np.array(points).reshape(-1, 2).astype(np.float32)
+
+        if len(_points) != 4:
+            x_min = int(np.min(_points[:, 0]))
+            y_min = int(np.min(_points[:, 1]))
+            x_max = int(np.max(_points[:, 0]))
+            y_max = int(np.max(_points[:, 1]))
+            dx = x_max - x_min
+            dy = y_max - y_min
+            expand_x = int(0.5 * dx * (ratio_width - 1))
+            expand_y = int(0.5 * dy * (ratio_height - 1))
+            x_min = np.clip(int(x_min - expand_x), 0, w - 1)
+            y_min = np.clip(int(y_min - expand_y), 0, h - 1)
+            x_max = np.clip(int(x_max + expand_x), 0, w - 1)
+            y_max = np.clip(int(y_max + expand_y), 0, h - 1)
+
+            dst_img = img[y_min:y_max, x_min:x_max, :].copy()
+        else:
+            img_crop_width = int(
+                max(
+                    np.linalg.norm(_points[0] - _points[1]),
+                    np.linalg.norm(_points[2] - _points[3])))
+            img_crop_height = int(
+                max(
+                    np.linalg.norm(_points[0] - _points[3]),
+                    np.linalg.norm(_points[1] - _points[2])))
+
+            dst_img = Homography(img, _points, img_crop_width, img_crop_height, interpolation, ratio_width, ratio_height)
+
+        return dst_img
+
+
+    def visualize(self, image_data, points_list):
+        visualization = image_data.copy()
+
+        for box in points_list:
+            box = np.array(box).reshape(-1, 2).astype(np.int32)
+            cv2.drawContours(visualization, [np.array(box).reshape((-1, 1, 2))], -1, (0, 0, 255), 2)
+            for i, p in enumerate(box):
+                if i != 0:
+                    cv2.circle(visualization, tuple(p), radius=1, color=(255, 0, 0), thickness=2)
+                else:
+                    cv2.circle(visualization, tuple(p), radius=1, color=(255, 255, 0), thickness=2)
+        return visualization
+
+
+    def __call__(self, image_data, points, interpolation=cv2.INTER_LINEAR,
+                 ratio_width=1.0, ratio_height=1.0, loss_thresh=5.0, mode='calibration'):
+        """
+        rectification in strategies for a poly text
+        :param image_data:
+        :param points: [x1,y1,x2,y2,x3,y3,...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return:
+        """
+        _points = np.array(points).reshape(-1,2)
+        if len(_points) >= self.npoints and len(_points) % 2 == 0:
+            try:
+                curveTextRectifier = CurveTextRectifier()
+
+                dst_img, loss = curveTextRectifier(image_data, points, interpolation, ratio_width, ratio_height, mode)
+                if loss >= 2:
+                    # for robust
+                    # large loss means it cannot be reconstruct correctly, we must find other way to reconstruct
+                    img_list, loss_list = [dst_img], [loss]
+                    _dst_img, _loss = PlanB()(image_data, points, curveTextRectifier,
+                                              interpolation, ratio_width, ratio_height,
+                                              loss_thresh=loss_thresh,
+                                              square=True)
+                    img_list += [_dst_img]
+                    loss_list += [_loss]
+
+                    _dst_img, _loss = PlanB()(image_data, points, curveTextRectifier,
+                                              interpolation, ratio_width, ratio_height,
+                                              loss_thresh=loss_thresh, square=False)
+                    img_list += [_dst_img]
+                    loss_list += [_loss]
+
+                    min_loss = min(loss_list)
+                    dst_img = img_list[loss_list.index(min_loss)]
+
+                    if min_loss >= loss_thresh:
+                        print('calibration loss: {} is too large for spatial transformer. It is failed. Using get_rotate_crop_image'.format(loss))
+                        dst_img = self.get_rotate_crop_image(image_data, points, interpolation, ratio_width, ratio_height)
+                        print('here')
+            except Exception as e:
+                print(e)
+                dst_img = self.get_rotate_crop_image(image_data, points, interpolation, ratio_width, ratio_height)
+        else:
+            dst_img = self.get_rotate_crop_image(image_data, _points, interpolation, ratio_width, ratio_height)
+
+        return dst_img
+
+
+    def run(self, image_data, points_list, interpolation=cv2.INTER_LINEAR,
+            ratio_width=1.0, ratio_height=1.0, loss_thresh=5.0, mode='calibration'):
+        """
+        run for texts in an image
+        :param image_data: numpy.ndarray. The shape is [h, w, 3]
+        :param points_list: [[x1,y1,x2,y2,x3,y3,...], [x1,y1,x2,y2,x3,y3,...], ...], clockwise order, (x1,y1) must be the top-left of first char.
+        :param interpolation: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4
+        :param ratio_width:  roi_image width expansion. It should not be smaller than 1.0
+        :param ratio_height: roi_image height expansion. It should not be smaller than 1.0
+        :param loss_thresh: if loss greater than loss_thresh --> get_rotate_crop_image
+        :param mode: 'calibration' or 'homography'. when homography, ratio_width and ratio_height must be 1.0
+        :return: res: roi-image list, visualized_image: draw polys in original image
+        """
+        if image_data is None:
+            raise ValueError
+        if not isinstance(points_list, list):
+            raise ValueError
+        for points in points_list:
+            if not isinstance(points, list):
+                raise ValueError
+
+        if ratio_width < 1.0 or ratio_height < 1.0:
+            raise ValueError('ratio_width and ratio_height cannot be smaller than 1, but got {}', (ratio_width, ratio_height))
+
+        if mode.lower() != 'calibration' and mode.lower() != 'homography':
+            raise ValueError('mode must be ["calibration", "homography"], but got {}'.format(mode))
+
+        if mode.lower() == 'homography' and ratio_width != 1.0 and ratio_height != 1.0:
+            raise ValueError('ratio_width and ratio_height must be 1.0 when mode is homography, but got mode:{}, ratio:({},{})'.format(mode, ratio_width, ratio_height))
+
+        res = []
+        for points in points_list:
+            rectified_img = self(image_data, points, interpolation, ratio_width, ratio_height,
+                                 loss_thresh=loss_thresh, mode=mode)
+            res.append(rectified_img)
+
+        # visualize
+        visualized_image = self.visualize(image_data, points_list)
+
+        return res, visualized_image
+

+ 33 - 13
paddlex/pipelines/OCR/pipeline.py

@@ -49,6 +49,11 @@ class OCRPipeline(BasePipeline):
         self.device = device
         self.text_det_kernel_option = text_det_kernel_option
         self.text_rec_kernel_option = text_rec_kernel_option
+        if self.text_det_model_name in ['PP-OCRv4_server_seal_det', 'PP-OCRv4_mobile_seal_det']:
+            self.task = "poly"
+        else:
+            self.task = "quad"
+        
         if (
             self.text_det_model_name is not None
             and self.text_rec_model_name is not None
@@ -80,19 +85,34 @@ Only support: {text_rec_models}."
             if self.text_rec_kernel_option is None
             else self.text_rec_kernel_option
         )
-        text_det_post_transforms = [
-            text_det_T.DBPostProcess(
-                thresh=0.3,
-                box_thresh=0.6,
-                max_candidates=1000,
-                unclip_ratio=1.5,
-                use_dilation=False,
-                score_mode="fast",
-                box_type="quad",
-            ),
-            # TODO
-            text_det_T.CropByPolys(det_box_type="foo"),
-        ]
+        if self.task == "poly":
+            text_det_post_transforms = [
+                text_det_T.DBPostProcess(
+                    thresh=0.2,
+                    box_thresh=0.6,
+                    max_candidates=1000,
+                    unclip_ratio=1.5,
+                    use_dilation=False,
+                    score_mode="fast",
+                    box_type="poly",
+                ),
+                # TODO
+                text_det_T.CropByPolys(det_box_type="poly"),
+            ]
+        else:
+            text_det_post_transforms = [
+                text_det_T.DBPostProcess(
+                    thresh=0.3,
+                    box_thresh=0.6,
+                    max_candidates=1000,
+                    unclip_ratio=1.5,
+                    use_dilation=False,
+                    score_mode="fast",
+                    box_type="quad",
+                ),
+                # TODO
+                text_det_T.CropByPolys(det_box_type="quad"),
+            ]
 
         self.text_det_model = create_model(
             self.text_det_model_name,

+ 41 - 7
paddlex/pipelines/OCR/utils.py

@@ -24,6 +24,28 @@ import copy
 from ...utils.fonts import PINGFANG_FONT_FILE_PATH
 
 
+def get_minarea_rect(points):
+    bounding_box = cv2.minAreaRect(points)
+    points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
+
+    index_a, index_b, index_c, index_d = 0, 1, 2, 3
+    if points[1][1] > points[0][1]:
+        index_a = 0
+        index_d = 1
+    else:
+        index_a = 1
+        index_d = 0
+    if points[3][1] > points[2][1]:
+        index_b = 2
+        index_c = 3
+    else:
+        index_b = 3
+        index_c = 2
+
+    box = np.array([points[index_a], points[index_b], points[index_c], points[index_d]]).astype(np.int32)
+
+    return box
+
 def draw_ocr_box_txt(
     img,
     boxes,
@@ -43,14 +65,26 @@ def draw_ocr_box_txt(
     if txts is None or len(txts) != len(boxes):
         txts = [None] * len(boxes)
     for idx, (box, txt) in enumerate(zip(boxes, txts)):
-        if scores is not None and scores[idx] < drop_score:
+        try:
+            if scores is not None and scores[idx] < drop_score:
+                continue
+            color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
+            box = np.array(box)
+            if len(box) > 4:
+                pts = [(x, y) for x, y in box.tolist()]
+                draw_left.polygon(pts, outline=color, width=8)
+                box = get_minarea_rect(box)
+                height = int(0.5 * (max(box[:,1]) - min(box[:,1])))
+                box[:2,1] = np.mean(box[:,1])
+                box[2:,1] = np.mean(box[:,1]) + min(20, height)
+            draw_left.polygon(box, fill=color)
+            img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
+            pts = np.array(box, np.int32).reshape((-1, 1, 2))
+            cv2.polylines(img_right_text, [pts], True, color, 1)
+            img_right = cv2.bitwise_and(img_right, img_right_text)
+        except:
             continue
-        color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
-        draw_left.polygon(box, fill=color)
-        img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
-        pts = np.array(box, np.int32).reshape((-1, 1, 2))
-        cv2.polylines(img_right_text, [pts], True, color, 1)
-        img_right = cv2.bitwise_and(img_right, img_right_text)
+
     img_left = Image.blend(image, img_left, 0.5)
     img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
     img_show.paste(img_left, (0, 0, w, h))

+ 169 - 0
paddlex/repo_apis/PaddleOCR_api/configs/PP-OCRv4_mobile_seal_det.yaml

@@ -0,0 +1,169 @@
+Global:
+  debug: false
+  use_gpu: true
+  epoch_num: 100
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: output
+  save_epoch_step: 1
+  eval_batch_step:
+  - 0
+  - 100
+  cal_metric_during_train: false
+  checkpoints:
+  pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/ch_PP-OCRv4_mobile_det_curve_trained.pdparams
+  save_inference_dir: null
+  use_visualdl: false
+  distributed: true
+
+Architecture:
+  model_type: det
+  algorithm: DB
+  Transform: null
+  Backbone:
+    name: PPLCNetV3
+    scale: 0.75
+    det: True
+  Neck:
+    name: RSEFPN
+    out_channels: 96
+    shortcut: True
+  Head:
+    name: DBHead
+    k: 50
+
+Loss:
+  name: DBLoss
+  balance_loss: true
+  main_loss_type: DiceLoss
+  alpha: 5
+  beta: 10
+  ohem_ratio: 3
+
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Cosine
+    learning_rate: 0.001
+    warmup_epoch: 2
+  regularizer:
+    name: L2
+    factor: 1e-6
+
+PostProcess:
+  name: DBPostProcess
+  thresh: 0.2
+  box_thresh: 0.6
+  max_candidates: 1000
+  unclip_ratio: 0.5
+  box_type: "poly"
+
+Metric:
+  name: DetMetric
+  main_indicator: hmean
+
+Train:
+  dataset:
+    name: TextDetDataset
+    data_dir: datasets/ICDAR2015
+    label_file_list:
+      - datasets/ICDAR2015/train.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - IaaAugment:
+        augmenter_args:
+        - type: Fliplr
+          args:
+            p: 0.5
+        - type: Affine
+          args:
+            rotate:
+            - -10
+            - 10
+        - type: Resize
+          args:
+            size:
+            - 0.5
+            - 3
+    - EastRandomCropData:
+        size:
+        - 640
+        - 640
+        max_tries: 50
+        keep_ratio: true
+    - MakeBorderMap:
+        shrink_ratio: 0.8
+        thresh_min: 0.3
+        thresh_max: 0.7
+        total_epoch: 500
+    - MakeShrinkMap:
+        shrink_ratio: 0.8
+        min_text_size: 8
+        total_epoch: 500
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - threshold_map
+        - threshold_mask
+        - shrink_map
+        - shrink_mask
+  loader:
+    shuffle: true
+    drop_last: false
+    batch_size_per_card: 8
+    num_workers: 3
+
+Eval:
+  dataset:
+    name: TextDetDataset
+    data_dir: datasets/ICDAR2015
+    label_file_list:
+      - datasets/ICDAR2015/val.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - DetResizeForTest:
+        resize_long: 736
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - shape
+        - polys
+        - ignore_tags
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 1
+    num_workers: 0
+profiler_options: null

+ 169 - 0
paddlex/repo_apis/PaddleOCR_api/configs/PP-OCRv4_server_seal_det.yaml

@@ -0,0 +1,169 @@
+Global:
+  debug: false
+  use_gpu: true
+  epoch_num: 100
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: output
+  save_epoch_step: 1
+  eval_batch_step:
+  - 0
+  - 100
+  cal_metric_during_train: false
+  checkpoints:
+  pretrained_model: https://paddleocr.bj.bcebos.com/pretrained/ch_PP-OCRv4_det_curve_trained.pdparams
+  save_inference_dir: null
+  use_visualdl: false
+  distributed: true
+
+Architecture:
+  model_type: det
+  algorithm: DB
+  Transform: null
+  Backbone:
+    name: PPHGNet_small
+    det: True
+  Neck:
+    name: LKPAN
+    out_channels: 256
+    intracl: true
+  Head:
+    name: PFHeadLocal
+    k: 50
+    mode: "large"
+
+Loss:
+  name: DBLoss
+  balance_loss: true
+  main_loss_type: DiceLoss
+  alpha: 5
+  beta: 10
+  ohem_ratio: 3
+
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Cosine
+    learning_rate: 0.001
+    warmup_epoch: 2
+  regularizer:
+    name: L2
+    factor: 1e-6
+
+PostProcess:
+  name: DBPostProcess
+  thresh: 0.2
+  box_thresh: 0.6
+  max_candidates: 1000
+  unclip_ratio: 0.5
+  box_type: "poly"
+
+Metric:
+  name: DetMetric
+  main_indicator: hmean
+
+Train:
+  dataset:
+    name: TextDetDataset
+    data_dir: datasets/ICDAR2015
+    label_file_list:
+      - datasets/ICDAR2015/train.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - IaaAugment:
+        augmenter_args:
+        - type: Fliplr
+          args:
+            p: 0.5
+        - type: Affine
+          args:
+            rotate:
+            - -10
+            - 10
+        - type: Resize
+          args:
+            size:
+            - 0.5
+            - 3
+    - EastRandomCropData:
+        size:
+        - 640
+        - 640
+        max_tries: 50
+        keep_ratio: true
+    - MakeBorderMap:
+        shrink_ratio: 0.8
+        thresh_min: 0.3
+        thresh_max: 0.7
+        total_epoch: 500
+    - MakeShrinkMap:
+        shrink_ratio: 0.8
+        min_text_size: 8
+        total_epoch: 500
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - threshold_map
+        - threshold_mask
+        - shrink_map
+        - shrink_mask
+  loader:
+    shuffle: true
+    drop_last: false
+    batch_size_per_card: 4
+    num_workers: 3
+
+Eval:
+  dataset:
+    name: TextDetDataset
+    data_dir: datasets/ICDAR2015
+    label_file_list:
+      - datasets/ICDAR2015/val.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - DetResizeForTest:
+        resize_long: 736
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - shape
+        - polys
+        - ignore_tags
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 1
+    num_workers: 0
+profiler_options: null

+ 21 - 1
paddlex/repo_apis/PaddleOCR_api/text_det/register.py

@@ -40,7 +40,7 @@ register_model_info(
     {
         "model_name": "PP-OCRv4_mobile_det",
         "suite": "TextDet",
-        "config_path": osp.join(PDX_CONFIG_DIR, "PP-OCRv4_mobile_det.yaml"),
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-OCRv4_mobile_seal_det.yaml"),
         "supported_apis": ["train", "evaluate", "predict", "export"],
         "hpi_config_path": HPI_CONFIG_DIR / "PP-OCRv4_mobile_det.yaml",
     }
@@ -55,3 +55,23 @@ register_model_info(
         "hpi_config_path": HPI_CONFIG_DIR / "PP-OCRv4_server_det.yaml",
     }
 )
+
+register_model_info(
+    {
+        "model_name": "PP-OCRv4_server_seal_det",
+        "suite": "TextDet",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-OCRv4_server_seal_det.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+        "hpi_config_path": HPI_CONFIG_DIR / "PP-OCRv4_server_seal_det.yaml",
+    }
+)
+
+register_model_info(
+    {
+        "model_name": "PP-OCRv4_mobile_seal_det",
+        "suite": "TextDet",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-OCRv4_mobile_seal_det.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+        "hpi_config_path": HPI_CONFIG_DIR / "PP-OCRv4_mobile_seal_det.yaml",
+    }
+)

+ 35 - 0
paddlex/utils/hpi_configs/PP-OCRv4_mobile_seal_det.yaml

@@ -0,0 +1,35 @@
+Hpi:
+  backend_config:
+    onnx_runtime:
+      cpu_num_threads: 8
+    openvino:
+      cpu_num_threads: 8
+    paddle_infer:
+      cpu_num_threads: 8
+      enable_log_info: false
+    paddle_tensorrt:
+      dynamic_shapes:
+        enable_log_info: false
+        x:
+        - []
+        - []
+        - []
+    tensorrt:
+      dynamic_shapes:
+        x:
+        - []
+        - []
+        - []
+  selected_backends:
+    cpu: onnx_runtime
+    gpu: paddle_tensorrt
+  supported_backends:
+    cpu:
+    - paddle_infer
+    - openvino
+    - onnx_runtime
+    gpu:
+    - paddle_infer
+    - paddle_tensorrt
+    - onnx_runtime
+    - tensorrt

+ 35 - 0
paddlex/utils/hpi_configs/PP-OCRv4_server_seal_det.yaml

@@ -0,0 +1,35 @@
+Hpi:
+  backend_config:
+    onnx_runtime:
+      cpu_num_threads: 8
+    openvino:
+      cpu_num_threads: 8
+    paddle_infer:
+      cpu_num_threads: 8
+      enable_log_info: false
+    paddle_tensorrt:
+      dynamic_shapes:
+        enable_log_info: false
+        x:
+        - []
+        - []
+        - []
+    tensorrt:
+      dynamic_shapes:
+        x:
+        - []
+        - []
+        - []
+  selected_backends:
+    cpu: onnx_runtime
+    gpu: paddle_tensorrt
+  supported_backends:
+    cpu:
+    - paddle_infer
+    - openvino
+    - onnx_runtime
+    gpu:
+    - paddle_infer
+    - paddle_tensorrt
+    - onnx_runtime
+    - tensorrt