Przeglądaj źródła

add params for ocr det (#2701)

* add params for ocr det

* add params for ocr det

* adddet v3

* adddet v3 infer

* adddet v3 infer

* adddet ocr params
Sunflower7788 10 miesięcy temu
rodzic
commit
903b522ee2

+ 40 - 0
paddlex/configs/modules/text_detection/PP-OCRv3_mobile_det.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: PP-OCRv3_mobile_det
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  module: text_det
+  dataset_dir: "/paddle/dataset/paddlex/ocr_det/ocr_det_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 100
+  batch_size: 4
+  learning_rate: 0.001
+  pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-OCRv3_mobile_det_pretrained.pdparams
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy/best_accuracy.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-OCRv3_mobile_det_pretrained.pdparams
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_accuracy/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_001.png"
+  kernel_option:
+    run_mode: paddle

+ 40 - 0
paddlex/configs/modules/text_detection/PP-OCRv3_server_det.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: PP-OCRv3_server_det
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  module: text_det
+  dataset_dir: "/paddle/dataset/paddlex/ocr_det/ocr_det_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 100
+  batch_size: 4
+  learning_rate: 0.001
+  pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-OCRv3_server_det_pretrained.pdparams
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy/best_accuracy.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-OCRv3_server_det_pretrained.pdparams
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_accuracy/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_001.png"
+  kernel_option:
+    run_mode: paddle

+ 77 - 16
paddlex/inference/models_new/text_detection/predictor.py

@@ -12,6 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import numpy as np
+from typing import List, Union
+
 from ....utils.func_register import FuncRegister
 from ....modules.text_detection.model_list import MODELS
 from ...common.batch_sampler import ImageBatchSampler
@@ -36,8 +39,27 @@ class TextDetPredictor(BasicPredictor):
     _FUNC_MAP = {}
     register = FuncRegister(_FUNC_MAP)
 
-    def __init__(self, *args, **kwargs):
+    def __init__(
+        self,
+        limit_side_len: Union[int, None] = None,
+        limit_type: Union[str, None] = None,
+        thresh: Union[float, None] = None,
+        box_thresh: Union[float, None] = None,
+        max_candidates: Union[int, None] = None,
+        unclip_ratio: Union[float, None] = None,
+        use_dilation: Union[bool, None] = None,
+        *args,
+        **kwargs
+    ):
         super().__init__(*args, **kwargs)
+
+        self.limit_side_len = limit_side_len
+        self.limit_type = limit_type
+        self.thresh = thresh
+        self.box_thresh = box_thresh
+        self.max_candidates = max_candidates
+        self.unclip_ratio = unclip_ratio
+        self.use_dilation = use_dilation
         self.pre_tfs, self.infer, self.post_op = self._build()
 
     def _build_batch_sampler(self):
@@ -67,14 +89,37 @@ class TextDetPredictor(BasicPredictor):
         post_op = self.build_postprocess(**self.config["PostProcess"])
         return pre_tfs, infer, post_op
 
-    def process(self, batch_data):
+    def process(
+        self,
+        batch_data: List[Union[str, np.ndarray]],
+        limit_side_len: Union[int, None] = None,
+        limit_type: Union[str, None] = None,
+        thresh: Union[float, None] = None,
+        box_thresh: Union[float, None] = None,
+        max_candidates: Union[int, None] = None,
+        unclip_ratio: Union[float, None] = None,
+        use_dilation: Union[bool, None] = None,
+    ):
+
         batch_raw_imgs = self.pre_tfs["Read"](imgs=batch_data)
-        batch_imgs, batch_shapes = self.pre_tfs["Resize"](imgs=batch_raw_imgs)
+        batch_imgs, batch_shapes = self.pre_tfs["Resize"](
+            imgs=batch_raw_imgs,
+            limit_side_len=limit_side_len or self.limit_side_len,
+            limit_type=limit_type or self.limit_type,
+        )
         batch_imgs = self.pre_tfs["Normalize"](imgs=batch_imgs)
         batch_imgs = self.pre_tfs["ToCHW"](imgs=batch_imgs)
         x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
         batch_preds = self.infer(x=x)
-        polys, scores = self.post_op(batch_preds, batch_shapes)
+        polys, scores = self.post_op(
+            batch_preds,
+            batch_shapes,
+            thresh=thresh or self.thresh,
+            box_thresh=box_thresh or self.box_thresh,
+            max_candidates=max_candidates or self.max_candidates,
+            unclip_ratio=unclip_ratio or self.unclip_ratio,
+            use_dilation=use_dilation or self.use_dilation,
+        )
         return {
             "input_path": batch_data,
             "input_img": batch_raw_imgs,
@@ -88,14 +133,29 @@ class TextDetPredictor(BasicPredictor):
         return "Read", ReadImage(format=img_mode)
 
     @register("DetResizeForTest")
-    def build_resize(self, **kwargs):
+    def build_resize(
+        self,
+        limit_side_len: Union[int, None] = None,
+        limit_type: Union[str, None] = None,
+        **kwargs
+    ):
         # TODO: align to PaddleOCR
-        if self.model_name in ("PP-OCRv4_server_det", "PP-OCRv4_mobile_det"):
-            resize_long = kwargs.get("resize_long", 960)
-            return "Resize", DetResizeForTest(
-                limit_side_len=resize_long, limit_type="max"
-            )
-        return "Resize", DetResizeForTest(**kwargs)
+
+        if self.model_name in (
+            "PP-OCRv4_server_det",
+            "PP-OCRv4_mobile_det",
+            "PP-OCRv3_server_det",
+            "PP-OCRv3_mobile_det",
+        ):
+            limit_side_len = self.limit_side_len or kwargs.get("resize_long", 960)
+            limit_type = self.limit_type or kwargs.get("limit_type", "max")
+        else:
+            limit_side_len = self.limit_side_len or kwargs.get("resize_long", 736)
+            limit_type = self.limit_type or kwargs.get("limit_type", "min")
+
+        return "Resize", DetResizeForTest(
+            limit_side_len=limit_side_len, limit_type=limit_type, **kwargs
+        )
 
     @register("NormalizeImage")
     def build_normalize(
@@ -117,11 +177,12 @@ class TextDetPredictor(BasicPredictor):
     def build_postprocess(self, **kwargs):
         if kwargs.get("name") == "DBPostProcess":
             return DBPostProcess(
-                thresh=kwargs.get("thresh", 0.3),
-                box_thresh=kwargs.get("box_thresh", 0.7),
-                max_candidates=kwargs.get("max_candidates", 1000),
-                unclip_ratio=kwargs.get("unclip_ratio", 2.0),
-                use_dilation=kwargs.get("use_dilation", False),
+                thresh=self.thresh or kwargs.get("thresh", 0.3),
+                box_thresh=self.box_thresh or kwargs.get("box_thresh", 0.6),
+                max_candidates=self.max_candidates
+                or kwargs.get("max_candidates", 1000),
+                unclip_ratio=self.unclip_ratio or kwargs.get("unclip_ratio", 2.0),
+                use_dilation=self.use_dilation or kwargs.get("use_dilation", False),
                 score_mode=kwargs.get("score_mode", "fast"),
                 box_type=kwargs.get("box_type", "quad"),
             )

+ 88 - 29
paddlex/inference/models_new/text_detection/processors.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
+from typing import List, Tuple, Union
 import os
 import sys
 import cv2
@@ -50,23 +50,32 @@ class DetResizeForTest:
             self.limit_side_len = 736
             self.limit_type = "min"
 
-    def __call__(self, imgs):
+    def __call__(
+        self,
+        imgs,
+        limit_side_len: Union[int, None] = None,
+        limit_type: Union[str, None] = None,
+    ):
         """apply"""
         resize_imgs, img_shapes = [], []
         for ori_img in imgs:
-            img, shape = self.resize(ori_img)
+            img, shape = self.resize(ori_img, limit_side_len, limit_type)
             resize_imgs.append(img)
             img_shapes.append(shape)
         return resize_imgs, img_shapes
 
-    def resize(self, img):
+    def resize(
+        self, img, limit_side_len: Union[int, None], limit_type: Union[str, None]
+    ):
         src_h, src_w, _ = img.shape
         if sum([src_h, src_w]) < 64:
             img = self.image_padding(img)
 
         if self.resize_type == 0:
             # img, shape = self.resize_image_type0(img)
-            img, [ratio_h, ratio_w] = self.resize_image_type0(img)
+            img, [ratio_h, ratio_w] = self.resize_image_type0(
+                img, limit_side_len, limit_type
+            )
         elif self.resize_type == 2:
             img, [ratio_h, ratio_w] = self.resize_image_type2(img)
         else:
@@ -95,7 +104,9 @@ class DetResizeForTest:
         # return img, np.array([ori_h, ori_w])
         return img, [ratio_h, ratio_w]
 
-    def resize_image_type0(self, img):
+    def resize_image_type0(
+        self, img, limit_side_len: Union[int, None], limit_type: Union[str, None]
+    ):
         """
         resize image to a size multiple of 32 which is required by the network
         args:
@@ -103,11 +114,12 @@ class DetResizeForTest:
         return(tuple):
             img, (ratio_h, ratio_w)
         """
-        limit_side_len = self.limit_side_len
+        limit_side_len = limit_side_len or self.limit_side_len
+        limit_type = limit_type or self.limit_type
         h, w, c = img.shape
 
         # limit the max side
-        if self.limit_type == "max":
+        if limit_type == "max":
             if max(h, w) > limit_side_len:
                 if h > w:
                     ratio = float(limit_side_len) / h
@@ -115,7 +127,7 @@ class DetResizeForTest:
                     ratio = float(limit_side_len) / w
             else:
                 ratio = 1.0
-        elif self.limit_type == "min":
+        elif limit_type == "min":
             if min(h, w) < limit_side_len:
                 if h < w:
                     ratio = float(limit_side_len) / h
@@ -123,7 +135,7 @@ class DetResizeForTest:
                     ratio = float(limit_side_len) / w
             else:
                 ratio = 1.0
-        elif self.limit_type == "resize_long":
+        elif limit_type == "resize_long":
             ratio = float(limit_side_len) / max(h, w)
         else:
             raise Exception("not support limit type, image ")
@@ -221,10 +233,18 @@ class DBPostProcess:
             "slow",
             "fast",
         ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
+        self.use_dilation = use_dilation
 
-        self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
-
-    def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
+    def polygons_from_bitmap(
+        self,
+        pred,
+        _bitmap,
+        dest_width,
+        dest_height,
+        box_thresh,
+        max_candidates,
+        unclip_ratio,
+    ):
         """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
 
         bitmap = _bitmap
@@ -237,7 +257,7 @@ class DBPostProcess:
             (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
         )
 
-        for contour in contours[: self.max_candidates]:
+        for contour in contours[:max_candidates]:
             epsilon = 0.002 * cv2.arcLength(contour, True)
             approx = cv2.approxPolyDP(contour, epsilon, True)
             points = approx.reshape((-1, 2))
@@ -245,11 +265,11 @@ class DBPostProcess:
                 continue
 
             score = self.box_score_fast(pred, points.reshape(-1, 2))
-            if self.box_thresh > score:
+            if box_thresh > score:
                 continue
 
             if points.shape[0] > 2:
-                box = self.unclip(points, self.unclip_ratio)
+                box = self.unclip(points, unclip_ratio)
                 if len(box) > 1:
                     continue
             else:
@@ -272,7 +292,16 @@ class DBPostProcess:
             scores.append(score)
         return boxes, scores
 
-    def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
+    def boxes_from_bitmap(
+        self,
+        pred,
+        _bitmap,
+        dest_width,
+        dest_height,
+        box_thresh,
+        max_candidates,
+        unclip_ratio,
+    ):
         """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
 
         bitmap = _bitmap
@@ -286,7 +315,7 @@ class DBPostProcess:
         elif len(outs) == 2:
             contours, _ = outs[0], outs[1]
 
-        num_contours = min(len(contours), self.max_candidates)
+        num_contours = min(len(contours), max_candidates)
 
         boxes = []
         scores = []
@@ -300,10 +329,10 @@ class DBPostProcess:
                 score = self.box_score_fast(pred, points.reshape(-1, 2))
             else:
                 score = self.box_score_slow(pred, contour)
-            if self.box_thresh > score:
+            if box_thresh > score:
                 continue
 
-            box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
+            box = self.unclip(points, unclip_ratio).reshape(-1, 1, 2)
             box, sside = self.get_mini_boxes(box)
             if sside < self.min_size + 2:
                 continue
@@ -385,31 +414,61 @@ class DBPostProcess:
         cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
         return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
 
-    def __call__(self, preds, img_shapes):
+    def __call__(
+        self,
+        preds,
+        img_shapes,
+        thresh: Union[float, None] = None,
+        box_thresh: Union[float, None] = None,
+        max_candidates: Union[int, None] = None,
+        unclip_ratio: Union[float, None] = None,
+        use_dilation: Union[bool, None] = None,
+    ):
         """apply"""
         boxes, scores = [], []
         for pred, img_shape in zip(preds[0], img_shapes):
-            box, score = self.process(pred, img_shape)
+            box, score = self.process(
+                pred,
+                img_shape,
+                thresh or self.thresh,
+                box_thresh or self.box_thresh,
+                max_candidates or self.max_candidates,
+                unclip_ratio or self.unclip_ratio,
+                use_dilation or self.use_dilation,
+            )
             boxes.append(box)
             scores.append(score)
         return boxes, scores
 
-    def process(self, pred, img_shape):
+    def process(
+        self,
+        pred,
+        img_shape,
+        thresh,
+        box_thresh,
+        max_candidates,
+        unclip_ratio,
+        use_dilation,
+    ):
         pred = pred[0, :, :]
-        segmentation = pred > self.thresh
-
+        segmentation = pred > thresh
+        dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
         src_h, src_w, ratio_h, ratio_w = img_shape
-        if self.dilation_kernel is not None:
+        if dilation_kernel is not None:
             mask = cv2.dilate(
                 np.array(segmentation).astype(np.uint8),
-                self.dilation_kernel,
+                dilation_kernel,
             )
         else:
             mask = segmentation
         if self.box_type == "poly":
-            boxes, scores = self.polygons_from_bitmap(pred, mask, src_w, src_h)
+            boxes, scores = self.polygons_from_bitmap(
+                pred, mask, src_w, src_h, box_thresh, max_candidates, unclip_ratio
+            )
         elif self.box_type == "quad":
-            boxes, scores = self.boxes_from_bitmap(pred, mask, src_w, src_h)
+            boxes, scores = self.boxes_from_bitmap(
+                pred, mask, src_w, src_h, box_thresh, max_candidates, unclip_ratio
+            )
         else:
             raise ValueError("box_type can only be one of ['quad', 'poly']")
         return boxes, scores

+ 4 - 0
paddlex/inference/utils/official_models.py

@@ -227,6 +227,10 @@ PP-OCRv4_mobile_rec_infer.tar",
 PP-OCRv4_server_det_infer.tar",
     "PP-OCRv4_mobile_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/\
 PP-OCRv4_mobile_det_infer.tar",
+    "PP-OCRv3_server_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/\
+PP-OCRv3_server_det_infer.tar",
+    "PP-OCRv3_mobile_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/\
+PP-OCRv3_mobile_det_infer.tar",
     "PP-OCRv4_server_seal_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/\
 PP-OCRv4_server_seal_det_infer.tar",
     "PP-OCRv4_mobile_seal_det": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b2/\

+ 2 - 0
paddlex/modules/text_detection/model_list.py

@@ -17,6 +17,8 @@ MODELS = [
     "PP-OCRv4_server_det",
     "PP-OCRv4_mobile_seal_det",
     "PP-OCRv4_server_seal_det",
+    "PP-OCRv3_mobile_det",
+    "PP-OCRv3_server_det",
 ]
 
 CURVE_MODELS = ["PP-OCRv4_mobile_seal_det", "PP-OCRv4_server_seal_det"]

+ 162 - 0
paddlex/repo_apis/PaddleOCR_api/configs/PP-OCRv3_mobile_det.yaml

@@ -0,0 +1,162 @@
+Global:
+  debug: false
+  use_gpu: true
+  epoch_num: 500
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: ./output/ch_PP-OCR_V3_det/
+  save_epoch_step: 100
+  eval_batch_step:
+  - 0
+  - 400
+  cal_metric_during_train: false
+  pretrained_model: ch_PP-OCRv3_det_distill_train/student.pdparams
+  checkpoints: null
+  save_inference_dir: null
+  use_visualdl: false
+  infer_img: doc/imgs_en/img_10.jpg
+  save_res_path: ./checkpoints/det_db/predicts_db.txt
+  distributed: true
+
+Architecture:
+  model_type: det
+  algorithm: DB
+  Transform:
+  Backbone:
+    name: MobileNetV3
+    scale: 0.5
+    model_name: large
+    disable_se: True
+  Neck:
+    name: RSEFPN
+    out_channels: 96
+    shortcut: True
+  Head:
+    name: DBHead
+    k: 50
+
+Loss:
+  name: DBLoss
+  balance_loss: true
+  main_loss_type: DiceLoss
+  alpha: 5
+  beta: 10
+  ohem_ratio: 3
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Cosine
+    learning_rate: 0.001
+    warmup_epoch: 2
+  regularizer:
+    name: L2
+    factor: 5.0e-05
+PostProcess:
+  name: DBPostProcess
+  thresh: 0.3
+  box_thresh: 0.6
+  max_candidates: 1000
+  unclip_ratio: 1.5
+Metric:
+  name: DetMetric
+  main_indicator: hmean
+Train:
+  dataset:
+    name: TextDetDataset
+    data_dir: datasets/ICDAR2015
+    label_file_list:
+      - datasets/ICDAR2015/train.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - IaaAugment:
+        augmenter_args:
+        - type: Fliplr
+          args:
+            p: 0.5
+        - type: Affine
+          args:
+            rotate:
+            - -10
+            - 10
+        - type: Resize
+          args:
+            size:
+            - 0.5
+            - 3
+    - EastRandomCropData:
+        size:
+        - 960
+        - 960
+        max_tries: 50
+        keep_ratio: true
+    - MakeBorderMap:
+        shrink_ratio: 0.4
+        thresh_min: 0.3
+        thresh_max: 0.7
+    - MakeShrinkMap:
+        shrink_ratio: 0.4
+        min_text_size: 8
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - threshold_map
+        - threshold_mask
+        - shrink_map
+        - shrink_mask
+  loader:
+    shuffle: true
+    drop_last: false
+    batch_size_per_card: 8
+    num_workers: 4
+Eval:
+  dataset:
+    name: TextDetDataset
+    data_dir: datasets/ICDAR2015
+    label_file_list:
+      - datasets/ICDAR2015/val.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - DetResizeForTest: null
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - shape
+        - polys
+        - ignore_tags
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 1
+    num_workers: 2

+ 159 - 0
paddlex/repo_apis/PaddleOCR_api/configs/PP-OCRv3_server_det.yaml

@@ -0,0 +1,159 @@
+Global:
+  debug: false
+  use_gpu: true
+  epoch_num: 500
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: ./output/ch_PP-OCR_v3_det/
+  save_epoch_step: 100
+  eval_batch_step:
+  - 0
+  - 400
+  cal_metric_during_train: false
+  pretrained_model: ch_PP-OCRv3_det_distill_train/best_accuracy_new.pdparams
+  checkpoints: null
+  save_inference_dir: null
+  use_visualdl: false
+  infer_img: doc/imgs_en/img_10.jpg
+  save_res_path: ./checkpoints/det_db/predicts_db.txt
+  distributed: true
+  d2s_train_image_shape: [3, -1, -1]
+  amp_dtype: bfloat16
+
+Architecture:
+  model_type: det
+  algorithm: DB
+  Backbone:
+    name: ResNet_vd
+    in_channels: 3
+    layers: 50
+  Neck:
+    name: LKPAN
+    out_channels: 256
+  Head:
+    name: DBHead
+    kernel_list: [7,2,2]
+    k: 50
+
+
+Loss:
+  name: DBLoss
+  balance_loss: true
+  main_loss_type: DiceLoss
+  alpha: 5
+  beta: 10
+  ohem_ratio: 3
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Cosine
+    learning_rate: 0.001
+    warmup_epoch: 2
+  regularizer:
+    name: L2
+    factor: 5.0e-05
+
+PostProcess:
+  name: DBPostProcess
+  thresh: 0.3
+  box_thresh: 0.6
+  max_candidates: 1000
+  unclip_ratio: 1.5
+
+Metric:
+  name: DetMetric
+  main_indicator: hmean
+
+Train:
+  dataset:
+    name: TextDetDataset
+    data_dir: datasets/ICDAR2015
+    label_file_list:
+      - datasets/ICDAR2015/train.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - CopyPaste:
+    - IaaAugment:
+        augmenter_args:
+        - type: Fliplr
+          args:
+            p: 0.5
+        - type: Affine
+          args:
+            rotate:
+            - -10
+            - 10
+        - type: Resize
+          args:
+            size:
+            - 0.5
+            - 3
+    - EastRandomCropData:
+        size:
+        - 960
+        - 960
+        max_tries: 50
+        keep_ratio: true
+    - MakeBorderMap:
+        shrink_ratio: 0.4
+        thresh_min: 0.3
+        thresh_max: 0.7
+    - MakeShrinkMap:
+        shrink_ratio: 0.4
+        min_text_size: 8
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - threshold_map
+        - threshold_mask
+        - shrink_map
+        - shrink_mask
+  loader:
+    shuffle: true
+    drop_last: false
+    batch_size_per_card: 8
+    num_workers: 4
+
+Eval:
+  dataset:
+    name: TextDetDataset
+    data_dir: datasets/ICDAR2015
+    label_file_list:
+      - datasets/ICDAR2015/val.txt
+    transforms:
+      - DecodeImage: # load image
+          img_mode: BGR
+          channel_first: False
+      - DetLabelEncode: # Class handling label
+      - DetResizeForTest:
+      - NormalizeImage:
+          scale: 1./255.
+          mean: [0.485, 0.456, 0.406]
+          std: [0.229, 0.224, 0.225]
+          order: 'hwc'
+      - ToCHWImage:
+      - KeepKeys:
+          keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
+  loader:
+    shuffle: False
+    drop_last: False
+    batch_size_per_card: 1 # must be 1
+    num_workers: 2
+profiler_options: null

+ 18 - 0
paddlex/repo_apis/PaddleOCR_api/text_det/register.py

@@ -70,3 +70,21 @@ register_model_info(
         "supported_apis": ["train", "evaluate", "predict", "export"],
     }
 )
+
+register_model_info(
+    {
+        "model_name": "PP-OCRv3_mobile_det",
+        "suite": "TextDet",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-OCRv3_mobile_det.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+    }
+)
+
+register_model_info(
+    {
+        "model_name": "PP-OCRv3_server_det",
+        "suite": "TextDet",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-OCRv3_server_det.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+    }
+)