Pārlūkot izejas kodu

support latin and korean rec model (#4274)

* support latin and korean rec model

* refine font support

* fixed bugs
学卿 4 mēneši atpakaļ
vecāks
revīzija
1d9d645708

+ 39 - 0
paddlex/configs/modules/text_recognition/korean_PP-OCRv5_mobile_rec.yaml

@@ -0,0 +1,39 @@
+Global:
+  model: korean_PP-OCRv5_mobile_rec
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/ocr_rec/ocr_rec_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: null
+  split: 
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 20
+  batch_size: 8
+  learning_rate: 0.001
+  pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/korean_PP-OCRv5_mobile_rec_pretrained.pdparams
+  resume_path: null
+  log_interval: 20
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy/best_accuracy.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/korean_PP-OCRv5_mobile_rec_pretrained.pdparams
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_accuracy/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_rec_003_korean.png"
+  kernel_option:
+    run_mode: paddle

+ 39 - 0
paddlex/configs/modules/text_recognition/latin_PP-OCRv5_mobile_rec.yaml

@@ -0,0 +1,39 @@
+Global:
+  model: latin_PP-OCRv5_mobile_rec
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/ocr_rec/ocr_rec_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: null
+  split: 
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 20
+  batch_size: 8
+  learning_rate: 0.001
+  pretrain_weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/latin_PP-OCRv5_mobile_rec_pretrained.pdparams
+  resume_path: null
+  log_interval: 20
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy/best_accuracy.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/latin_PP-OCRv5_mobile_rec_pretrained.pdparams
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_accuracy/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_ocr_rec_009_latin.png"
+  kernel_option:
+    run_mode: paddle

+ 51 - 1
paddlex/inference/models/text_recognition/predictor.py

@@ -13,6 +13,17 @@
 # limitations under the License.
 
 from ....modules.text_recognition.model_list import MODELS
+from ....utils.fonts import (
+    ARABIC_FONT,
+    CYRILLIC_FONT,
+    DEVANAGARI_FONT,
+    KANNADA_FONT,
+    KOREAN_FONT,
+    LATIN_FONT,
+    SIMFANG_FONT,
+    TAMIL_FONT,
+    TELUGU_FONT,
+)
 from ....utils.func_register import FuncRegister
 from ...common.batch_sampler import ImageBatchSampler
 from ...common.reader import ReadImage
@@ -31,6 +42,7 @@ class TextRecPredictor(BasePredictor):
     def __init__(self, *args, input_shape=None, **kwargs):
         super().__init__(*args, **kwargs)
         self.input_shape = input_shape
+        self.vis_font = self.get_vis_font()
         self.pre_tfs, self.infer, self.post_op = self._build()
 
     def _build_batch_sampler(self):
@@ -68,6 +80,7 @@ class TextRecPredictor(BasePredictor):
             "input_img": batch_raw_imgs,
             "rec_text": texts,
             "rec_score": scores,
+            "vis_font": [self.vis_font] * len(batch_raw_imgs),
         }
 
     @register("DecodeImage")
@@ -76,7 +89,7 @@ class TextRecPredictor(BasePredictor):
         return "Read", ReadImage(format=img_mode)
 
     @register("RecResizeImg")
-    def build_resize(self, image_shape):
+    def build_resize(self, image_shape, **kwargs):
         return "ReisizeNorm", OCRReisizeNormImg(
             rec_image_shape=image_shape, input_shape=self.input_shape
         )
@@ -96,3 +109,40 @@ class TextRecPredictor(BasePredictor):
     @register("KeepKeys")
     def foo(self, *args, **kwargs):
         return None, None
+
+    def get_vis_font(self):
+        if self.model_name.startswith("PP-OCR"):
+            return SIMFANG_FONT
+
+        if self.model_name in (
+            "latin_PP-OCRv3_mobile_rec",
+            "latin_PP-OCRv5_mobile_rec",
+        ):
+            return LATIN_FONT
+
+        if self.model_name in (
+            "cyrillic_PP-OCRv3_mobile_rec",
+            "eslav_PP-OCRv5_mobile_rec",
+        ):
+            return CYRILLIC_FONT
+
+        if self.model_name in (
+            "korean_PP-OCRv3_mobile_rec",
+            "korean_PP-OCRv5_mobile_rec",
+        ):
+            return KOREAN_FONT
+
+        if self.model_name == "arabic_PP-OCRv3_mobile_rec":
+            return ARABIC_FONT
+
+        if self.model_name == "ka_PP-OCRv3_mobile_rec":
+            return KANNADA_FONT
+
+        if self.model_name == "te_PP-OCRv3_mobile_rec":
+            return TELUGU_FONT
+
+        if self.model_name == "ta_PP-OCRv3_mobile_rec":
+            return TAMIL_FONT
+
+        if self.model_name == "devanagari_PP-OCRv3_mobile_rec":
+            return DEVANAGARI_FONT

+ 5 - 2
paddlex/inference/models/text_recognition/result.py

@@ -17,7 +17,7 @@ import copy
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 
-from ....utils.fonts import PINGFANG_FONT
+from ....utils.fonts import SIMFANG_FONT
 from ...common.result import BaseCVResult, JsonMixin
 
 
@@ -26,11 +26,13 @@ class TextRecResult(BaseCVResult):
     def _to_str(self, *args, **kwargs):
         data = copy.deepcopy(self)
         data.pop("input_img")
+        data.pop("vis_font")
         return JsonMixin._to_str(data, *args, **kwargs)
 
     def _to_json(self, *args, **kwargs):
         data = copy.deepcopy(self)
         data.pop("input_img")
+        data.pop("vis_font")
         return JsonMixin._to_json(data, *args, **kwargs)
 
     def _to_img(self):
@@ -38,10 +40,11 @@ class TextRecResult(BaseCVResult):
         image = Image.fromarray(self["input_img"][:, :, ::-1])
         rec_text = self["rec_text"]
         rec_score = self["rec_score"]
+        vis_font = self["vis_font"] if self["vis_font"] is not None else SIMFANG_FONT
         image = image.convert("RGB")
         image_width, image_height = image.size
         text = f"{rec_text} ({rec_score})"
-        font = self.adjust_font_size(image_width, text, PINGFANG_FONT.path)
+        font = self.adjust_font_size(image_width, text, vis_font.path)
         row_height = font.getbbox(text)[3]
         new_image_height = image_height + int(row_height * 1.2)
         new_image = Image.new("RGB", (image_width, new_image_height), (255, 255, 255))

+ 2 - 0
paddlex/inference/pipelines/ocr/pipeline.py

@@ -368,6 +368,7 @@ class _OCRPipeline(BasePipeline):
                     "rec_texts": [],
                     "rec_scores": [],
                     "rec_polys": [],
+                    "vis_fonts": [],
                 }
                 for input_path, page_index, doc_preprocessor_res, dt_polys in zip(
                     batch_data.input_paths,
@@ -439,6 +440,7 @@ class _OCRPipeline(BasePipeline):
                         if rec_res["rec_score"] >= text_rec_score_thresh:
                             res["rec_texts"].append(rec_res["rec_text"])
                             res["rec_scores"].append(rec_res["rec_score"])
+                            res["vis_fonts"].append(rec_res["vis_font"])
                             res["rec_polys"].append(dt_polys[sno])
 
             for res in results:

+ 6 - 1
paddlex/inference/pipelines/ocr/result.py

@@ -82,6 +82,11 @@ class OCRResult(BaseCVResult):
         random.seed(0)
         draw_left = ImageDraw.Draw(img_left)
         for idx, (box, txt) in enumerate(zip(boxes, txts)):
+            vis_font = (
+                self["vis_fonts"][idx]
+                if self["vis_fonts"][idx] is not None
+                else SIMFANG_FONT
+            )
             try:
                 color = (
                     random.randint(0, 255),
@@ -100,7 +105,7 @@ class OCRResult(BaseCVResult):
                     box_pts = [(int(x), int(y)) for x, y in box.tolist()]
                     draw_left.polygon(box_pts, fill=color)
 
-                img_right_text = draw_box_txt_fine((w, h), box, txt, SIMFANG_FONT.path)
+                img_right_text = draw_box_txt_fine((w, h), box, txt, vis_font.path)
                 pts = np.array(box, np.int32).reshape((-1, 1, 2))
                 cv2.polylines(img_right_text, [pts], True, color, 1)
                 img_right = cv2.bitwise_and(img_right, img_right_text)

+ 2 - 0
paddlex/inference/utils/official_models.py

@@ -362,6 +362,8 @@ PP-OCRv5_mobile_rec_infer.tar",
     "eslav_PP-OCRv5_mobile_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/\
 eslav_PP-OCRv5_mobile_rec_infer.tar",
     "PP-DocBee2-3B": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/PP-DocBee2-3B_infer.tar",
+    "latin_PP-OCRv5_mobile_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/latin_PP-OCRv5_mobile_rec_infer.tar",
+    "korean_PP-OCRv5_mobile_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0/korean_PP-OCRv5_mobile_rec_infer.tar",
 }
 
 

+ 2 - 0
paddlex/modules/text_recognition/model_list.py

@@ -33,5 +33,7 @@ MODELS = [
     "ch_RepSVTR_rec",
     "PP-OCRv5_server_rec",
     "PP-OCRv5_mobile_rec",
+    "latin_PP-OCRv5_mobile_rec",
     "eslav_PP-OCRv5_mobile_rec",
+    "korean_PP-OCRv5_mobile_rec",
 ]

+ 1 - 0
paddlex/repo_apis/PaddleOCR_api/configs/eslav_PP-OCRv5_mobile_rec.yaml

@@ -1,4 +1,5 @@
 Global:
+  model_name: eslav_PP-OCRv5_mobile_rec # To use static model for inference.
   debug: false
   use_gpu: true
   epoch_num: 75

+ 140 - 0
paddlex/repo_apis/PaddleOCR_api/configs/korean_PP-OCRv5_mobile_rec.yaml

@@ -0,0 +1,140 @@
+Global:
+  model_name: korean_PP-OCRv5_mobile_rec # To use static model for inference.
+  debug: false
+  use_gpu: true
+  epoch_num: 200
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: ./output/korean_PP-OCRv5_mobile_rec
+  save_epoch_step: 10
+  eval_batch_step: [0, 500]
+  cal_metric_during_train: true
+  pretrained_model: 
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: false
+  infer_img: doc/imgs_words/ch/word_1.jpg
+  character_dict_path: ./ppocr/utils/dict/ppocrv5_korean_dict.txt
+  max_text_length: &max_text_length 25
+  infer_mode: false
+  use_space_char: true
+  distributed: true
+  save_res_path: ./output/rec/predicts_ppocrv5_korean.txt
+  d2s_train_image_shape: [3, 48, 320]
+
+
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Cosine
+    learning_rate: 0.0005
+    warmup_epoch: 5
+  regularizer:
+    name: L2
+    factor: 3.0e-05
+
+
+Architecture:
+  model_type: rec
+  algorithm: SVTR_LCNet
+  Transform:
+  Backbone:
+    name: PPLCNetV3
+    scale: 0.95
+  Head:
+    name: MultiHead
+    head_list:
+      - CTCHead:
+          Neck:
+            name: svtr
+            dims: 120
+            depth: 2
+            hidden_dims: 120
+            kernel_size: [1, 3]
+            use_guide: True
+          Head:
+            fc_decay: 0.00001
+      - NRTRHead:
+          nrtr_dim: 384
+          max_text_length: *max_text_length
+
+Loss:
+  name: MultiLoss
+  loss_config_list:
+    - CTCLoss:
+    - NRTRLoss:
+
+PostProcess:  
+  name: CTCLabelDecode
+
+Metric:
+  name: RecMetric
+  main_indicator: acc
+
+Train:
+  dataset:
+    name: MultiScaleDataSet
+    ds_width: false
+    data_dir: ./train_data/
+    ext_op_transform_idx: 1
+    label_file_list:
+    - ./train_data/train_list.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - RecConAug:
+        prob: 0.5
+        ext_data_num: 2
+        image_shape: [48, 320, 3]
+        max_text_length: *max_text_length
+    - RecAug:
+    - MultiLabelEncode:
+        gtc_encode: NRTRLabelEncode
+    - KeepKeys:
+        keep_keys:
+        - image
+        - label_ctc
+        - label_gtc
+        - length
+        - valid_ratio
+  sampler:
+    name: MultiScaleSampler
+    scales: [[320, 32], [320, 48], [320, 64]]
+    first_bs: &bs 128
+    fix_bs: false
+    divided_factor: [8, 16] # w, h
+    is_training: True
+  loader:
+    shuffle: true
+    batch_size_per_card: *bs
+    drop_last: true
+    num_workers: 8
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data/
+    label_file_list:
+    - ./train_data/val_list.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - MultiLabelEncode:
+        gtc_encode: NRTRLabelEncode
+    - RecResizeImg:
+        image_shape: [3, 48, 320]
+    - KeepKeys:
+        keep_keys:
+        - image
+        - label_ctc
+        - label_gtc
+        - length
+        - valid_ratio
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 128
+    num_workers: 4

+ 143 - 0
paddlex/repo_apis/PaddleOCR_api/configs/latin_PP-OCRv5_mobile_rec.yaml

@@ -0,0 +1,143 @@
+Global:
+  model_name: latin_PP-OCRv5_mobile_rec # To use static model for inference.
+  debug: false
+  use_gpu: true
+  epoch_num: 75
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: ./output/latin_rec_ppocr_v5
+  save_epoch_step: 10
+  eval_batch_step: [0, 500]
+  cal_metric_during_train: true
+  pretrained_model: 
+  checkpoints:
+  save_inference_dir:
+  use_visualdl: false
+  infer_img:
+  character_dict_path: ppocr/utils/dict/ppocrv5_latin_dict.txt
+  infer_mode: false
+  use_space_char: true
+  distributed: true
+  save_res_path: ./output/rec/predicts_ppocrv5_latin.txt
+  d2s_train_image_shape: [3, 48, 320]
+
+
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Cosine
+    learning_rate: 0.0005
+    warmup_epoch: 5
+  regularizer:
+    name: L2
+    factor: 3.0e-05
+
+
+Architecture:
+  model_type: rec
+  algorithm: SVTR_LCNet
+  Transform:
+  Backbone:
+    name: PPLCNetV3
+    scale: 0.95
+  Head:
+    name: MultiHead
+    head_list:
+      - CTCHead:
+          Neck:
+            name: svtr
+            dims: 120
+            depth: 2
+            hidden_dims: 120
+            kernel_size: [1, 3]
+            use_guide: True
+          Head:
+            fc_decay: 0.00001
+      - NRTRHead:
+          nrtr_dim: 384
+          max_text_length: 25
+
+Loss:
+  name: MultiLoss
+  loss_config_list:
+    - CTCLoss:
+    - NRTRLoss:
+
+PostProcess:  
+  name: CTCLabelDecode
+
+Metric:
+  name: RecMetric
+  main_indicator: acc
+  ignore_space: False
+
+Train:
+  dataset:
+    name: MultiScaleDataSet
+    ds_width: false
+    data_dir: ./train_data/
+    ext_op_transform_idx: 1
+    label_file_list:
+    - ./train_data/train_list.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - RecConAug:
+        prob: 0.5
+        ext_data_num: 2
+        image_shape: [48, 320, 3]
+        max_text_length: 25
+    - RecAug:
+    - MultiLabelEncode:
+        max_text_length: 25
+        gtc_encode: NRTRLabelEncode
+    - KeepKeys:
+        keep_keys:
+        - image
+        - label_ctc
+        - label_gtc
+        - length
+        - valid_ratio
+  sampler:
+    name: MultiScaleSampler
+    scales: [[320, 32], [320, 48], [320, 64]]
+    first_bs: &bs 128
+    fix_bs: false
+    divided_factor: [8, 16] # w, h
+    is_training: True
+  loader:
+    shuffle: true
+    batch_size_per_card: *bs
+    drop_last: true
+    num_workers: 8
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: ./train_data/
+    label_file_list:
+    - ./train_data/eval_list.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - MultiLabelEncode:
+        max_text_length: 1000
+        gtc_encode: NRTRLabelEncode
+    - RecResizeImg:
+        eval_mode: True
+        image_shape: [3, 48, 320]
+    - KeepKeys:
+        keep_keys:
+        - image
+        - label_ctc
+        - label_gtc
+        - length
+        - valid_ratio
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 1
+    num_workers: 4

+ 18 - 0
paddlex/repo_apis/PaddleOCR_api/text_rec/register.py

@@ -217,9 +217,27 @@ register_model_info(
 
 register_model_info(
     {
+        "model_name": "latin_PP-OCRv5_mobile_rec",
+        "suite": "TextRec",
+        "config_path": osp.join(PDX_CONFIG_DIR, "latin_PP-OCRv5_mobile_rec.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+    }
+)
+
+register_model_info(
+    {
         "model_name": "eslav_PP-OCRv5_mobile_rec",
         "suite": "TextRec",
         "config_path": osp.join(PDX_CONFIG_DIR, "eslav_PP-OCRv5_mobile_rec.yaml"),
         "supported_apis": ["train", "evaluate", "predict", "export"],
     }
 )
+
+register_model_info(
+    {
+        "model_name": "korean_PP-OCRv5_mobile_rec",
+        "suite": "TextRec",
+        "config_path": osp.join(PDX_CONFIG_DIR, "korean_PP-OCRv5_mobile_rec.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+    }
+)

+ 14 - 5
paddlex/utils/fonts.py

@@ -72,6 +72,9 @@ def create_font_vertical(
 
 class Font:
     def __init__(self, font_name=None, local_path=None):
+        if local_path is None:
+            if Path(str(LOCAL_FONT_FILE_PATH)).is_file():
+                local_path = str(LOCAL_FONT_FILE_PATH)
         self._local_path = local_path
         if not local_path:
             assert font_name is not None
@@ -101,8 +104,14 @@ if Path(str(LOCAL_FONT_FILE_PATH)).is_file():
     logging.warning(
         f"Using the local font file(`{LOCAL_FONT_FILE_PATH}`) specified by `LOCAL_FONT_FILE_PATH`!"
     )
-    PINGFANG_FONT = Font(local_path=LOCAL_FONT_FILE_PATH)
-    SIMFANG_FONT = Font(local_path=LOCAL_FONT_FILE_PATH)
-else:
-    PINGFANG_FONT = Font(font_name="PingFang-SC-Regular.ttf")
-    SIMFANG_FONT = Font(font_name="simfang.ttf")
+
+PINGFANG_FONT = Font(font_name="PingFang-SC-Regular.ttf")
+SIMFANG_FONT = Font(font_name="simfang.ttf")
+LATIN_FONT = Font(font_name="latin.ttf")
+KOREAN_FONT = Font(font_name="korean.ttf")
+ARABIC_FONT = Font(font_name="arabic.ttf")
+CYRILLIC_FONT = Font(font_name="cyrillic.ttf")
+KANNADA_FONT = Font(font_name="kannada.ttf")
+TELUGU_FONT = Font(font_name="telugu.ttf")
+TAMIL_FONT = Font(font_name="tamil.ttf")
+DEVANAGARI_FONT = Font(font_name="devanagari.ttf")