Forráskód Böngészése

add ShiTu Rec modules (#1988)

cuicheng01 1 éve
szülő
commit
b547c62e9d
32 módosított fájl, 1976 hozzáadás és 0 törlés
  1. 42 0
      paddlex/configs/general_recognition/PP-ShiTuV2_rec.yaml
  2. 42 0
      paddlex/configs/general_recognition/PP-ShiTuV2_rec_CLIP_vit_base.yaml
  3. 41 0
      paddlex/configs/general_recognition/PP-ShiTuV2_rec_CLIP_vit_large.yaml
  4. 8 0
      paddlex/modules/__init__.py
  5. 6 0
      paddlex/modules/base/predictor/utils/official_models.py
  6. 19 0
      paddlex/modules/general_recognition/__init__.py
  7. 105 0
      paddlex/modules/general_recognition/dataset_checker/__init__.py
  8. 18 0
      paddlex/modules/general_recognition/dataset_checker/dataset_src/__init__.py
  9. 92 0
      paddlex/modules/general_recognition/dataset_checker/dataset_src/analyse_dataset.py
  10. 100 0
      paddlex/modules/general_recognition/dataset_checker/dataset_src/check_dataset.py
  11. 82 0
      paddlex/modules/general_recognition/dataset_checker/dataset_src/split_dataset.py
  12. 13 0
      paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/__init__.py
  13. 152 0
      paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/visualizer.py
  14. 29 0
      paddlex/modules/general_recognition/evaluator.py
  15. 22 0
      paddlex/modules/general_recognition/exportor.py
  16. 20 0
      paddlex/modules/general_recognition/model_list.py
  17. 16 0
      paddlex/modules/general_recognition/predictor/__init__.py
  18. 29 0
      paddlex/modules/general_recognition/predictor/keys.py
  19. 79 0
      paddlex/modules/general_recognition/predictor/predictor.py
  20. 67 0
      paddlex/modules/general_recognition/predictor/transforms.py
  21. 83 0
      paddlex/modules/general_recognition/predictor/utils.py
  22. 50 0
      paddlex/modules/general_recognition/trainer.py
  23. 1 0
      paddlex/repo_apis/PaddleClas_api/__init__.py
  24. 2 0
      paddlex/repo_apis/PaddleClas_api/cls/__init__.py
  25. 206 0
      paddlex/repo_apis/PaddleClas_api/configs/PP-ShiTuV2_rec.yaml
  26. 169 0
      paddlex/repo_apis/PaddleClas_api/configs/PP-ShiTuV2_rec_CLIP_vit_base.yaml
  27. 169 0
      paddlex/repo_apis/PaddleClas_api/configs/PP-ShiTuV2_rec_CLIP_vit_large.yaml
  28. 18 0
      paddlex/repo_apis/PaddleClas_api/shitu_rec/__init__.py
  29. 145 0
      paddlex/repo_apis/PaddleClas_api/shitu_rec/config.py
  30. 23 0
      paddlex/repo_apis/PaddleClas_api/shitu_rec/model.py
  31. 74 0
      paddlex/repo_apis/PaddleClas_api/shitu_rec/register.py
  32. 54 0
      paddlex/repo_apis/PaddleClas_api/shitu_rec/runner.py

+ 42 - 0
paddlex/configs/general_recognition/PP-ShiTuV2_rec.yaml

@@ -0,0 +1,42 @@
+Global:
+  model: PP-ShiTuV2_rec
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/shitu_rec/Inshop_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: null
+  split: 
+    enable: False
+    train_percent: null
+    gallery_percent: null
+    query_percent: null
+
+Train:
+  num_classes: 159
+  epochs_iters: 20
+  batch_size: 128
+  learning_rate: 0.01
+  pretrain_weight_path: https://paddleclas.bj.bcebos.com/models/PP-ShiTu/PP-ShiTuV2_rec_pretrained.pdparams
+  warmup_steps: 5
+  resume_path: null
+  log_interval: 1
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_model.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: https://paddleclas.bj.bcebos.com/models/PP-ShiTu/PP-ShiTuV2_rec_pretrained.pdparams
+
+Predict:
+  model_dir: "output/best_model"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_image_recognition_001.jpg"
+  kernel_option:
+    run_mode: paddle
+    batch_size: 1

+ 42 - 0
paddlex/configs/general_recognition/PP-ShiTuV2_rec_CLIP_vit_base.yaml

@@ -0,0 +1,42 @@
+Global:
+  model: PP-ShiTuV2_rec_CLIP_vit_base
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/shitu_rec/Inshop_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: null
+  split: 
+    enable: False
+    train_percent: null
+    gallery_percent: null
+    query_percent: null
+
+Train:
+  num_classes: 159
+  epochs_iters: 20
+  batch_size: 32
+  learning_rate: 0.001
+  pretrain_weight_path: https://paddleclas.bj.bcebos.com/models/PP-ShiTu/PP-ShiTuV2_rec_CLIP_vit_base_pretrained.pdparams
+  warmup_steps: 5
+  resume_path: null
+  log_interval: 1
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_model.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: https://paddleclas.bj.bcebos.com/models/PP-ShiTu/PP-ShiTuV2_rec_CLIP_vit_base_pretrained.pdparams
+
+Predict:
+  model_dir: "output/best_model"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_image_recognition_001.jpg"
+  kernel_option:
+    run_mode: paddle
+    batch_size: 1

+ 41 - 0
paddlex/configs/general_recognition/PP-ShiTuV2_rec_CLIP_vit_large.yaml

@@ -0,0 +1,41 @@
+Global:
+  model: PP-ShiTuV2_rec_CLIP_vit_large
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/shitu_rec/Inshop_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: null
+  split: 
+    enable: False
+    train_percent: null
+    gallery_percent: null
+    query_percent: null
+
+Train:
+  num_classes: 159
+  epochs_iters: 20
+  batch_size: 32
+  learning_rate: 0.001
+  pretrain_weight_path: https://paddleclas.bj.bcebos.com/models/PP-ShiTu/PP-ShiTuV2_rec_CLIP_vit_large_pretrained.pdparams
+  warmup_steps: 5
+  resume_path: null
+  log_interval: 1
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_model.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: https://paddleclas.bj.bcebos.com/models/PP-ShiTu/PP-ShiTuV2_rec_CLIP_vit_large_pretrained.pdparams 
+Predict:
+  model_dir: "output/best_model"
+  input_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_image_recognition_001.jpg"
+  kernel_option:
+    run_mode: paddle
+    batch_size: 1

+ 8 - 0
paddlex/modules/__init__.py

@@ -29,6 +29,14 @@ from .image_classification import (
     ClsExportor,
     ClsPredictor,
 )
+
+from .general_recognition import (
+    ShiTuRecDatasetChecker,
+    ShiTuRecTrainer,
+    ShiTuRecEvaluator,
+    ShiTuRecExportor,
+)
+
 from .object_detection import (
     COCODatasetChecker,
     DetTrainer,

+ 6 - 0
paddlex/modules/base/predictor/utils/official_models.py

@@ -106,6 +106,11 @@ SwinTransformer_large_patch4_window12_384_infer.tar",
 CLIP_vit_base_patch16_224_infer.tar",
     "CLIP_vit_large_patch14_224": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/\
 CLIP_vit_large_patch14_224_infer.tar",
+    "PP-ShiTuV2_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/PP-ShiTuV2_rec_infer.tar",
+    "PP-ShiTuV2_rec_CLIP_vit_base": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/\
+PP-ShiTuV2_rec_CLIP_vit_base_infer.tar",
+    "PP-ShiTuV2_rec_CLIP_vit_large": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0/\
+PP-ShiTuV2_rec_CLIP_vit_large_infer.tar",
     "PP-LCNet_x1_0_ML": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-LCNet_x1_0_ML_infer.tar",
     "PP-HGNetV2-B0_ML": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNetV2-B0_ML_infer.tar",
     "PP-HGNetV2-B4_ML": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PP-HGNetV2-B4_ML_infer.tar",
@@ -185,6 +190,7 @@ openatom_rec_svtrv2_ch_infer.tar",
 }
 
 
+
 class OfficialModelsDict(dict):
     """Official Models Dict"""
 

+ 19 - 0
paddlex/modules/general_recognition/__init__.py

@@ -0,0 +1,19 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .trainer import ShiTuRecTrainer
+from .dataset_checker import ShiTuRecDatasetChecker
+from .evaluator import ShiTuRecEvaluator
+from .exportor import ShiTuRecExportor
+from .predictor import ShiTuRecPredictor, transforms

+ 105 - 0
paddlex/modules/general_recognition/dataset_checker/__init__.py

@@ -0,0 +1,105 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+
+from ...base import BaseDatasetChecker
+from .dataset_src import check, split_dataset, deep_analyse
+from ..model_list import MODELS
+
+
+class ShiTuRecDatasetChecker(BaseDatasetChecker):
+    """Dataset Checker for ShiTu Recognition Model"""
+
+    entities = MODELS
+    sample_num = 10
+
+    def get_dataset_root(self, dataset_dir: str) -> str:
+        """find the dataset root dir
+
+        Args:
+            dataset_dir (str): the directory that contain dataset.
+
+        Returns:
+            str: the root directory of dataset.
+        """
+        anno_dirs = list(Path(dataset_dir).glob("**/images"))
+        assert len(anno_dirs) == 1
+        dataset_dir = anno_dirs[0].parent.as_posix()
+        return dataset_dir
+
+    def convert_dataset(self, src_dataset_dir: str) -> str:
+        """convert the dataset from other type to specified type
+
+        Args:
+            src_dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            str: the root directory of converted dataset.
+        """
+        return src_dataset_dir
+
+    def split_dataset(self, src_dataset_dir: str) -> str:
+        """repartition the train and validation dataset
+
+        Args:
+            src_dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            str: the root directory of splited dataset.
+        """
+        return split_dataset(
+            src_dataset_dir,
+            self.check_dataset_config.split.train_percent,
+            self.check_dataset_config.split.gallery_percent,
+            self.check_dataset_config.split.query_percent,
+        )
+
+    def check_dataset(self, dataset_dir: str, sample_num: int = sample_num) -> dict:
+        """check if the dataset meets the specifications and get dataset summary
+
+        Args:
+            dataset_dir (str): the root directory of dataset.
+            sample_num (int): the number to be sampled.
+        Returns:
+            dict: dataset summary.
+        """
+        return check(dataset_dir, self.output)
+
+    def analyse(self, dataset_dir: str) -> dict:
+        """deep analyse dataset
+
+        Args:
+            dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            dict: the deep analysis results.
+        """
+        return deep_analyse(dataset_dir, self.output)
+
+    def get_show_type(self) -> str:
+        """get the show type of dataset
+
+        Returns:
+            str: show type
+        """
+        return "image"
+
+    def get_dataset_type(self) -> str:
+        """return the dataset type
+
+        Returns:
+            str: dataset type
+        """
+        return "ShiTuRecDataset"

+ 18 - 0
paddlex/modules/general_recognition/dataset_checker/dataset_src/__init__.py

@@ -0,0 +1,18 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .check_dataset import check
+from .split_dataset import split_dataset
+from .analyse_dataset import deep_analyse

+ 92 - 0
paddlex/modules/general_recognition/dataset_checker/dataset_src/analyse_dataset.py

@@ -0,0 +1,92 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+import math
+import platform
+from pathlib import Path
+
+from collections import defaultdict
+from PIL import Image
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import font_manager
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+
+from .....utils.file_interface import custom_open
+from .....utils.fonts import PINGFANG_FONT_FILE_PATH
+
+
+def deep_analyse(dataset_path, output, dataset_type="ShiTuRec"):
+    """class analysis for dataset"""
+    tags = ["train", "gallery", "query"]
+    tags_info = dict()
+    for tag in tags:
+        anno_path = os.path.join(dataset_path, f"{tag}.txt")
+        with custom_open(anno_path, "r") as f:
+            lines = f.readlines()
+            lines = [line.strip("\n").split(" ") for line in lines]
+            num_images = len(lines)
+            num_labels = len(set([int(line[1]) for line in lines]))
+        tags_info[tag] = {
+            "num_images": num_images,
+            "num_labels": num_labels,
+        }
+
+    categories = list(tags_info.keys())
+    num_images = [tags_info[category]['num_images'] for category in categories]
+    num_labels = [tags_info[category]['num_labels'] for category in categories]
+
+    # bar
+    os_system = platform.system().lower()
+    if os_system == "windows":
+        plt.rcParams["font.sans-serif"] = "FangSong"
+    else:
+        font = font_manager.FontProperties(fname=PINGFANG_FONT_FILE_PATH, size=10)
+
+    x = np.arange(len(categories))  # 标签位置
+    width = 0.35  # 每个条形的宽度
+
+    fig, ax = plt.subplots()
+    rects1 = ax.bar(x - width/2, num_images, width, label="Num Images")
+    rects2 = ax.bar(x + width/2, num_labels, width, label="Num Classes")
+
+    # 添加一些文本标签
+    ax.set_xlabel("集合", fontproperties=None if os_system == "windows" else font)
+    ax.set_ylabel("数量", fontproperties=None if os_system == "windows" else font)
+    ax.set_title("不同集合的图片和类别数量", fontproperties=None if os_system == "windows" else font)
+    ax.set_xticks(x, fontproperties=None if os_system == "windows" else font)
+    ax.set_xticklabels(categories)
+    ax.legend()
+
+    # 在条形图上添加数值标签
+    def autolabel(rects):
+        """Attach a text label above each bar in *rects*, displaying its height."""
+        for rect in rects:
+            height = rect.get_height()
+            ax.annotate('{}'.format(height),
+                        xy=(rect.get_x() + rect.get_width() / 2, height),
+                        xytext=(0, 3),  # 3 points vertical offset
+                        textcoords="offset points",
+                        ha="center", va="bottom")
+
+    autolabel(rects1)
+    autolabel(rects2)
+
+    fig.tight_layout()
+    file_path = os.path.join(output, "histogram.png")
+    fig.savefig(file_path, dpi=300)
+
+    return {"histogram": os.path.join("check_dataset", "histogram.png")}

+ 100 - 0
paddlex/modules/general_recognition/dataset_checker/dataset_src/check_dataset.py

@@ -0,0 +1,100 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import random
+from PIL import Image, ImageOps
+from collections import defaultdict
+
+from .....utils.errors import DatasetFileNotFoundError, CheckFailedError
+from .utils.visualizer import draw_label
+
+
+def check(dataset_dir, output, sample_num=10, dataset_type="ShiTuRec"):
+    """check dataset"""
+    dataset_dir = osp.abspath(dataset_dir)
+    # Custom dataset
+    if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
+        raise DatasetFileNotFoundError(file_path=dataset_dir)
+
+    tags = ["train", "gallery", "query"]
+
+    delim = " "
+    valid_num_parts = 2
+
+    sample_cnts = dict()
+    label_map_dict = dict()
+    sample_paths = defaultdict(list)
+    labels = []
+
+    for tag in tags:
+        file_list = osp.join(dataset_dir, f"{tag}.txt")
+        if not osp.exists(file_list):
+            if tag in ("train", "gallery", "query"):
+                # train, gallery, query file lists must exist
+                raise DatasetFileNotFoundError(
+                    file_path=file_list,
+                    solution=f"Ensure that both `train.txt`, `gallery.txt`, `query.txt` exist in {dataset_dir}",
+                )
+            else:
+                # tag == 'test'
+                continue
+        else:
+            with open(file_list, "r", encoding="utf-8") as f:
+                all_lines = f.readlines()
+                random.seed(123)
+                random.shuffle(all_lines)
+                sample_cnts[tag] = len(all_lines)
+                for line in all_lines:
+                    substr = line.strip("\n").split(delim)
+                    if len(substr) != valid_num_parts:
+                        raise CheckFailedError(
+                            f"The number of delimiter-separated items in each row in {file_list} \
+                                    should be {valid_num_parts} (current delimiter is '{delim}')."
+                        )
+                    file_name = substr[0]
+                    label = substr[1]
+
+                    img_path = osp.join(dataset_dir, file_name)
+
+                    if not osp.exists(img_path):
+                        raise DatasetFileNotFoundError(file_path=img_path)
+
+                    vis_save_dir = osp.join(output, "demo_img")
+                    if not osp.exists(vis_save_dir):
+                        os.makedirs(vis_save_dir)
+
+                    if len(sample_paths[tag]) < sample_num:
+                        img = Image.open(img_path)
+                        img = ImageOps.exif_transpose(img)
+                        vis_im = draw_label(img, label)
+                        vis_path = osp.join(vis_save_dir, osp.basename(file_name))
+                        vis_im.save(vis_path)
+                        sample_path = osp.join(
+                            "check_dataset", os.path.relpath(vis_path, output)
+                        )
+                        sample_paths[tag].append(sample_path)
+
+    attrs = {}
+    attrs["train_samples"] = sample_cnts["train"]
+    attrs["train_sample_paths"] = sample_paths["train"]
+
+    attrs["gallery_samples"] = sample_cnts["gallery"]
+    attrs["gallery_sample_paths"] = sample_paths["gallery"]
+
+    attrs["query_samples"] = sample_cnts["query"]
+    attrs["query_sample_paths"] = sample_paths["query"]
+
+    return attrs

+ 82 - 0
paddlex/modules/general_recognition/dataset_checker/dataset_src/split_dataset.py

@@ -0,0 +1,82 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+from random import shuffle
+
+from .....utils.file_interface import custom_open
+
+
+def split_dataset(root_dir, train_rate, gallery_rate, query_rate):
+    """
+    将图像数据集按照比例分成训练集、验证集和测试集,并生成对应的.txt文件。
+
+    Args:
+        root_dir (str): 数据集根目录路径。
+        train_rate (int): 训练集占总数据集的比例(%)。
+        gallery_rate (int): 被查询数据集占总数据集的比例(%)。
+        query_rate (int): 查询数据集占总数据集的比例(%)。
+
+    Returns:
+        str: 数据划分结果信息。
+    """
+    sum_rate = train_rate + gallery_rate + query_rate
+    assert (
+        sum_rate == 100
+    ), f"The sum of train_rate({train_rate}), gallery_rate({gallery_rate}), query_rate({query_rate}) should equal 100!"
+    assert (
+        train_rate > 0 and gallery_rate > 0 and query_rate > 0
+    ), f"The train_rate({train_rate}) and gallery_rate({gallery_rate}) and query_rate({query_rate}) should be greater than 0!"
+    tags = ["train", "gallery", "query"]
+    valid_path = False
+    image_files = []
+    for tag in tags:
+        split_image_list = os.path.abspath(os.path.join(root_dir, f"{tag}.txt"))
+        rename_image_list = os.path.abspath(os.path.join(root_dir, f"{tag}.txt.bak"))
+        if os.path.exists(split_image_list):
+            with custom_open(split_image_list, "r") as f:
+                lines = f.readlines()
+            image_files = image_files + lines
+            valid_path = True
+            if not os.path.exists(rename_image_list):
+                os.rename(split_image_list, rename_image_list)
+
+    assert (
+        valid_path
+    ), f"The files to be divided{tags[0]}.txt, {tags[1]}.txt, {tags[1]}.txt, do not exist in the dataset directory."
+
+    shuffle(image_files)
+    start = 0
+    image_num = len(image_files)
+    rate_list = [train_rate, gallery_rate, query_rate]
+    for i, tag in enumerate(tags):
+
+        rate = rate_list[i]
+        if rate == 0:
+            continue
+
+        end = start + round(image_num * rate / 100)
+        if sum(rate_list[i + 1 :]) == 0:
+            end = image_num
+
+        txt_file = os.path.abspath(os.path.join(root_dir, tag + ".txt"))
+        with custom_open(txt_file, "w") as f:
+            m = 0
+            for id in range(start, end):
+                m += 1
+                f.write(image_files[id])
+        start = end
+
+    return root_dir

+ 13 - 0
paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/__init__.py

@@ -0,0 +1,13 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

+ 152 - 0
paddlex/modules/general_recognition/dataset_checker/dataset_src/utils/visualizer.py

@@ -0,0 +1,152 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+import json
+from pathlib import Path
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ......utils.fonts import PINGFANG_FONT_FILE_PATH
+
+
+def colormap(rgb=False):
+    """
+    Get colormap
+    """
+    color_list = np.array(
+        [
+            0xFF,
+            0x00,
+            0x00,
+            0xCC,
+            0xFF,
+            0x00,
+            0x00,
+            0xFF,
+            0x66,
+            0x00,
+            0x66,
+            0xFF,
+            0xCC,
+            0x00,
+            0xFF,
+            0xFF,
+            0x4D,
+            0x00,
+            0x80,
+            0xFF,
+            0x00,
+            0x00,
+            0xFF,
+            0xB2,
+            0x00,
+            0x1A,
+            0xFF,
+            0xFF,
+            0x00,
+            0xE5,
+            0xFF,
+            0x99,
+            0x00,
+            0x33,
+            0xFF,
+            0x00,
+            0x00,
+            0xFF,
+            0xFF,
+            0x33,
+            0x00,
+            0xFF,
+            0xFF,
+            0x00,
+            0x99,
+            0xFF,
+            0xE5,
+            0x00,
+            0x00,
+            0xFF,
+            0x1A,
+            0x00,
+            0xB2,
+            0xFF,
+            0x80,
+            0x00,
+            0xFF,
+            0xFF,
+            0x00,
+            0x4D,
+        ]
+    ).astype(np.float32)
+    color_list = color_list.reshape((-1, 3))
+    if not rgb:
+        color_list = color_list[:, ::-1]
+    return color_list.astype("int32")
+
+
+def font_colormap(color_index):
+    """
+    Get font colormap
+    """
+    dark = np.array([0x14, 0x0E, 0x35])
+    light = np.array([0xFF, 0xFF, 0xFF])
+    light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+    if color_index in light_indexs:
+        return light.astype("int32")
+    else:
+        return dark.astype("int32")
+
+
+def draw_label(image, label):
+    """Draw label on image"""
+    image = image.convert("RGB")
+    image_size = image.size
+    draw = ImageDraw.Draw(image)
+    min_font_size = int(image_size[0] * 0.02)
+    max_font_size = int(image_size[0] * 0.05)
+    for font_size in range(max_font_size, min_font_size - 1, -1):
+        font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
+        if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
+            text_width_tmp, text_height_tmp = draw.textsize(label, font)
+        else:
+            left, top, right, bottom = draw.textbbox((0, 0), label, font)
+            text_width_tmp, text_height_tmp = right - left, bottom - top
+        if text_width_tmp <= image_size[0]:
+            break
+        else:
+            font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, min_font_size)
+    color_list = colormap(rgb=True)
+    color = tuple(color_list[0])
+    font_color = tuple(font_colormap(3))
+    if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
+        text_width, text_height = draw.textsize(label, font)
+    else:
+        left, top, right, bottom = draw.textbbox(
+            (0, 0), label, font
+        )
+        text_width, text_height = right - left, bottom - top
+
+    rect_left = 3
+    rect_top = 3
+    rect_right = rect_left + text_width + 3
+    rect_bottom = rect_top + text_height + 6
+
+    draw.rectangle([(rect_left, rect_top), (rect_right, rect_bottom)], fill=color)
+
+    text_x = rect_left + 3
+    text_y = rect_top
+    draw.text((text_x, text_y), label, fill=font_color, font=font)
+
+    return image

+ 29 - 0
paddlex/modules/general_recognition/evaluator.py

@@ -0,0 +1,29 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..image_classification import ClsEvaluator
+from .model_list import MODELS
+
+class ShiTuRecEvaluator(ClsEvaluator):
+    """ShiTu Recognition Model Evaluator"""
+
+    entities = MODELS
+    
+    def update_config(self):
+        """update evalution config"""
+        if self.eval_config.log_interval:
+            self.pdx_config.update_log_interval(self.eval_config.log_interval)
+        self.pdx_config.update_dataset(self.global_config.dataset_dir, "ShiTuRecDataset")
+        self.pdx_config.update_pretrained_weights(self.eval_config.weight_path)
+

+ 22 - 0
paddlex/modules/general_recognition/exportor.py

@@ -0,0 +1,22 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..image_classification import ClsExportor
+from .model_list import MODELS
+
+
+class ShiTuRecExportor(ClsExportor):
+    """ShiTu Recognition Model Exportor"""
+    
+    entities = MODELS

+ 20 - 0
paddlex/modules/general_recognition/model_list.py

@@ -0,0 +1,20 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+MODELS = [
+"PP-ShiTuV2_rec",
+"PP-ShiTuV2_rec_CLIP_vit_base",
+"PP-ShiTuV2_rec_CLIP_vit_large"
+]
+

+ 16 - 0
paddlex/modules/general_recognition/predictor/__init__.py

@@ -0,0 +1,16 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import ShiTuRecPredictor
+from . import transforms

+ 29 - 0
paddlex/modules/general_recognition/predictor/keys.py

@@ -0,0 +1,29 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class ShiTuRecKeys(object):
+    """
+    This class defines a set of keys used for communication of ShiTuRec predictors
+    and transforms. Both predictors and transforms accept a dict or a list of
+    dicts as input, and they get the objects of their interest from the dict, or
+    put the generated objects into the dict, all based on these keys.
+    """
+
+    # Common keys
+    IMAGE = "image"
+    IM_PATH = "input_path"
+    # Suite-specific keys
+    SHITU_REC_PRED = "shitu_rec_pred"
+    SHITU_REC_RESULT = "shitu_rec_result"

+ 79 - 0
paddlex/modules/general_recognition/predictor/predictor.py

@@ -0,0 +1,79 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import numpy as np
+from pathlib import Path
+
+from ...base import BasePredictor
+from ...base.predictor.transforms import image_common
+from .keys import ShiTuRecKeys as K
+from .utils import InnerConfig
+from ....utils import logging
+from . import transforms as T
+from ..model_list import MODELS
+
+class ShiTuRecPredictor(BasePredictor):
+    """ShiTu Recognition Predictor"""
+
+    entities = MODELS
+
+    def load_other_src(self):
+        """load the inner config file"""
+        infer_cfg_file_path = os.path.join(self.model_dir, "inference.yml")
+        if not os.path.exists(infer_cfg_file_path):
+            raise FileNotFoundError(f"Cannot find config file: {infer_cfg_file_path}")
+        return InnerConfig(infer_cfg_file_path)
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [[K.IMAGE], [K.IM_PATH]]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return [K.SHITU_REC_PRED]
+
+    def _run(self, batch_input):
+        """run"""
+        input_dict = {}
+        input_dict[K.IMAGE] = np.stack(
+            [data[K.IMAGE] for data in batch_input], axis=0
+        ).astype(dtype=np.float32, copy=False)
+        input_ = [input_dict[K.IMAGE]]
+        outputs = self._predictor.predict(input_)
+        shitu_rec_outs = outputs[0]
+        # In-place update
+        pred = batch_input
+        for dict_, shitu_rec_out in zip(pred, shitu_rec_outs):
+            dict_[K.SHITU_REC_PRED] = shitu_rec_out
+        return pred
+
+    def _get_pre_transforms_from_config(self):
+        """get preprocess transforms"""
+        logging.info(
+            f"Transformation operators for data preprocessing will be inferred from config file."
+        )
+        pre_transforms = self.other_src.pre_transforms
+        pre_transforms.insert(0, image_common.ReadImage(format="RGB"))
+        return pre_transforms
+
+    def _get_post_transforms_from_config(self):
+        """get postprocess transforms"""
+        post_transforms = self.other_src.post_transforms
+        if not self.disable_print:
+            post_transforms.append(T.PrintShiTuRecResult())
+        return post_transforms

+ 67 - 0
paddlex/modules/general_recognition/predictor/transforms.py

@@ -0,0 +1,67 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+
+from .keys import ShiTuRecKeys as K
+from ...base import BaseTransform
+from ....utils import logging
+
+__all__ = [
+    "NormalizeFeatures",
+    "PrintShiTuRecResult"
+]
+
+
+class NormalizeFeatures(BaseTransform):
+    """Normalize Features Transform"""
+
+    def apply(self, data):
+        """apply"""
+        x = data[K.SHITU_REC_PRED]
+        feas_norm = np.sqrt(np.sum(np.square(x), axis=0, keepdims=True))
+        x = np.divide(x, feas_norm)
+        data[K.SHITU_REC_RESULT] = x
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [K.IM_PATH, K.SHITU_REC_PRED]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return [K.SHITU_REC_RESULT]
+
+
+class PrintShiTuRecResult(BaseTransform):
+    """Print Result Transform"""
+
+    def apply(self, data):
+        """apply"""
+        logging.info("The prediction result is:")
+        logging.info(data[K.SHITU_REC_RESULT])
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [K.SHITU_REC_RESULT]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return []

+ 83 - 0
paddlex/modules/general_recognition/predictor/utils.py

@@ -0,0 +1,83 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import codecs
+
+import yaml
+
+from ...base.predictor.transforms import image_common
+from . import transforms as T
+
+
+class InnerConfig(object):
+    """Inner Config"""
+
+    def __init__(self, config_path):
+        self.inner_cfg = self.load(config_path)
+
+    def load(self, config_path):
+        """load infer config"""
+        with codecs.open(config_path, "r", "utf-8") as file:
+            dic = yaml.load(file, Loader=yaml.FullLoader)
+        return dic
+
+    @property
+    def pre_transforms(self):
+        """read preprocess transforms from config file"""
+        if "RecPreProcess" in list(self.inner_cfg.keys()):
+            tfs_cfg = self.inner_cfg["RecPreProcess"]["transform_ops"]
+        else:
+            tfs_cfg = self.inner_cfg["PreProcess"]["transform_ops"]
+        tfs = []
+        for cfg in tfs_cfg:
+            tf_key = list(cfg.keys())[0]
+            if tf_key == "NormalizeImage":
+                tf = image_common.Normalize(
+                    mean=cfg["NormalizeImage"].get("mean", [0.485, 0.456, 0.406]),
+                    std=cfg["NormalizeImage"].get("std", [0.229, 0.224, 0.225]),
+                )
+            elif tf_key == "ResizeImage":
+                if "resize_short" in list(cfg[tf_key].keys()):
+                    tf = image_common.ResizeByShort(
+                        target_short_edge=cfg["ResizeImage"].get("resize_short", 224),
+                        size_divisor=None,
+                        interp="LINEAR",
+                    )
+                else:
+                    tf = image_common.Resize(
+                        target_size=cfg["ResizeImage"].get("size", (224, 224))
+                    )
+            elif tf_key == "CropImage":
+                tf = image_common.Crop(crop_size=cfg["CropImage"].get("size", 224))
+            elif tf_key == "ToCHWImage":
+                tf = image_common.ToCHWImage()
+            else:
+                raise RuntimeError(f"Unsupported type: {tf_key}")
+            tfs.append(tf)
+        return tfs
+
+    @property
+    def post_transforms(self):
+        """read postprocess transforms from config file"""
+        tfs_cfg = self.inner_cfg["PostProcess"]
+        tfs = []
+        if tfs_cfg is None:
+            return tfs
+        for tf_key in tfs_cfg:
+            if tf_key == "NormalizeFeatures":
+                tf = T.NormalizeFeatures()
+            else:
+                raise RuntimeError(f"Unsupported type: {tf_key}")
+            tfs.append(tf)
+        return tfs

+ 50 - 0
paddlex/modules/general_recognition/trainer.py

@@ -0,0 +1,50 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..image_classification import ClsTrainer
+from .model_list import MODELS
+
+
+class ShiTuRecTrainer(ClsTrainer):
+    """ShiTu Recognition Model Trainer"""
+
+    entities = MODELS
+
+    def update_config(self):
+        """update training config"""
+        if self.train_config.log_interval:
+            self.pdx_config.update_log_interval(self.train_config.log_interval)
+        if self.train_config.eval_interval:
+            self.pdx_config.update_eval_interval(self.train_config.eval_interval)
+        if self.train_config.save_interval:
+            self.pdx_config.update_save_interval(self.train_config.save_interval)
+
+        self.pdx_config.update_dataset(self.global_config.dataset_dir, "ShiTuRecDataset")
+        if self.train_config.num_classes is not None:
+            self.pdx_config.update_num_classes(self.train_config.num_classes)
+        if self.train_config.pretrain_weight_path != "":
+            self.pdx_config.update_pretrained_weights(
+                self.train_config.pretrain_weight_path
+            )
+
+        if self.train_config.batch_size is not None:
+            self.pdx_config.update_batch_size(self.train_config.batch_size)
+        if self.train_config.learning_rate is not None:
+            self.pdx_config.update_learning_rate(self.train_config.learning_rate)
+        if self.train_config.epochs_iters is not None:
+            self.pdx_config._update_epochs(self.train_config.epochs_iters)
+        if self.train_config.warmup_steps is not None:
+            self.pdx_config.update_warmup_epochs(self.train_config.warmup_steps)
+        if self.global_config.output is not None:
+            self.pdx_config._update_output_dir(self.global_config.output)

+ 1 - 0
paddlex/repo_apis/PaddleClas_api/__init__.py

@@ -14,3 +14,4 @@
 
 
 from .cls import ClsModel, ClsRunner, register
+from .shitu_rec import ShiTuRecModel, ShiTuRecRunner, register

+ 2 - 0
paddlex/repo_apis/PaddleClas_api/cls/__init__.py

@@ -15,4 +15,6 @@
 
 from .model import ClsModel
 from .runner import ClsRunner
+from .config import ClsConfig
 from . import register
+

+ 206 - 0
paddlex/repo_apis/PaddleClas_api/configs/PP-ShiTuV2_rec.yaml

@@ -0,0 +1,206 @@
+# global configs
+Global:
+  checkpoints: null
+  pretrained_model: null
+  output_dir: ./output
+  device: gpu
+  save_interval: 1
+  eval_during_train: True
+  eval_interval: 1
+  epochs: 100
+  print_batch_step: 20
+  use_visualdl: False
+  eval_mode: retrieval
+  retrieval_feature_from: features # 'backbone' or 'features'
+  re_ranking: False
+  use_dali: False
+  # used for static mode and model export
+  image_shape: [3, 224, 224]
+  save_inference_dir: ./inference
+
+# mixed precision
+AMP:
+  use_amp: False
+  use_fp16_test: False
+  scale_loss: 128.0
+  use_dynamic_loss_scaling: True
+  use_promote: False
+  # O1: mixed fp16, O2: pure fp16
+  level: O1
+
+# model architecture
+Arch:
+  name: RecModel
+  infer_output_key: features
+  infer_add_softmax: False
+
+  Backbone:
+    name: PPLCNetV2_base_ShiTu
+    pretrained: True
+    use_ssld: True
+    class_expand: &feat_dim 512
+  BackboneStopLayer:
+    name: flatten
+  Neck:
+    name: BNNeck
+    num_features: *feat_dim
+    weight_attr:
+      initializer:
+        name: Constant
+        value: 1.0
+    bias_attr:
+      initializer:
+        name: Constant
+        value: 0.0
+      learning_rate: 1.0e-20 # NOTE: Temporarily set lr small enough to freeze the bias to zero
+  Head:
+    name: FC
+    embedding_size: *feat_dim
+    class_num: 192612
+    weight_attr:
+      initializer:
+        name: Normal
+        std: 0.001
+    bias_attr: False
+
+# loss function config for traing/eval process
+Loss:
+  Train:
+    - CELoss:
+        weight: 1.0
+        epsilon: 0.1
+    - TripletAngularMarginLoss:
+        weight: 1.0
+        feature_from: features
+        margin: 0.5
+        reduction: mean
+        add_absolute: True
+        absolute_loss_weight: 0.1
+        normalize_feature: True
+        ap_value: 0.8
+        an_value: 0.4
+  Eval:
+    - CELoss:
+        weight: 1.0
+
+Optimizer:
+  name: Momentum
+  momentum: 0.9
+  lr:
+    name: Cosine
+    learning_rate: 0.06 # for 8gpu x 256bs
+    warmup_epoch: 5
+  regularizer:
+    name: L2
+    coeff: 0.00001
+
+# data loader for train and eval
+DataLoader:
+  Train:
+    dataset:
+      name: ImageNetDataset
+      image_root: ./dataset/
+      cls_label_path: ./dataset/train_reg_all_data_v2.txt
+      relabel: True
+      transform_ops:
+        - DecodeImage:
+            to_rgb: True
+            channel_first: False
+        - ResizeImage:
+            size: [224, 224]
+            return_numpy: False
+            interpolation: bilinear
+            backend: cv2
+        - RandFlipImage:
+            flip_code: 1
+        - Pad:
+            padding: 10
+            backend: cv2
+        - RandCropImageV2:
+            size: [224, 224]
+        - RandomRotation:
+            prob: 0.5
+            degrees: 90
+            interpolation: bilinear
+        - ResizeImage:
+            size: [224, 224]
+            return_numpy: False
+            interpolation: bilinear
+            backend: cv2
+        - NormalizeImage:
+            scale: 1.0/255.0
+            mean: [0.485, 0.456, 0.406]
+            std: [0.229, 0.224, 0.225]
+            order: hwc
+    sampler:
+      name: PKSampler
+      batch_size: 256
+      sample_per_id: 4
+      drop_last: False
+      shuffle: True
+    loader:
+      num_workers: 4
+      use_shared_memory: True
+
+  Eval:
+    Query:
+      dataset:
+        name: ImageNetDataset
+        image_root: ./dataset/Inshop/
+        cls_label_path: ./dataset/Inshop/query_list.txt
+        transform_ops:
+          - DecodeImage:
+              to_rgb: True
+              channel_first: False
+          - ResizeImage:
+              size: [224, 224]
+              return_numpy: False
+              interpolation: bilinear
+              backend: cv2
+          - NormalizeImage:
+              scale: 1.0/255.0
+              mean: [0.485, 0.456, 0.406]
+              std: [0.229, 0.224, 0.225]
+              order: hwc
+      sampler:
+        name: DistributedBatchSampler
+        batch_size: 64
+        drop_last: False
+        shuffle: False
+      loader:
+        num_workers: 4
+        use_shared_memory: True
+
+    Gallery:
+      dataset:
+        name: ImageNetDataset
+        image_root: ./dataset/Inshop/
+        cls_label_path: ./dataset/Inshop/gallery_list.txt
+        transform_ops:
+          - DecodeImage:
+              to_rgb: True
+              channel_first: False
+          - ResizeImage:
+              size: [224, 224]
+              return_numpy: False
+              interpolation: bilinear
+              backend: cv2
+          - NormalizeImage:
+              scale: 1.0/255.0
+              mean: [0.485, 0.456, 0.406]
+              std: [0.229, 0.224, 0.225]
+              order: hwc
+      sampler:
+        name: DistributedBatchSampler
+        batch_size: 64
+        drop_last: False
+        shuffle: False
+      loader:
+        num_workers: 4
+        use_shared_memory: True
+
+Metric:
+  Eval:
+    - Recallk:
+        topk: [1, 5]
+    - mAP: {}

+ 169 - 0
paddlex/repo_apis/PaddleClas_api/configs/PP-ShiTuV2_rec_CLIP_vit_base.yaml

@@ -0,0 +1,169 @@
+# global configs
+Global:
+  checkpoints: null
+  pretrained_model: null
+  output_dir: ./output/
+  device: gpu
+  save_interval: 5
+  eval_during_train: True
+  eval_interval: 1
+  epochs: 20
+  print_batch_step: 10
+  use_visualdl: False
+  # used for static mode and model export
+  image_shape: [3, 224, 224]
+  save_inference_dir: ./inference
+  eval_mode: retrieval
+  use_dali: False
+  to_static: False
+
+# mixed precision
+AMP:
+  use_amp: False
+  use_fp16_test: False
+  scale_loss: 128.0
+  use_dynamic_loss_scaling: True
+  use_promote: False
+  # O1: mixed fp16, O2: pure fp16
+  level: O1
+
+# model architecture
+Arch:
+  name: RecModel
+  infer_output_key: features
+  infer_add_softmax: False
+
+  Backbone:
+    name: CLIP_vit_base_patch16_224
+    pretrained: True
+    return_embed: True
+    return_mean_embed: True 
+  BackboneStopLayer:
+    name: "flatten"
+  Neck:
+    name: FC
+    embedding_size: 512
+    class_num: 512
+  Head:
+    name: ArcMargin
+    embedding_size: 512
+    class_num: 192613
+    margin: 0.2
+    scale: 30
+
+# loss function config for traing/eval process
+Loss:
+  Train:
+    - CELoss:
+        weight: 1.0
+  Eval:
+    - CELoss:
+        weight: 1.0
+
+Optimizer:
+  name: Momentum
+  momentum: 0.9
+  lr:
+    name: Cosine
+    learning_rate: 0.02
+    warmup_epoch: 1
+  regularizer:
+    name: "L2"
+    coeff: 0.00002
+
+# data loader for train and eval
+DataLoader:
+  Train:
+    dataset:
+      name: ImageNetDataset
+      image_root: ./dataset/
+      cls_label_path: ./dataset/train_reg_all_data_v2.txt 
+      transform_ops:
+        - DecodeImage:
+            to_rgb: True
+            channel_first: False
+        - RandCropImage:
+            size: 224
+        - RandFlipImage:
+            flip_code: 1
+        - TimmAutoAugment:
+            config_str: rand-m7-mstd0.5-inc1
+            interpolation: bicubic
+            img_size: 224  
+        - RandomRotation:
+            prob: 0.3
+            degrees: 90
+            interpolation: bicubic
+        - NormalizeImage:
+            scale: 1.0/255.0
+            mean: [0.485, 0.456, 0.406]
+            std: [0.229, 0.224, 0.225]
+            order: ""
+
+    sampler:
+      name: DistributedBatchSampler
+      batch_size: 64
+      drop_last: False
+      shuffle: True
+    loader:
+      num_workers: 8
+      use_shared_memory: True
+
+  Eval:
+    Query:
+      dataset:
+        name: ImageNetDataset
+        image_root: ./dataset/Inshop/
+        cls_label_path: ./dataset/Inshop/query_list.txt
+        transform_ops:
+          - DecodeImage:
+              to_rgb: True
+              channel_first: False
+          - ResizeImage:
+              size: 224
+              interpolation: bicubic
+          - NormalizeImage:
+              scale: 1.0/255.0
+              mean: [0.485, 0.456, 0.406]
+              std: [0.229, 0.224, 0.225]
+              order: ""
+      sampler:
+        name: DistributedBatchSampler
+        batch_size: 32
+        drop_last: False
+        shuffle: False
+      loader:
+        num_workers: 12
+        use_shared_memory: True
+
+    Gallery:
+      dataset:
+        name: ImageNetDataset
+        image_root: ./dataset/Inshop/
+        cls_label_path: ./dataset/Inshop/gallery_list.txt
+        transform_ops:
+          - DecodeImage:
+              to_rgb: True
+              channel_first: False
+          - ResizeImage:
+              size: 224
+              interpolation: bicubic
+          - NormalizeImage:
+              scale: 1.0/255.0
+              mean: [0.485, 0.456, 0.406]
+              std: [0.229, 0.224, 0.225]
+              order: ""
+      sampler:
+        name: DistributedBatchSampler
+        batch_size: 32
+        drop_last: False
+        shuffle: False
+      loader:
+        num_workers: 12
+        use_shared_memory: True
+
+Metric:
+  Eval:
+    - Recallk:
+        topk: [1, 5]
+    - mAP: {}

+ 169 - 0
paddlex/repo_apis/PaddleClas_api/configs/PP-ShiTuV2_rec_CLIP_vit_large.yaml

@@ -0,0 +1,169 @@
+# global configs
+Global:
+  checkpoints: null
+  pretrained_model: null
+  output_dir: ./output/
+  device: gpu
+  save_interval: 5
+  eval_during_train: True
+  eval_interval: 1
+  epochs: 10
+  print_batch_step: 10
+  use_visualdl: False
+  # used for static mode and model export
+  image_shape: [3, 224, 224]
+  save_inference_dir: ./inference
+  eval_mode: retrieval
+  use_dali: False
+  to_static: False
+
+# mixed precision
+AMP:
+  use_amp: False
+  use_fp16_test: False
+  scale_loss: 128.0
+  use_dynamic_loss_scaling: True
+  use_promote: False
+  # O1: mixed fp16, O2: pure fp16
+  level: O1
+
+# model architecture
+Arch:
+  name: RecModel
+  infer_output_key: features
+  infer_add_softmax: False
+
+  Backbone:
+    name: CLIP_vit_large_patch14_224
+    pretrained: True
+    return_embed: True
+    return_mean_embed: True 
+  BackboneStopLayer:
+    name: "flatten"
+  Neck:
+    name: FC
+    embedding_size: 512
+    class_num: 512
+  Head:
+    name: ArcMargin
+    embedding_size: 512
+    class_num: 192613
+    margin: 0.2
+    scale: 30
+
+# loss function config for traing/eval process
+Loss:
+  Train:
+    - CELoss:
+        weight: 1.0
+  Eval:
+    - CELoss:
+        weight: 1.0
+
+Optimizer:
+  name: Momentum
+  momentum: 0.9
+  lr:
+    name: Cosine
+    learning_rate: 0.0025
+    warmup_epoch: 1
+  regularizer:
+    name: "L2"
+    coeff: 0.00002
+
+# data loader for train and eval
+DataLoader:
+  Train:
+    dataset:
+      name: ImageNetDataset
+      image_root: ./dataset/
+      cls_label_path: ./dataset/train_reg_all_data_v2.txt
+      transform_ops:
+        - DecodeImage:
+            to_rgb: True
+            channel_first: False
+        - RandCropImage:
+            size: 224
+        - RandFlipImage:
+            flip_code: 1
+        - TimmAutoAugment:
+            config_str: rand-m7-mstd0.5-inc1
+            interpolation: bicubic
+            img_size: 224  
+        - RandomRotation:
+            prob: 0.3
+            degrees: 90
+            interpolation: bicubic
+        - NormalizeImage:
+            scale: 1.0/255.0
+            mean: [0.485, 0.456, 0.406]
+            std: [0.229, 0.224, 0.225]
+            order: ""
+
+    sampler:
+      name: DistributedBatchSampler
+      batch_size: 32
+      drop_last: False
+      shuffle: True
+    loader:
+      num_workers: 4
+      use_shared_memory: True
+
+  Eval:
+    Query:
+      dataset:
+        name: ImageNetDataset
+        image_root: ./dataset/Inshop
+        cls_label_path: ./dataset/Inshop/query_list.txt
+        transform_ops:
+          - DecodeImage:
+              to_rgb: True
+              channel_first: False
+          - ResizeImage:
+              size: 224
+              interpolation: bicubic
+          - NormalizeImage:
+              scale: 1.0/255.0
+              mean: [0.485, 0.456, 0.406]
+              std: [0.229, 0.224, 0.225]
+              order: ""
+      sampler:
+        name: DistributedBatchSampler
+        batch_size: 32
+        drop_last: False
+        shuffle: False
+      loader:
+        num_workers: 12
+        use_shared_memory: True
+
+    Gallery:
+      dataset:
+        name: ImageNetDataset
+        image_root: ./dataset/Inshop/
+        cls_label_path: ./dataset/Inshop/gallery_list.txt
+        transform_ops:
+          - DecodeImage:
+              to_rgb: True
+              channel_first: False
+          - ResizeImage:
+              size: 224
+              interpolation: bicubic
+          - NormalizeImage:
+              scale: 1.0/255.0
+              mean: [0.485, 0.456, 0.406]
+              std: [0.229, 0.224, 0.225]
+              order: ""
+      sampler:
+        name: DistributedBatchSampler
+        batch_size: 32
+        drop_last: False
+        shuffle: False
+      loader:
+        num_workers: 12
+        use_shared_memory: True
+
+Metric:
+  Eval:
+    - Recallk:
+        topk: [1, 5]
+    - mAP: {}

+ 18 - 0
paddlex/repo_apis/PaddleClas_api/shitu_rec/__init__.py

@@ -0,0 +1,18 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .model import ShiTuRecModel
+from .runner import ShiTuRecRunner
+from . import register

+ 145 - 0
paddlex/repo_apis/PaddleClas_api/shitu_rec/config.py

@@ -0,0 +1,145 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+ 
+from ..cls import ClsConfig
+from ....utils.misc import abspath
+
+
+class ShiTuRecConfig(ClsConfig):
+    """ShiTu Recognition Config"""
+
+    def update_dataset(
+        self,
+        dataset_path: str,
+        dataset_type: str = None,
+        *,
+        train_list_path: str = None,
+    ):
+        """update dataset settings
+
+        Args:
+            dataset_path (str): the root path of dataset.
+            dataset_type (str, optional): dataset type. Defaults to None.
+            train_list_path (str, optional): the path of train dataset annotation file . Defaults to None.
+
+        Raises:
+            ValueError: the dataset_type error.
+        """
+        dataset_path = abspath(dataset_path)
+
+        dataset_type = "ShiTuRecDataset"
+        if train_list_path:
+            train_list_path = f"{train_list_path}"
+        else:
+            train_list_path = f"{dataset_path}/train.txt"
+
+
+        ds_cfg = [
+            f"DataLoader.Train.dataset.name={dataset_type}",
+            f"DataLoader.Train.dataset.image_root={dataset_path}",
+            f"DataLoader.Train.dataset.cls_label_path={train_list_path}",
+            f"DataLoader.Eval.Query.dataset.name={dataset_type}",
+            f"DataLoader.Eval.Query.dataset.image_root={dataset_path}",
+            f"DataLoader.Eval.Query.dataset.cls_label_path={dataset_path}/query.txt",
+            f"DataLoader.Eval.Gallery.dataset.name={dataset_type}",
+            f"DataLoader.Eval.Gallery.dataset.image_root={dataset_path}",
+            f"DataLoader.Eval.Gallery.dataset.cls_label_path={dataset_path}/gallery.txt",
+        ]
+
+        self.update(ds_cfg)
+
+    def update_batch_size(self, batch_size: int, mode: str = "train"):
+        """update batch size setting
+
+        Args:
+            batch_size (int): the batch size number to set.
+            mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval'.
+                Defaults to 'train'.
+
+        Raises:
+            ValueError: `mode` error.
+        """
+        if mode == "train":
+            if self.DataLoader["Train"]["sampler"].get("batch_size", False):
+                _cfg = [f"DataLoader.Train.sampler.batch_size={batch_size}"]
+            else:
+                _cfg = [f"DataLoader.Train.sampler.first_bs={batch_size}"]
+                _cfg = [f"DataLoader.Train.dataset.name=MultiScaleDataset"]
+        elif mode == "eval":
+            _cfg = [f"DataLoader.Eval.Query.sampler.batch_size={batch_size}"]
+            _cfg = [f"DataLoader.Eval.Gallery.sampler.batch_size={batch_size}"]
+        else:
+            raise ValueError("The input `mode` should be train or eval")
+        self.update(_cfg)
+
+
+    def update_num_classes(self, num_classes: int):
+        """update classes number
+
+        Args:
+            num_classes (int): the classes number value to set.
+        """
+        update_str_list = [f"Arch.Head.class_num={num_classes}"]
+        self.update(update_str_list)
+
+
+    def update_num_workers(self, num_workers: int):
+        """update workers number of train and eval dataloader
+
+        Args:
+            num_workers (int): the value of train and eval dataloader workers number to set.
+        """
+        _cfg = [
+            f"DataLoader.Train.loader.num_workers={num_workers}",
+            f"DataLoader.Eval.Query.loader.num_workers={num_workers}",
+            f"DataLoader.Eval.Gallery.loader.num_workers={num_workers}",
+        ]
+        self.update(_cfg)
+
+    def update_shared_memory(self, shared_memeory: bool):
+        """update shared memory setting of train and eval dataloader
+
+        Args:
+            shared_memeory (bool): whether or not to use shared memory
+        """
+        assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
+        _cfg = [
+            f"DataLoader.Train.loader.use_shared_memory={shared_memeory}",
+            f"DataLoader.Eval.Query.loader.use_shared_memory={shared_memeory}",
+            f"DataLoader.Eval.Gallery.loader.use_shared_memory={shared_memeory}",
+        ]
+        self.update(_cfg)
+
+    def update_shuffle(self, shuffle: bool):
+        """update shuffle setting of train and eval dataloader
+
+        Args:
+            shuffle (bool): whether or not to shuffle the data
+        """
+        assert isinstance(shuffle, bool), "shuffle should be a bool"
+        _cfg = [
+            f"DataLoader.Train.loader.shuffle={shuffle}",
+            f"DataLoader.Eval.Query.loader.shuffle={shuffle}",
+            f"DataLoader.Eval.Gallery.loader.shuffle={shuffle}",
+        ]
+        self.update(_cfg)
+
+
+    def _get_backbone_name(self) -> str:
+        """get backbone name of rec model
+
+        Returns:
+            str: the model backbone name, i.e., `Arch.Backbone.name` in config.
+        """
+        return self.dict["Arch"]["Backbone"]["name"]

+ 23 - 0
paddlex/repo_apis/PaddleClas_api/shitu_rec/model.py

@@ -0,0 +1,23 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from ..cls import ClsModel
+
+
+class ShiTuRecModel(ClsModel):
+    """ShiTu Recognition Model"""
+
+    pass

+ 74 - 0
paddlex/repo_apis/PaddleClas_api/shitu_rec/register.py

@@ -0,0 +1,74 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+
+from ...base.register import register_model_info, register_suite_info
+from .model import ShiTuRecModel
+from .runner import ShiTuRecRunner
+from .config import ShiTuRecConfig
+
+REPO_ROOT_PATH = os.environ.get("PADDLE_PDX_PADDLECLAS_PATH")
+PDX_CONFIG_DIR = osp.abspath(osp.join(osp.dirname(__file__), "..", "configs"))
+
+register_suite_info(
+    {
+        "suite_name": "ShiTuRec",
+        "model": ShiTuRecModel,
+        "runner": ShiTuRecRunner,
+        "config": ShiTuRecConfig,
+        "runner_root_path": REPO_ROOT_PATH,
+    }
+)
+
+################ Models Using Universal Config ################
+register_model_info(
+    {
+        "model_name": "PP-ShiTuV2_rec",
+        "suite": "ShiTuRec",
+        "config_path": osp.join(
+            PDX_CONFIG_DIR, "PP-ShiTuV2_rec.yaml"
+        ),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+        "supported_dataset_types": ["ShiTuRecDataset"],
+        "infer_config": None,
+    }
+)
+
+register_model_info(
+    {
+        "model_name": "PP-ShiTuV2_rec_CLIP_vit_base",
+        "suite": "ShiTuRec",
+        "config_path": osp.join(
+            PDX_CONFIG_DIR, "PP-ShiTuV2_rec_CLIP_vit_base.yaml"
+        ),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+        "supported_dataset_types": ["ShiTuRecDataset"],
+        "infer_config": None,
+    }
+)
+
+register_model_info(
+    {
+        "model_name": "PP-ShiTuV2_rec_CLIP_vit_large",
+        "suite": "ShiTuRec",
+        "config_path": osp.join(
+            PDX_CONFIG_DIR, "PP-ShiTuV2_rec_CLIP_vit_large.yaml"
+        ),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+        "supported_dataset_types": ["ShiTuRecDataset"],
+        "infer_config": None,
+    }
+)

+ 54 - 0
paddlex/repo_apis/PaddleClas_api/shitu_rec/runner.py

@@ -0,0 +1,54 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import tempfile
+from ..cls import ClsRunner
+from ...base.utils.subprocess import CompletedProcess
+
+
+class ShiTuRecRunner(ClsRunner):
+    """ShiTuRec Runner"""
+    pass
+
+
+def _extract_eval_metrics(stdout: str) -> dict:
+    """extract evaluation metrics from training log
+
+    Args:
+        stdout (str): the training log
+
+    Returns:
+        dict: the training metric
+    """
+    import re
+
+    _DP = r"[-+]?[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?"
+    patterns = [
+        r"\[Eval\]\[Epoch 0\]\[Avg\].*top1: (_dp), top5: (_dp)".replace("_dp", _DP),
+        r"\[Eval\]\[Epoch 0\]\[Avg\].*recall1: (_dp), recall5: (_dp), mAP: (_dp)".replace(
+            "_dp", _DP
+        ),
+    ]
+    keys = [["val.top1", "val.top5"], ["recall1", "recall5", "mAP"]]
+
+    metric_dict = dict()
+    for pattern, key in zip(patterns, keys):
+        pattern = re.compile(pattern)
+        for line in stdout.splitlines():
+            match = pattern.search(line)
+            if match:
+                for k, v in zip(key, map(float, match.groups())):
+                    metric_dict[k] = v
+    return metric_dict