Procházet zdrojové kódy

Merge pull request #1191 from will-jl944/develop_jf

Set num_workers to 0 when negative samples are added.
FlyingQianMM před 4 roky
rodič
revize
31b86ddab3
48 změnil soubory, kde provedl 2119 přidání a 238 odebrání
  1. 1 1
      PaddleDetection
  2. 2 1
      paddlex/cv/datasets/coco.py
  3. 6 2
      paddlex/cv/datasets/voc.py
  4. 5 2
      paddlex/ppdet/data/source/category.py
  5. 1 1
      paddlex/ppdet/data/source/coco.py
  6. 23 13
      paddlex/ppdet/data/source/keypoint_coco.py
  7. 17 14
      paddlex/ppdet/data/source/voc.py
  8. 3 3
      paddlex/ppdet/data/transform/atss_assigner.py
  9. 81 0
      paddlex/ppdet/data/transform/operators.py
  10. 1 1
      paddlex/ppdet/engine/callbacks.py
  11. 6 2
      paddlex/ppdet/engine/export_utils.py
  12. 2 2
      paddlex/ppdet/engine/tracker.py
  13. 17 4
      paddlex/ppdet/engine/trainer.py
  14. 5 6
      paddlex/ppdet/metrics/mot_metrics.py
  15. 1 1
      paddlex/ppdet/model_zoo/model_zoo.py
  16. 3 3
      paddlex/ppdet/model_zoo/tests/test_get_model.py
  17. 3 3
      paddlex/ppdet/model_zoo/tests/test_list_model.py
  18. 11 9
      paddlex/ppdet/modeling/architectures/centernet.py
  19. 2 1
      paddlex/ppdet/modeling/architectures/fairmot.py
  20. 10 1
      paddlex/ppdet/modeling/architectures/keypoint_hrnet.py
  21. 30 1
      paddlex/ppdet/modeling/architectures/meta_arch.py
  22. 3 3
      paddlex/ppdet/modeling/architectures/picodet.py
  23. 10 1
      paddlex/ppdet/modeling/architectures/ssd.py
  24. 6 0
      paddlex/ppdet/modeling/backbones/__init__.py
  25. 224 0
      paddlex/ppdet/modeling/backbones/hardnet.py
  26. 17 3
      paddlex/ppdet/modeling/backbones/hrnet.py
  27. 259 0
      paddlex/ppdet/modeling/backbones/lcnet.py
  28. 7 23
      paddlex/ppdet/modeling/backbones/shufflenet_v2.py
  29. 737 0
      paddlex/ppdet/modeling/backbones/swin_transformer.py
  30. 2 1
      paddlex/ppdet/modeling/heads/centernet_head.py
  31. 5 3
      paddlex/ppdet/modeling/heads/detr_head.py
  32. 22 15
      paddlex/ppdet/modeling/heads/gfl_head.py
  33. 3 74
      paddlex/ppdet/modeling/heads/pico_head.py
  34. 3 0
      paddlex/ppdet/modeling/heads/solov2_head.py
  35. 47 7
      paddlex/ppdet/modeling/heads/ssd_head.py
  36. 1 2
      paddlex/ppdet/modeling/layers.py
  37. 1 1
      paddlex/ppdet/modeling/mot/tracker/deepsort_tracker.py
  38. 2 0
      paddlex/ppdet/modeling/necks/__init__.py
  39. 302 0
      paddlex/ppdet/modeling/necks/bifpn.py
  40. 150 2
      paddlex/ppdet/modeling/necks/centernet_fpn.py
  41. 4 12
      paddlex/ppdet/modeling/necks/pan.py
  42. 2 2
      paddlex/ppdet/modeling/necks/yolo_fpn.py
  43. 15 6
      paddlex/ppdet/modeling/proposal_generator/target.py
  44. 24 7
      paddlex/ppdet/modeling/proposal_generator/target_layer.py
  45. 3 3
      paddlex/ppdet/modeling/tests/test_architectures.py
  46. 32 1
      paddlex/ppdet/optimizer.py
  47. 7 0
      paddlex/ppdet/utils/checkpoint.py
  48. 1 1
      paddlex/ppdet/utils/download.py

+ 1 - 1
PaddleDetection

@@ -1 +1 @@
-Subproject commit 3bdf2671f3188de3c4158c9056a46e949cf02eb8
+Subproject commit 340da51be9227167f3673902533417bde19aae96

+ 2 - 1
paddlex/cv/datasets/coco.py

@@ -196,8 +196,9 @@ class CocoDetection(VOCDetection):
             logging.error(
                 "No coco record found in %s' % (ann_file)", exit=True)
         self.pos_num = len(self.file_list)
-        if self.allow_empty:
+        if self.allow_empty and neg_file_list:
             self.file_list += self._sample_empty(neg_file_list)
+            self.num_workers = 0
         logging.info(
             "{} samples in file {}, including {} positive samples and {} negative samples.".
             format(

+ 6 - 2
paddlex/cv/datasets/voc.py

@@ -290,8 +290,9 @@ class VOCDetection(Dataset):
             logging.error(
                 "No voc record found in %s' % (file_list)", exit=True)
         self.pos_num = len(self.file_list)
-        if self.allow_empty:
+        if self.allow_empty and neg_file_list:
             self.file_list += self._sample_empty(neg_file_list)
+            self.num_workers = 0
         logging.info(
             "{} samples in file {}, including {} positive samples and {} negative samples.".
             format(
@@ -423,7 +424,10 @@ class VOCDetection(Dataset):
                 **
                 label_info
             })
-        self.file_list += self._sample_empty(neg_file_list)
+        if neg_file_list:
+            self.allow_empty = True
+            self.file_list += self._sample_empty(neg_file_list)
+            self.num_workers = 0
         logging.info(
             "{} negative samples added. Dataset contains {} positive samples and {} negative samples.".
             format(

+ 5 - 2
paddlex/ppdet/data/source/category.py

@@ -90,16 +90,19 @@ def get_categories(metric_type, anno_file=None, arch=None):
     elif metric_type.lower() in ['mot', 'motdet', 'reid']:
         return _mot_category()
 
+    elif metric_type.lower() in ['kitti', 'bdd100k']:
+        return _mot_category(category='car')
+
     else:
         raise ValueError("unknown metric type {}".format(metric_type))
 
 
-def _mot_category():
+def _mot_category(category='person'):
     """
     Get class id to category id map and category id
     to category name map of mot dataset
     """
-    label_map = {'person': 0}
+    label_map = {category: 0}
     label_map = sorted(label_map.items(), key=lambda x: x[1])
     cats = [l[0] for l in label_map]
 

+ 1 - 1
paddlex/ppdet/data/source/coco.py

@@ -245,7 +245,7 @@ class COCODataSet(DetDataset):
                 break
         assert ct > 0, 'not found any coco record in %s' % (anno_path)
         logger.debug('{} samples in file {}'.format(ct, anno_path))
-        if len(empty_records) > 0:
+        if self.allow_empty and len(empty_records) > 0:
             empty_records = self._sample_empty(empty_records, len(records))
             records += empty_records
         self.roidbs = records

+ 23 - 13
paddlex/ppdet/data/source/keypoint_coco.py

@@ -63,6 +63,9 @@ class KeypointBottomUpBaseDataset(DetDataset):
         self.ann_info['num_joints'] = num_joints
         self.img_ids = []
 
+    def parse_dataset(self):
+        pass
+
     def __len__(self):
         """Get dataset length."""
         return len(self.img_ids)
@@ -136,26 +139,30 @@ class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
         super().__init__(dataset_dir, image_dir, anno_path, num_joints,
                          transform, shard, test_mode)
 
-        ann_file = os.path.join(dataset_dir, anno_path)
-        self.coco = COCO(ann_file)
+        self.ann_file = os.path.join(dataset_dir, anno_path)
+        self.shard = shard
+        self.test_mode = test_mode
+
+    def parse_dataset(self):
+        self.coco = COCO(self.ann_file)
 
         self.img_ids = self.coco.getImgIds()
-        if not test_mode:
+        if not self.test_mode:
             self.img_ids = [
                 img_id for img_id in self.img_ids
                 if len(self.coco.getAnnIds(
                     imgIds=img_id, iscrowd=None)) > 0
             ]
-        blocknum = int(len(self.img_ids) / shard[1])
-        self.img_ids = self.img_ids[(blocknum * shard[0]):(blocknum * (shard[0]
-                                                                       + 1))]
+        blocknum = int(len(self.img_ids) / self.shard[1])
+        self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
+            self.shard[0] + 1))]
         self.num_images = len(self.img_ids)
         self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
         self.dataset_name = 'coco'
 
         cat_ids = self.coco.getCatIds()
         self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
-        print(f'=> num_images: {self.num_images}')
+        print('=> num_images: {}'.format(self.num_images))
 
     @staticmethod
     def _get_mapping_id_name(imgs):
@@ -301,20 +308,23 @@ class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset):
         super().__init__(dataset_dir, image_dir, anno_path, num_joints,
                          transform, shard, test_mode)
 
-        ann_file = os.path.join(dataset_dir, anno_path)
+        self.ann_file = os.path.join(dataset_dir, anno_path)
+        self.shard = shard
+        self.test_mode = test_mode
 
-        self.coco = COCO(ann_file)
+    def parse_dataset(self):
+        self.coco = COCO(self.ann_file)
 
         self.img_ids = self.coco.getImgIds()
-        if not test_mode:
+        if not self.test_mode:
             self.img_ids = [
                 img_id for img_id in self.img_ids
                 if len(self.coco.getAnnIds(
                     imgIds=img_id, iscrowd=None)) > 0
             ]
-        blocknum = int(len(self.img_ids) / shard[1])
-        self.img_ids = self.img_ids[(blocknum * shard[0]):(blocknum * (shard[0]
-                                                                       + 1))]
+        blocknum = int(len(self.img_ids) / self.shard[1])
+        self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
+            self.shard[0] + 1))]
         self.num_images = len(self.img_ids)
         self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
 

+ 17 - 14
paddlex/ppdet/data/source/voc.py

@@ -131,11 +131,13 @@ class VOCDataSet(DetDataset):
                         'Illegal width: {} or height: {} in annotation, '
                         'and {} will be ignored'.format(im_w, im_h, xml_file))
                     continue
-                gt_bbox = []
-                gt_class = []
-                gt_score = []
-                difficult = []
-                for i, obj in enumerate(objs):
+
+                num_bbox, i = len(objs), 0
+                gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
+                gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
+                gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
+                difficult = np.zeros((num_bbox, 1), dtype=np.int32)
+                for obj in objs:
                     cname = obj.find('name').text
 
                     # user dataset may not contain difficult field
@@ -152,19 +154,20 @@ class VOCDataSet(DetDataset):
                     x2 = min(im_w - 1, x2)
                     y2 = min(im_h - 1, y2)
                     if x2 > x1 and y2 > y1:
-                        gt_bbox.append([x1, y1, x2, y2])
-                        gt_class.append([cname2cid[cname]])
-                        gt_score.append([1.])
-                        difficult.append([_difficult])
+                        gt_bbox[i, :] = [x1, y1, x2, y2]
+                        gt_class[i, 0] = cname2cid[cname]
+                        gt_score[i, 0] = 1.
+                        difficult[i, 0] = _difficult
+                        i += 1
                     else:
                         logger.warning(
                             'Found an invalid bbox in annotations: xml_file: {}'
                             ', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
                                 xml_file, x1, y1, x2, y2))
-                gt_bbox = np.array(gt_bbox).astype('float32')
-                gt_class = np.array(gt_class).astype('int32')
-                gt_score = np.array(gt_score).astype('float32')
-                difficult = np.array(difficult).astype('int32')
+                gt_bbox = gt_bbox[:i, :]
+                gt_class = gt_class[:i, :]
+                gt_score = gt_score[:i, :]
+                difficult = difficult[:i, :]
 
                 voc_rec = {
                     'im_file': img_file,
@@ -193,7 +196,7 @@ class VOCDataSet(DetDataset):
                     break
         assert ct > 0, 'not found any voc record in %s' % (self.anno_path)
         logger.debug('{} samples in file {}'.format(ct, anno_path))
-        if len(empty_records) > 0:
+        if self.allow_empty and len(empty_records) > 0:
             empty_records = self._sample_empty(empty_records, len(records))
             records += empty_records
         self.roidbs, self.cname2cid = records, cname2cid

+ 3 - 3
paddlex/ppdet/data/transform/atss_assigner.py

@@ -178,8 +178,6 @@ class ATSSAssigner(object):
         """
         bboxes = bboxes[:, :4]
         num_gt, num_bboxes = gt_bboxes.shape[0], bboxes.shape[0]
-        # compute iou between all bbox and gt
-        overlaps = bbox_overlaps(bboxes, gt_bboxes)
 
         # assign 0 by default
         assigned_gt_inds = np.zeros((num_bboxes, ), dtype=np.int64)
@@ -194,8 +192,10 @@ class ATSSAssigner(object):
                 assigned_labels = None
             else:
                 assigned_labels = -np.ones((num_bboxes, ), dtype=np.int64)
-            return assigned_gt_inds, max_overlaps, assigned_labels
+            return assigned_gt_inds, max_overlaps
 
+        # compute iou between all bbox and gt
+        overlaps = bbox_overlaps(bboxes, gt_bboxes)
         # compute center distance between all bbox and gt
         gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
         gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0

+ 81 - 0
paddlex/ppdet/data/transform/operators.py

@@ -36,6 +36,9 @@ import copy
 import logging
 import cv2
 from PIL import Image, ImageDraw
+import pickle
+import threading
+MUTEX = threading.Lock()
 
 from paddlex.ppdet.core.workspace import serializable
 from paddlex.ppdet.modeling import bbox_utils
@@ -150,6 +153,84 @@ class Decode(BaseOperator):
         return sample
 
 
+def _make_dirs(dirname):
+    try:
+        from pathlib import Path
+    except ImportError:
+        from pathlib2 import Path
+    Path(dirname).mkdir(exist_ok=True)
+
+
+@register_op
+class DecodeCache(BaseOperator):
+    def __init__(self, cache_root=None):
+        '''decode image and caching
+        '''
+        super(DecodeCache, self).__init__()
+
+        self.use_cache = False if cache_root is None else True
+        self.cache_root = cache_root
+
+        if cache_root is not None:
+            _make_dirs(cache_root)
+
+    def apply(self, sample, context=None):
+
+        if self.use_cache and os.path.exists(
+                self.cache_path(self.cache_root, sample['im_file'])):
+            path = self.cache_path(self.cache_root, sample['im_file'])
+            im = self.load(path)
+
+        else:
+            if 'image' not in sample:
+                with open(sample['im_file'], 'rb') as f:
+                    sample['image'] = f.read()
+
+            im = sample['image']
+            data = np.frombuffer(im, dtype='uint8')
+            im = cv2.imdecode(data, 1)  # BGR mode, but need RGB mode
+            if 'keep_ori_im' in sample and sample['keep_ori_im']:
+                sample['ori_image'] = im
+            im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+
+            if self.use_cache and not os.path.exists(
+                    self.cache_path(self.cache_root, sample['im_file'])):
+                path = self.cache_path(self.cache_root, sample['im_file'])
+                self.dump(im, path)
+
+        sample['image'] = im
+        sample['h'] = im.shape[0]
+        sample['w'] = im.shape[1]
+
+        sample['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
+        sample['scale_factor'] = np.array([1., 1.], dtype=np.float32)
+
+        return sample
+
+    @staticmethod
+    def cache_path(dir_oot, im_file):
+        return os.path.join(dir_oot, os.path.basename(im_file) + '.pkl')
+
+    @staticmethod
+    def load(path):
+        with open(path, 'rb') as f:
+            im = pickle.load(f)
+        return im
+
+    @staticmethod
+    def dump(obj, path):
+        MUTEX.acquire()
+        try:
+            with open(path, 'wb') as f:
+                pickle.dump(obj, f)
+
+        except Exception as e:
+            logger.warning('dump {} occurs exception {}'.format(path, str(e)))
+
+        finally:
+            MUTEX.release()
+
+
 @register_op
 class Permute(BaseOperator):
     def __init__(self):

+ 1 - 1
paddlex/ppdet/engine/callbacks.py

@@ -26,7 +26,7 @@ import paddle.distributed as dist
 from paddlex.ppdet.utils.checkpoint import save_model
 
 from paddlex.ppdet.utils.logger import setup_logger
-logger = setup_logger('paddlex.ppdet.engine')
+logger = setup_logger('ppdet.engine')
 
 __all__ = ['Callback', 'ComposeCallback', 'LogPrinter', 'Checkpointer']
 

+ 6 - 2
paddlex/ppdet/engine/export_utils.py

@@ -23,7 +23,7 @@ from collections import OrderedDict
 from paddlex.ppdet.data.source.category import get_categories
 
 from paddlex.ppdet.utils.logger import setup_logger
-logger = setup_logger('paddlex.ppdet.engine')
+logger = setup_logger('ppdet.engine')
 
 # Global dictionary
 TRT_MIN_SUBGRAPH = {
@@ -59,6 +59,7 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
 
     label_list = [str(cat) for cat in catid2name.values()]
 
+    fuse_normalize = reader_cfg.get('fuse_normalize', False)
     sample_transforms = reader_cfg['sample_transforms']
     for st in sample_transforms[1:]:
         for key, value in st.items():
@@ -66,6 +67,8 @@ def _parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape):
             if key == 'Resize':
                 if int(image_shape[1]) != -1:
                     value['target_size'] = image_shape[1:]
+            if fuse_normalize and key == 'NormalizeImage':
+                continue
             p.update(value)
             preprocess_list.append(p)
     batch_transforms = reader_cfg.get('batch_transforms', None)
@@ -122,7 +125,8 @@ def _dump_infer_config(config, path, image_shape, model):
             format(infer_arch) +
             'Please set TRT_MIN_SUBGRAPH in ppdet/engine/export_utils.py')
         os._exit(0)
-    if 'Mask' in infer_arch:
+    if 'mask_head' in config[config['architecture']] and config[config[
+            'architecture']]['mask_head']:
         infer_cfg['mask'] = True
     label_arch = 'detection_arch'
     if infer_arch in KEYPOINT_ARCH:

+ 2 - 2
paddlex/ppdet/engine/tracker.py

@@ -333,7 +333,7 @@ class Tracker(object):
             if save_videos:
                 output_video_path = os.path.join(save_dir, '..',
                                                  '{}_vis.mp4'.format(seq))
-                cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
+                cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
                     save_dir, output_video_path)
                 os.system(cmd_str)
                 logger.info('Save video in {}.'.format(output_video_path))
@@ -451,7 +451,7 @@ class Tracker(object):
         if save_videos:
             output_video_path = os.path.join(save_dir, '..',
                                              '{}_vis.mp4'.format(seq))
-            cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
+            cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg {}'.format(
                 save_dir, output_video_path)
             os.system(cmd_str)
             logger.info('Save video in {}'.format(output_video_path))

+ 17 - 4
paddlex/ppdet/engine/trainer.py

@@ -43,7 +43,7 @@ from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, Wife
 from .export_utils import _dump_infer_config
 
 from paddlex.ppdet.utils.logger import setup_logger
-logger = setup_logger('paddlex.ppdet.engine')
+logger = setup_logger('ppdet.engine')
 
 __all__ = ['Trainer']
 
@@ -88,10 +88,18 @@ class Trainer(object):
             self.model = self.cfg.model
             self.is_loaded_weights = True
 
+        #normalize params for deploy
+        self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
+
         self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
         if self.use_ema:
+            ema_decay = self.cfg.get('ema_decay', 0.9998)
+            cycle_epoch = self.cfg.get('cycle_epoch', -1)
             self.ema = ModelEMA(
-                cfg['ema_decay'], self.model, use_thres_step=True)
+                self.model,
+                decay=ema_decay,
+                use_thres_step=True,
+                cycle_epoch=cycle_epoch)
 
         # EvalDataset build with BatchSampler to evaluate in single device
         # TODO: multi-device evaluate
@@ -547,8 +555,13 @@ class Trainer(object):
         if image_shape is None:
             image_shape = [3, -1, -1]
 
-        self.model.eval()
-        if hasattr(self.model, 'deploy'): self.model.deploy = True
+        if hasattr(self.model, 'deploy'):
+            self.model.deploy = True
+        if hasattr(self.model, 'fuse_norm'):
+            self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
+                                                              False)
+        if hasattr(self.cfg, 'lite_deploy'):
+            self.model.lite_deploy = self.cfg.lite_deploy
 
         # Save infer cfg
         _dump_infer_config(self.cfg,

+ 5 - 6
paddlex/ppdet/metrics/mot_metrics.py

@@ -37,7 +37,7 @@ __all__ = ['MOTEvaluator', 'MOTMetric', 'JDEDetMetric', 'KITTIMOTMetric']
 
 def read_mot_results(filename, is_gt=False, is_ignore=False):
     valid_labels = {1}
-    ignore_labels = {2, 7, 8, 12}
+    ignore_labels = {2, 7, 8, 12}  # only in motchallenge datasets like 'MOT16'
     results_dict = dict()
     if os.path.isfile(filename):
         with open(filename, 'r') as f:
@@ -53,11 +53,10 @@ def read_mot_results(filename, is_gt=False, is_ignore=False):
                 box_size = float(linelist[4]) * float(linelist[5])
 
                 if is_gt:
-                    if 'MOT16-' in filename or 'MOT17-' in filename or 'MOT15-' in filename or 'MOT20-' in filename:
-                        label = int(float(linelist[7]))
-                        mark = int(float(linelist[6]))
-                        if mark == 0 or label not in valid_labels:
-                            continue
+                    label = int(float(linelist[7]))
+                    mark = int(float(linelist[6]))
+                    if mark == 0 or label not in valid_labels:
+                        continue
                     score = 1
                 elif is_ignore:
                     if 'MOT16-' in filename or 'MOT17-' in filename or 'MOT15-' in filename or 'MOT20-' in filename:

+ 1 - 1
paddlex/ppdet/model_zoo/model_zoo.py

@@ -36,7 +36,7 @@ MODEL_ZOO_FILENAME = 'MODEL_ZOO'
 
 
 def list_model(filters=[]):
-    model_zoo_file = pkg_resources.resource_filename('paddlex.ppdet.model_zoo',
+    model_zoo_file = pkg_resources.resource_filename('ppdet.model_zoo',
                                                      MODEL_ZOO_FILENAME)
     with open(model_zoo_file) as f:
         model_names = f.read().splitlines()

+ 3 - 3
paddlex/ppdet/model_zoo/tests/test_get_model.py

@@ -18,7 +18,7 @@ from __future__ import print_function
 
 import os
 import paddle
-import paddlex.ppdet
+import paddlex.ppdet as ppdet
 import unittest
 
 # NOTE: weights downloading costs time, we choose
@@ -29,7 +29,7 @@ MODEL_NAME = 'ppyolo/ppyolo_tiny_650e_coco'
 class TestGetConfigFile(unittest.TestCase):
     def test_main(self):
         try:
-            cfg_file = paddlex.ppdet.model_zoo.get_config_file(MODEL_NAME)
+            cfg_file = ppdet.model_zoo.get_config_file(MODEL_NAME)
             assert os.path.isfile(cfg_file)
         except:
             self.assertTrue(False)
@@ -38,7 +38,7 @@ class TestGetConfigFile(unittest.TestCase):
 class TestGetModel(unittest.TestCase):
     def test_main(self):
         try:
-            model = paddlex.ppdet.model_zoo.get_model(MODEL_NAME)
+            model = ppdet.model_zoo.get_model(MODEL_NAME)
             assert isinstance(model, paddle.nn.Layer)
         except:
             self.assertTrue(False)

+ 3 - 3
paddlex/ppdet/model_zoo/tests/test_list_model.py

@@ -17,7 +17,7 @@ from __future__ import division
 from __future__ import print_function
 
 import unittest
-import paddlex.ppdet
+import paddlex.ppdet as ppdet
 
 
 class TestListModel(unittest.TestCase):
@@ -26,7 +26,7 @@ class TestListModel(unittest.TestCase):
 
     def test_main(self):
         try:
-            paddlex.ppdet.model_zoo.list_model(self._filter)
+            ppdet.model_zoo.list_model(self._filter)
             self.assertTrue(True)
         except:
             self.assertTrue(False)
@@ -58,7 +58,7 @@ class TestListModelError(unittest.TestCase):
 
     def test_main(self):
         try:
-            paddlex.ppdet.model_zoo.list_model(self._filter)
+            ppdet.model_zoo.list_model(self._filter)
             self.assertTrue(False)
         except ValueError:
             self.assertTrue(True)

+ 11 - 9
paddlex/ppdet/modeling/architectures/centernet.py

@@ -29,8 +29,8 @@ class CenterNet(BaseArch):
 
     Args:
         backbone (object): backbone instance
-        neck (object): 'CenterDLAFPN' instance
-        head (object): 'CenterHead' instance
+        neck (object): FPN instance, default use 'CenterNetDLAFPN'
+        head (object): 'CenterNetHead' instance
         post_process (object): 'CenterNetPostProcess' instance
         for_mot (bool): whether return other features used in tracking model
 
@@ -39,9 +39,9 @@ class CenterNet(BaseArch):
     __inject__ = ['post_process']
 
     def __init__(self,
-                 backbone='DLA',
-                 neck='CenterDLAFPN',
-                 head='CenterHead',
+                 backbone,
+                 neck='CenterNetDLAFPN',
+                 head='CenterNetHead',
                  post_process='CenterNetPostProcess',
                  for_mot=False):
         super(CenterNet, self).__init__()
@@ -56,16 +56,18 @@ class CenterNet(BaseArch):
         backbone = create(cfg['backbone'])
 
         kwargs = {'input_shape': backbone.out_shape}
-        neck = create(cfg['neck'], **kwargs)
+        neck = cfg['neck'] and create(cfg['neck'], **kwargs)
 
-        kwargs = {'input_shape': neck.out_shape}
+        out_shape = neck and neck.out_shape or backbone.out_shape
+        kwargs = {'input_shape': out_shape}
         head = create(cfg['head'], **kwargs)
 
         return {'backbone': backbone, 'neck': neck, "head": head}
 
     def _forward(self):
-        body_feats = self.backbone(self.inputs)
-        neck_feat = self.neck(body_feats)
+        neck_feat = self.backbone(self.inputs)
+        if self.neck is not None:
+            neck_feat = self.neck(neck_feat)
         head_out = self.head(neck_feat, self.inputs)
         if self.for_mot:
             head_out.update({'neck_feat': neck_feat})

+ 2 - 1
paddlex/ppdet/modeling/architectures/fairmot.py

@@ -53,8 +53,9 @@ class FairMOT(BaseArch):
     @classmethod
     def from_config(cls, cfg, *args, **kwargs):
         detector = create(cfg['detector'])
+        detector_out_shape = detector.neck and detector.neck.out_shape or detector.backbone.out_shape
 
-        kwargs = {'input_shape': detector.neck.out_shape}
+        kwargs = {'input_shape': detector_out_shape}
         reid = create(cfg['reid'], **kwargs)
         loss = create(cfg['loss'])
         tracker = create(cfg['tracker'])

+ 10 - 1
paddlex/ppdet/modeling/architectures/keypoint_hrnet.py

@@ -76,7 +76,12 @@ class TopDownHRNet(BaseArch):
         if self.training:
             return self.loss(hrnet_outputs, self.inputs)
         elif self.deploy:
-            return hrnet_outputs
+            outshape = hrnet_outputs.shape
+            max_idx = paddle.argmax(
+                hrnet_outputs.reshape(
+                    (outshape[0], outshape[1], outshape[2] * outshape[3])),
+                axis=-1)
+            return hrnet_outputs, max_idx
         else:
             if self.flip:
                 self.inputs['image'] = self.inputs['image'].flip([3])
@@ -199,6 +204,10 @@ class HRNetPostProcess(object):
         return coord
 
     def dark_postprocess(self, hm, coords, kernelsize):
+        '''DARK postpocessing, Zhang et al. Distribution-Aware Coordinate
+        Representation for Human Pose Estimation (CVPR 2020).
+        '''
+
         hm = self.gaussian_blur(hm, kernelsize)
         hm = np.maximum(hm, 1e-10)
         hm = np.log(hm)

+ 30 - 1
paddlex/ppdet/modeling/architectures/meta_arch.py

@@ -14,12 +14,41 @@ class BaseArch(nn.Layer):
     def __init__(self, data_format='NCHW'):
         super(BaseArch, self).__init__()
         self.data_format = data_format
+        self.inputs = {}
+        self.fuse_norm = False
+
+    def load_meanstd(self, cfg_transform):
+        self.scale = 1.
+        self.mean = paddle.to_tensor([0.485, 0.456, 0.406]).reshape(
+            (1, 3, 1, 1))
+        self.std = paddle.to_tensor([0.229, 0.224, 0.225]).reshape(
+            (1, 3, 1, 1))
+        for item in cfg_transform:
+            if 'NormalizeImage' in item:
+                self.mean = paddle.to_tensor(item['NormalizeImage'][
+                    'mean']).reshape((1, 3, 1, 1))
+                self.std = paddle.to_tensor(item['NormalizeImage'][
+                    'std']).reshape((1, 3, 1, 1))
+                if item['NormalizeImage']['is_scale']:
+                    self.scale = 1. / 255.
+                break
+        if self.data_format == 'NHWC':
+            self.mean = self.mean.reshape(1, 1, 1, 3)
+            self.std = self.std.reshape(1, 1, 1, 3)
 
     def forward(self, inputs):
         if self.data_format == 'NHWC':
             image = inputs['image']
             inputs['image'] = paddle.transpose(image, [0, 2, 3, 1])
-        self.inputs = inputs
+
+        if self.fuse_norm:
+            image = inputs['image']
+            self.inputs['image'] = (image * self.scale - self.mean) / self.std
+            self.inputs['im_shape'] = inputs['im_shape']
+            self.inputs['scale_factor'] = inputs['scale_factor']
+        else:
+            self.inputs = inputs
+
         self.model_arch()
 
         if self.training:

+ 3 - 3
paddlex/ppdet/modeling/architectures/picodet.py

@@ -41,7 +41,7 @@ class PicoDet(BaseArch):
         self.backbone = backbone
         self.neck = neck
         self.head = head
-        self.deploy = False
+        self.lite_deploy = False
 
     @classmethod
     def from_config(cls, cfg, *args, **kwargs):
@@ -63,7 +63,7 @@ class PicoDet(BaseArch):
         body_feats = self.backbone(self.inputs)
         fpn_feats = self.neck(body_feats)
         head_outs = self.head(fpn_feats)
-        if self.training or self.deploy:
+        if self.training or self.lite_deploy:
             return head_outs
         else:
             im_shape = self.inputs['im_shape']
@@ -83,7 +83,7 @@ class PicoDet(BaseArch):
         return loss
 
     def get_pred(self):
-        if self.deploy:
+        if self.lite_deploy:
             return {'picodet': self._forward()[0]}
         else:
             bbox_pred, bbox_num = self._forward()

+ 10 - 1
paddlex/ppdet/modeling/architectures/ssd.py

@@ -36,11 +36,20 @@ class SSD(BaseArch):
     __category__ = 'architecture'
     __inject__ = ['post_process']
 
-    def __init__(self, backbone, ssd_head, post_process):
+    def __init__(self, backbone, ssd_head, post_process, r34_backbone=False):
         super(SSD, self).__init__()
         self.backbone = backbone
         self.ssd_head = ssd_head
         self.post_process = post_process
+        self.r34_backbone = r34_backbone
+        if self.r34_backbone:
+            from paddlex.ppdet.modeling.backbones.resnet import ResNet
+            assert isinstance(self.backbone, ResNet) and \
+                   self.backbone.depth == 34, \
+                "If you set r34_backbone=True, please use ResNet-34 as backbone."
+            self.backbone.res_layers[2].blocks[
+                0].branch2a.conv._stride = [1, 1]
+            self.backbone.res_layers[2].blocks[0].short.conv._stride = [1, 1]
 
     @classmethod
     def from_config(cls, cfg, *args, **kwargs):

+ 6 - 0
paddlex/ppdet/modeling/backbones/__init__.py

@@ -25,6 +25,9 @@ from . import senet
 from . import res2net
 from . import dla
 from . import shufflenet_v2
+from . import swin_transformer
+from . import lcnet
+from . import hardnet
 
 from .vgg import *
 from .resnet import *
@@ -39,3 +42,6 @@ from .senet import *
 from .res2net import *
 from .dla import *
 from .shufflenet_v2 import *
+from .swin_transformer import *
+from .lcnet import *
+from .hardnet import *

+ 224 - 0
paddlex/ppdet/modeling/backbones/hardnet.py

@@ -0,0 +1,224 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+from paddlex.ppdet.core.workspace import register
+from ..shape_spec import ShapeSpec
+
+__all__ = ['HarDNet']
+
+
+def ConvLayer(in_channels,
+              out_channels,
+              kernel_size=3,
+              stride=1,
+              bias_attr=False):
+    layer = nn.Sequential(
+        ('conv', nn.Conv2D(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=kernel_size // 2,
+            groups=1,
+            bias_attr=bias_attr)), ('norm', nn.BatchNorm2D(out_channels)),
+        ('relu', nn.ReLU6()))
+    return layer
+
+
+def DWConvLayer(in_channels,
+                out_channels,
+                kernel_size=3,
+                stride=1,
+                bias_attr=False):
+    layer = nn.Sequential(
+        ('dwconv', nn.Conv2D(
+            in_channels,
+            out_channels,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=1,
+            groups=out_channels,
+            bias_attr=bias_attr)), ('norm', nn.BatchNorm2D(out_channels)))
+    return layer
+
+
+def CombConvLayer(in_channels, out_channels, kernel_size=1, stride=1):
+    layer = nn.Sequential(
+        ('layer1', ConvLayer(
+            in_channels, out_channels, kernel_size=kernel_size)),
+        ('layer2', DWConvLayer(
+            out_channels, out_channels, stride=stride)))
+    return layer
+
+
+class HarDBlock(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 growth_rate,
+                 grmul,
+                 n_layers,
+                 keepBase=False,
+                 residual_out=False,
+                 dwconv=False):
+        super().__init__()
+        self.keepBase = keepBase
+        self.links = []
+        layers_ = []
+        self.out_channels = 0
+        for i in range(n_layers):
+            outch, inch, link = self.get_link(i + 1, in_channels, growth_rate,
+                                              grmul)
+            self.links.append(link)
+            if dwconv:
+                layers_.append(CombConvLayer(inch, outch))
+            else:
+                layers_.append(ConvLayer(inch, outch))
+
+            if (i % 2 == 0) or (i == n_layers - 1):
+                self.out_channels += outch
+        self.layers = nn.LayerList(layers_)
+
+    def get_out_ch(self):
+        return self.out_channels
+
+    def get_link(self, layer, base_ch, growth_rate, grmul):
+        if layer == 0:
+            return base_ch, 0, []
+        out_channels = growth_rate
+
+        link = []
+        for i in range(10):
+            dv = 2**i
+            if layer % dv == 0:
+                k = layer - dv
+                link.append(k)
+                if i > 0:
+                    out_channels *= grmul
+
+        out_channels = int(int(out_channels + 1) / 2) * 2
+        in_channels = 0
+
+        for i in link:
+            ch, _, _ = self.get_link(i, base_ch, growth_rate, grmul)
+            in_channels += ch
+
+        return out_channels, in_channels, link
+
+    def forward(self, x):
+        layers_ = [x]
+
+        for layer in range(len(self.layers)):
+            link = self.links[layer]
+            tin = []
+            for i in link:
+                tin.append(layers_[i])
+            if len(tin) > 1:
+                x = paddle.concat(tin, 1)
+            else:
+                x = tin[0]
+            out = self.layers[layer](x)
+            layers_.append(out)
+
+        t = len(layers_)
+        out_ = []
+        for i in range(t):
+            if (i == 0 and self.keepBase) or (i == t - 1) or (i % 2 == 1):
+                out_.append(layers_[i])
+        out = paddle.concat(out_, 1)
+
+        return out
+
+
+@register
+class HarDNet(nn.Layer):
+    def __init__(self, depth_wise=False, return_idx=[1, 3, 8, 13], arch=85):
+        super(HarDNet, self).__init__()
+        assert arch in [39, 68, 85], "HarDNet-{} not support.".format(arch)
+        if arch == 85:
+            first_ch = [48, 96]
+            second_kernel = 3
+            ch_list = [192, 256, 320, 480, 720]
+            grmul = 1.7
+            gr = [24, 24, 28, 36, 48]
+            n_layers = [8, 16, 16, 16, 16]
+        elif arch == 68:
+            first_ch = [32, 64]
+            second_kernel = 3
+            ch_list = [128, 256, 320, 640]
+            grmul = 1.7
+            gr = [14, 16, 20, 40]
+            n_layers = [8, 16, 16, 16]
+
+        self.return_idx = return_idx
+        self._out_channels = [96, 214, 458, 784]
+
+        avg_pool = True
+        if depth_wise:
+            second_kernel = 1
+            avg_pool = False
+
+        blks = len(n_layers)
+        self.base = nn.LayerList([])
+
+        # First Layer: Standard Conv3x3, Stride=2
+        self.base.append(
+            ConvLayer(
+                in_channels=3,
+                out_channels=first_ch[0],
+                kernel_size=3,
+                stride=2,
+                bias_attr=False))
+
+        # Second Layer
+        self.base.append(
+            ConvLayer(
+                first_ch[0], first_ch[1], kernel_size=second_kernel))
+
+        # Avgpooling or DWConv3x3 downsampling
+        if avg_pool:
+            self.base.append(nn.AvgPool2D(kernel_size=3, stride=2, padding=1))
+        else:
+            self.base.append(DWConvLayer(first_ch[1], first_ch[1], stride=2))
+
+        # Build all HarDNet blocks
+        ch = first_ch[1]
+        for i in range(blks):
+            blk = HarDBlock(ch, gr[i], grmul, n_layers[i], dwconv=depth_wise)
+            ch = blk.out_channels
+            self.base.append(blk)
+
+            if i != blks - 1:
+                self.base.append(ConvLayer(ch, ch_list[i], kernel_size=1))
+            ch = ch_list[i]
+            if i == 0:
+                self.base.append(
+                    nn.AvgPool2D(
+                        kernel_size=2, stride=2, ceil_mode=True))
+            elif i != blks - 1 and i != 1 and i != 3:
+                self.base.append(nn.AvgPool2D(kernel_size=2, stride=2))
+
+    def forward(self, inputs):
+        x = inputs['image']
+        outs = []
+        for i, layer in enumerate(self.base):
+            x = layer(x)
+            if i in self.return_idx:
+                outs.append(x)
+        return outs
+
+    @property
+    def out_shape(self):
+        return [ShapeSpec(channels=self._out_channels[i]) for i in range(4)]

+ 17 - 3
paddlex/ppdet/modeling/backbones/hrnet.py

@@ -569,6 +569,7 @@ class HRNet(nn.Layer):
         freeze_norm (bool): whether to freeze norm in HRNet
         norm_decay (float): weight decay for normalization layer weights
         return_idx (List): the stage to return
+        upsample (bool): whether to upsample and concat the backbone feats
     """
 
     def __init__(self,
@@ -577,7 +578,8 @@ class HRNet(nn.Layer):
                  freeze_at=0,
                  freeze_norm=True,
                  norm_decay=0.,
-                 return_idx=[0, 1, 2, 3]):
+                 return_idx=[0, 1, 2, 3],
+                 upsample=False):
         super(HRNet, self).__init__()
 
         self.width = width
@@ -588,6 +590,7 @@ class HRNet(nn.Layer):
         assert len(return_idx) > 0, "need one or more return index"
         self.freeze_at = freeze_at
         self.return_idx = return_idx
+        self.upsample = upsample
 
         self.channels = {
             18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
@@ -602,8 +605,8 @@ class HRNet(nn.Layer):
 
         channels_2, channels_3, channels_4 = self.channels[width]
         num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
-        self._out_channels = channels_4
-        self._out_strides = [4, 8, 16, 32]
+        self._out_channels = [sum(channels_4)] if self.upsample else channels_4
+        self._out_strides = [4] if self.upsample else [4, 8, 16, 32]
 
         self.conv_layer1_1 = ConvNormLayer(
             ch_in=3,
@@ -695,6 +698,15 @@ class HRNet(nn.Layer):
 
         st4 = self.st4(tr3)
 
+        if self.upsample:
+            # Upsampling
+            x0_h, x0_w = st4[0].shape[2:4]
+            x1 = F.upsample(st4[1], size=(x0_h, x0_w), mode='bilinear')
+            x2 = F.upsample(st4[2], size=(x0_h, x0_w), mode='bilinear')
+            x3 = F.upsample(st4[3], size=(x0_h, x0_w), mode='bilinear')
+            x = paddle.concat([st4[0], x1, x2, x3], 1)
+            return x
+
         res = []
         for i, layer in enumerate(st4):
             if i == self.freeze_at:
@@ -706,6 +718,8 @@ class HRNet(nn.Layer):
 
     @property
     def out_shape(self):
+        if self.upsample:
+            self.return_idx = [0]
         return [
             ShapeSpec(
                 channels=self._out_channels[i], stride=self._out_strides[i])

+ 259 - 0
paddlex/ppdet/modeling/backbones/lcnet.py

@@ -0,0 +1,259 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+from paddle import ParamAttr
+from paddle.nn import AdaptiveAvgPool2D, BatchNorm, Conv2D, Dropout, Linear
+from paddle.regularizer import L2Decay
+from paddle.nn.initializer import KaimingNormal
+
+from paddlex.ppdet.core.workspace import register, serializable
+from numbers import Integral
+from ..shape_spec import ShapeSpec
+
+__all__ = ['LCNet']
+
+NET_CONFIG = {
+    "blocks2":
+    #k, in_c, out_c, s, use_se
+    [[3, 16, 32, 1, False], ],
+    "blocks3": [
+        [3, 32, 64, 2, False],
+        [3, 64, 64, 1, False],
+    ],
+    "blocks4": [
+        [3, 64, 128, 2, False],
+        [3, 128, 128, 1, False],
+    ],
+    "blocks5": [
+        [3, 128, 256, 2, False],
+        [5, 256, 256, 1, False],
+        [5, 256, 256, 1, False],
+        [5, 256, 256, 1, False],
+        [5, 256, 256, 1, False],
+        [5, 256, 256, 1, False],
+    ],
+    "blocks6": [[5, 256, 512, 2, True], [5, 512, 512, 1, True]]
+}
+
+
+def make_divisible(v, divisor=8, min_value=None):
+    if min_value is None:
+        min_value = divisor
+    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
+    if new_v < 0.9 * v:
+        new_v += divisor
+    return new_v
+
+
+class ConvBNLayer(nn.Layer):
+    def __init__(self,
+                 num_channels,
+                 filter_size,
+                 num_filters,
+                 stride,
+                 num_groups=1):
+        super().__init__()
+
+        self.conv = Conv2D(
+            in_channels=num_channels,
+            out_channels=num_filters,
+            kernel_size=filter_size,
+            stride=stride,
+            padding=(filter_size - 1) // 2,
+            groups=num_groups,
+            weight_attr=ParamAttr(initializer=KaimingNormal()),
+            bias_attr=False)
+
+        self.bn = BatchNorm(
+            num_filters,
+            param_attr=ParamAttr(regularizer=L2Decay(0.0)),
+            bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
+        self.hardswish = nn.Hardswish()
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.hardswish(x)
+        return x
+
+
+class DepthwiseSeparable(nn.Layer):
+    def __init__(self,
+                 num_channels,
+                 num_filters,
+                 stride,
+                 dw_size=3,
+                 use_se=False):
+        super().__init__()
+        self.use_se = use_se
+        self.dw_conv = ConvBNLayer(
+            num_channels=num_channels,
+            num_filters=num_channels,
+            filter_size=dw_size,
+            stride=stride,
+            num_groups=num_channels)
+        if use_se:
+            self.se = SEModule(num_channels)
+        self.pw_conv = ConvBNLayer(
+            num_channels=num_channels,
+            filter_size=1,
+            num_filters=num_filters,
+            stride=1)
+
+    def forward(self, x):
+        x = self.dw_conv(x)
+        if self.use_se:
+            x = self.se(x)
+        x = self.pw_conv(x)
+        return x
+
+
+class SEModule(nn.Layer):
+    def __init__(self, channel, reduction=4):
+        super().__init__()
+        self.avg_pool = AdaptiveAvgPool2D(1)
+        self.conv1 = Conv2D(
+            in_channels=channel,
+            out_channels=channel // reduction,
+            kernel_size=1,
+            stride=1,
+            padding=0)
+        self.relu = nn.ReLU()
+        self.conv2 = Conv2D(
+            in_channels=channel // reduction,
+            out_channels=channel,
+            kernel_size=1,
+            stride=1,
+            padding=0)
+        self.hardsigmoid = nn.Hardsigmoid()
+
+    def forward(self, x):
+        identity = x
+        x = self.avg_pool(x)
+        x = self.conv1(x)
+        x = self.relu(x)
+        x = self.conv2(x)
+        x = self.hardsigmoid(x)
+        x = paddle.multiply(x=identity, y=x)
+        return x
+
+
+@register
+@serializable
+class LCNet(nn.Layer):
+    def __init__(self, scale=1.0, feature_maps=[3, 4, 5]):
+        super().__init__()
+        self.scale = scale
+        self.feature_maps = feature_maps
+
+        out_channels = []
+
+        self.conv1 = ConvBNLayer(
+            num_channels=3,
+            filter_size=3,
+            num_filters=make_divisible(16 * scale),
+            stride=2)
+
+        self.blocks2 = nn.Sequential(*[
+            DepthwiseSeparable(
+                num_channels=make_divisible(in_c * scale),
+                num_filters=make_divisible(out_c * scale),
+                dw_size=k,
+                stride=s,
+                use_se=se)
+            for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks2"])
+        ])
+
+        self.blocks3 = nn.Sequential(*[
+            DepthwiseSeparable(
+                num_channels=make_divisible(in_c * scale),
+                num_filters=make_divisible(out_c * scale),
+                dw_size=k,
+                stride=s,
+                use_se=se)
+            for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks3"])
+        ])
+
+        out_channels.append(
+            make_divisible(NET_CONFIG["blocks3"][-1][2] * scale))
+
+        self.blocks4 = nn.Sequential(*[
+            DepthwiseSeparable(
+                num_channels=make_divisible(in_c * scale),
+                num_filters=make_divisible(out_c * scale),
+                dw_size=k,
+                stride=s,
+                use_se=se)
+            for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks4"])
+        ])
+
+        out_channels.append(
+            make_divisible(NET_CONFIG["blocks4"][-1][2] * scale))
+
+        self.blocks5 = nn.Sequential(*[
+            DepthwiseSeparable(
+                num_channels=make_divisible(in_c * scale),
+                num_filters=make_divisible(out_c * scale),
+                dw_size=k,
+                stride=s,
+                use_se=se)
+            for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks5"])
+        ])
+
+        out_channels.append(
+            make_divisible(NET_CONFIG["blocks5"][-1][2] * scale))
+
+        self.blocks6 = nn.Sequential(*[
+            DepthwiseSeparable(
+                num_channels=make_divisible(in_c * scale),
+                num_filters=make_divisible(out_c * scale),
+                dw_size=k,
+                stride=s,
+                use_se=se)
+            for i, (k, in_c, out_c, s, se) in enumerate(NET_CONFIG["blocks6"])
+        ])
+
+        out_channels.append(
+            make_divisible(NET_CONFIG["blocks6"][-1][2] * scale))
+        self._out_channels = [
+            ch for idx, ch in enumerate(out_channels)
+            if idx + 2 in feature_maps
+        ]
+
+    def forward(self, inputs):
+        x = inputs['image']
+        outs = []
+
+        x = self.conv1(x)
+        x = self.blocks2(x)
+        x = self.blocks3(x)
+        outs.append(x)
+        x = self.blocks4(x)
+        outs.append(x)
+        x = self.blocks5(x)
+        outs.append(x)
+        x = self.blocks6(x)
+        outs.append(x)
+        outs = [o for i, o in enumerate(outs) if i + 2 in self.feature_maps]
+        return outs
+
+    @property
+    def out_shape(self):
+        return [ShapeSpec(channels=c) for c in self._out_channels]

+ 7 - 23
paddlex/ppdet/modeling/backbones/shufflenet_v2.py

@@ -21,6 +21,7 @@ import paddle.nn as nn
 from paddle import ParamAttr
 from paddle.nn import Conv2D, MaxPool2D, AdaptiveAvgPool2D, BatchNorm
 from paddle.nn.initializer import KaimingNormal
+from paddle.regularizer import L2Decay
 
 from paddlex.ppdet.core.workspace import register, serializable
 from numbers import Integral
@@ -50,7 +51,11 @@ class ConvBNLayer(nn.Layer):
             weight_attr=ParamAttr(initializer=KaimingNormal()),
             bias_attr=False)
 
-        self._batch_norm = BatchNorm(out_channels, act=act)
+        self._batch_norm = BatchNorm(
+            out_channels,
+            param_attr=ParamAttr(regularizer=L2Decay(0.0)),
+            bias_attr=ParamAttr(regularizer=L2Decay(0.0)),
+            act=act)
 
     def forward(self, inputs):
         y = self._conv(inputs)
@@ -159,14 +164,9 @@ class InvertedResidualDS(nn.Layer):
 @register
 @serializable
 class ShuffleNetV2(nn.Layer):
-    def __init__(self,
-                 scale=1.0,
-                 act="relu",
-                 feature_maps=[5, 13, 17],
-                 with_last_conv=False):
+    def __init__(self, scale=1.0, act="relu", feature_maps=[5, 13, 17]):
         super(ShuffleNetV2, self).__init__()
         self.scale = scale
-        self.with_last_conv = with_last_conv
         if isinstance(feature_maps, Integral):
             feature_maps = [feature_maps]
         self.feature_maps = feature_maps
@@ -226,19 +226,6 @@ class ShuffleNetV2(nn.Layer):
                 self._update_out_channels(stage_out_channels[stage_id + 2],
                                           self._feature_idx, self.feature_maps)
 
-        if self.with_last_conv:
-            # last_conv
-            self._last_conv = ConvBNLayer(
-                in_channels=stage_out_channels[-2],
-                out_channels=stage_out_channels[-1],
-                kernel_size=1,
-                stride=1,
-                padding=0,
-                act=act)
-            self._feature_idx += 1
-            self._update_out_channels(stage_out_channels[-1],
-                                      self._feature_idx, self.feature_maps)
-
     def _update_out_channels(self, channel, feature_idx, feature_maps):
         if feature_idx in feature_maps:
             self._out_channels.append(channel)
@@ -252,9 +239,6 @@ class ShuffleNetV2(nn.Layer):
             if i + 2 in self.feature_maps:
                 outs.append(y)
 
-        if self.with_last_conv:
-            y = self._last_conv(y)
-            outs.append(y)
         return outs
 
     @property

+ 737 - 0
paddlex/ppdet/modeling/backbones/swin_transformer.py

@@ -0,0 +1,737 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle.nn.initializer import TruncatedNormal, Constant, Assign
+from paddlex.ppdet.modeling.shape_spec import ShapeSpec
+from paddlex.ppdet.core.workspace import register, serializable
+import numpy as np
+
+# Common initializations
+ones_ = Constant(value=1.)
+zeros_ = Constant(value=0.)
+trunc_normal_ = TruncatedNormal(std=.02)
+
+
+# Common Functions
+def to_2tuple(x):
+    return tuple([x] * 2)
+
+
+def add_parameter(layer, datas, name=None):
+    parameter = layer.create_parameter(
+        shape=(datas.shape), default_initializer=Assign(datas))
+    if name:
+        layer.add_parameter(name, parameter)
+    return parameter
+
+
+# Common Layers
+def drop_path(x, drop_prob=0., training=False):
+    """
+        Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+        the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+        See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
+    """
+    if drop_prob == 0. or not training:
+        return x
+    keep_prob = paddle.to_tensor(1 - drop_prob)
+    shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
+    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
+    random_tensor = paddle.floor(random_tensor)  # binarize
+    output = x.divide(keep_prob) * random_tensor
+    return output
+
+
+class DropPath(nn.Layer):
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)
+
+
+class Identity(nn.Layer):
+    def __init__(self):
+        super(Identity, self).__init__()
+
+    def forward(self, input):
+        return input
+
+
+class Mlp(nn.Layer):
+    def __init__(self,
+                 in_features,
+                 hidden_features=None,
+                 out_features=None,
+                 act_layer=nn.GELU,
+                 drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.reshape(
+        [B, H // window_size, window_size, W // window_size, window_size, C])
+    windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape(
+        [-1, window_size, window_size, C])
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.reshape(
+        [B, H // window_size, W // window_size, window_size, window_size, -1])
+    x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, H, W, -1])
+    return x
+
+
+class WindowAttention(nn.Layer):
+    """ Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(self,
+                 dim,
+                 window_size,
+                 num_heads,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 attn_drop=0.,
+                 proj_drop=0.):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim**-0.5
+
+        # define a parameter table of relative position bias
+        self.relative_position_bias_table = add_parameter(
+            self,
+            paddle.zeros(((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
+                          num_heads)))  # 2*Wh-1 * 2*Ww-1, nH
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = paddle.arange(self.window_size[0])
+        coords_w = paddle.arange(self.window_size[1])
+        coords = paddle.stack(paddle.meshgrid(
+            [coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = paddle.flatten(coords, 1)  # 2, Wh*Ww
+        coords_flatten_1 = coords_flatten.unsqueeze(axis=2)
+        coords_flatten_2 = coords_flatten.unsqueeze(axis=1)
+        relative_coords = coords_flatten_1 - coords_flatten_2
+        relative_coords = relative_coords.transpose(
+            [1, 2, 0])  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += self.window_size[
+            0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        self.relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        self.register_buffer("relative_position_index",
+                             self.relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table)
+        self.softmax = nn.Softmax(axis=-1)
+
+    def forward(self, x, mask=None):
+        """ Forward function.
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        B_, N, C = x.shape
+        qkv = self.qkv(x).reshape(
+            [B_, N, 3, self.num_heads, C // self.num_heads]).transpose(
+                [2, 0, 3, 1, 4])
+        q, k, v = qkv[0], qkv[1], qkv[2]
+
+        q = q * self.scale
+        attn = paddle.mm(q, k.transpose([0, 1, 3, 2]))
+
+        index = self.relative_position_index.reshape([-1])
+
+        relative_position_bias = paddle.index_select(
+            self.relative_position_bias_table, index)
+        relative_position_bias = relative_position_bias.reshape([
+            self.window_size[0] * self.window_size[1],
+            self.window_size[0] * self.window_size[1], -1
+        ])  # Wh*Ww,Wh*Ww,nH
+        relative_position_bias = relative_position_bias.transpose(
+            [2, 0, 1])  # nH, Wh*Ww, Wh*Ww
+        attn = attn + relative_position_bias.unsqueeze(0)
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.reshape([B_ // nW, nW, self.num_heads, N, N
+                                 ]) + mask.unsqueeze(1).unsqueeze(0)
+            attn = attn.reshape([-1, self.num_heads, N, N])
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+
+        # x = (attn @ v).transpose(1, 2).reshape([B_, N, C])
+        x = paddle.mm(attn, v).transpose([0, 2, 1, 3]).reshape([B_, N, C])
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class SwinTransformerBlock(nn.Layer):
+    """ Swin Transformer Block.
+    Args:
+        dim (int): Number of input channels.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Layer, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Layer, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self,
+                 dim,
+                 num_heads,
+                 window_size=7,
+                 shift_size=0,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 act_layer=nn.GELU,
+                 norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size
+        self.mlp_ratio = mlp_ratio
+        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim,
+            window_size=to_2tuple(self.window_size),
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+            attn_drop=attn_drop,
+            proj_drop=drop)
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0. else Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(in_features=dim,
+                       hidden_features=mlp_hidden_dim,
+                       act_layer=act_layer,
+                       drop=drop)
+
+        self.H = None
+        self.W = None
+
+    def forward(self, x, mask_matrix):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+            mask_matrix: Attention mask for cyclic shift.
+        """
+        B, L, C = x.shape
+        H, W = self.H, self.W
+        assert L == H * W, "input feature has wrong size"
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.reshape([B, H, W, C])
+
+        # pad feature maps to multiples of window size
+        pad_l = pad_t = 0
+        pad_r = (self.window_size - W % self.window_size) % self.window_size
+        pad_b = (self.window_size - H % self.window_size) % self.window_size
+        x = F.pad(x, [0, pad_l, 0, pad_b, 0, pad_r, 0, pad_t])
+        _, Hp, Wp, _ = x.shape
+
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_x = paddle.roll(
+                x, shifts=(-self.shift_size, -self.shift_size), axis=(1, 2))
+            attn_mask = mask_matrix
+        else:
+            shifted_x = x
+            attn_mask = None
+
+        # partition windows
+        x_windows = window_partition(
+            shifted_x, self.window_size)  # nW*B, window_size, window_size, C
+        x_windows = x_windows.reshape(
+            [-1, self.window_size * self.window_size,
+             C])  # nW*B, window_size*window_size, C
+
+        # W-MSA/SW-MSA
+        attn_windows = self.attn(
+            x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C
+
+        # merge windows
+        attn_windows = attn_windows.reshape(
+            [-1, self.window_size, self.window_size, C])
+        shifted_x = window_reverse(attn_windows, self.window_size, Hp,
+                                   Wp)  # B H' W' C
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            x = paddle.roll(
+                shifted_x,
+                shifts=(self.shift_size, self.shift_size),
+                axis=(1, 2))
+        else:
+            x = shifted_x
+
+        if pad_r > 0 or pad_b > 0:
+            x = x[:, :H, :W, :]
+
+        x = x.reshape([B, H * W, C])
+
+        # FFN
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+        return x
+
+
+class PatchMerging(nn.Layer):
+    r""" Patch Merging Layer.
+    Args:
+        dim (int): Number of input channels.
+        norm_layer (nn.Layer, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x, H, W):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        x = x.reshape([B, H, W, C])
+
+        # padding
+        pad_input = (H % 2 == 1) or (W % 2 == 1)
+        if pad_input:
+            x = F.pad(x, [0, 0, 0, W % 2, 0, H % 2])
+
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = paddle.concat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.reshape([B, H * W // 4, 4 * C])  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+
+class BasicLayer(nn.Layer):
+    """ A basic Swin Transformer layer for one stage.
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Layer, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Layer | None, optional): Downsample layer at the end of the layer. Default: None
+    """
+
+    def __init__(self,
+                 dim,
+                 depth,
+                 num_heads,
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop=0.,
+                 attn_drop=0.,
+                 drop_path=0.,
+                 norm_layer=nn.LayerNorm,
+                 downsample=None):
+        super().__init__()
+        self.window_size = window_size
+        self.shift_size = window_size // 2
+        self.depth = depth
+
+        # build blocks
+        self.blocks = nn.LayerList([
+            SwinTransformerBlock(
+                dim=dim,
+                num_heads=num_heads,
+                window_size=window_size,
+                shift_size=0 if (i % 2 == 0) else window_size // 2,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop,
+                attn_drop=attn_drop,
+                drop_path=drop_path[i]
+                if isinstance(drop_path, np.ndarray) else drop_path,
+                norm_layer=norm_layer) for i in range(depth)
+        ])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x, H, W):
+        """ Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+
+        # calculate attention mask for SW-MSA
+        Hp = int(np.ceil(H / self.window_size)) * self.window_size
+        Wp = int(np.ceil(W / self.window_size)) * self.window_size
+        img_mask = paddle.fluid.layers.zeros(
+            [1, Hp, Wp, 1], dtype='float32')  # 1 Hp Wp 1
+        h_slices = (slice(0, -self.window_size),
+                    slice(-self.window_size, -self.shift_size),
+                    slice(-self.shift_size, None))
+        w_slices = (slice(0, -self.window_size),
+                    slice(-self.window_size, -self.shift_size),
+                    slice(-self.shift_size, None))
+        cnt = 0
+        for h in h_slices:
+            for w in w_slices:
+                img_mask[:, h, w, :] = cnt
+                cnt += 1
+        mask_windows = window_partition(
+            img_mask, self.window_size)  # nW, window_size, window_size, 1
+        mask_windows = mask_windows.reshape(
+            [-1, self.window_size * self.window_size])
+        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+        huns = -100.0 * paddle.ones_like(attn_mask)
+        attn_mask = huns * (attn_mask != 0).astype("float32")
+
+        for blk in self.blocks:
+            blk.H, blk.W = H, W
+            x = blk(x, attn_mask)
+        if self.downsample is not None:
+            x_down = self.downsample(x, H, W)
+            Wh, Ww = (H + 1) // 2, (W + 1) // 2
+            return x, H, W, x_down, Wh, Ww
+        else:
+            return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Layer):
+    """ Image to Patch Embedding
+    Args:
+        patch_size (int): Patch token size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        norm_layer (nn.Layer, optional): Normalization layer. Default: None
+    """
+
+    def __init__(self, patch_size=4, in_chans=3, embed_dim=96,
+                 norm_layer=None):
+        super().__init__()
+        patch_size = to_2tuple(patch_size)
+        self.patch_size = patch_size
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.proj = nn.Conv2D(
+            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
+
+    def forward(self, x):
+        B, C, H, W = x.shape
+        # assert [H, W] == self.img_size[:2], "Input image size ({H}*{W}) doesn't match model ({}*{}).".format(H, W, self.img_size[0], self.img_size[1])
+        if W % self.patch_size[1] != 0:
+            x = F.pad(x, [0, self.patch_size[1] - W % self.patch_size[1]])
+        if H % self.patch_size[0] != 0:
+            x = F.pad(x,
+                      [0, 0, 0, self.patch_size[0] - H % self.patch_size[0]])
+
+        x = self.proj(x)
+        if self.norm is not None:
+            _, _, Wh, Ww = x.shape
+            x = x.flatten(2).transpose([0, 2, 1])
+            x = self.norm(x)
+            x = x.transpose([0, 2, 1]).reshape([-1, self.embed_dim, Wh, Ww])
+
+        return x
+
+
+@register
+@serializable
+class SwinTransformer(nn.Layer):
+    """ Swin Transformer
+        A PaddlePaddle impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
+          https://arxiv.org/pdf/2103.14030
+
+    Args:
+        img_size (int | tuple(int)): Input image size. Default 224
+        patch_size (int | tuple(int)): Patch size. Default: 4
+        in_chans (int): Number of input image channels. Default: 3
+        num_classes (int): Number of classes for classification head. Default: 1000
+        embed_dim (int): Patch embedding dimension. Default: 96
+        depths (tuple(int)): Depth of each Swin Transformer layer.
+        num_heads (tuple(int)): Number of attention heads in different layers.
+        window_size (int): Window size. Default: 7
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+        drop_rate (float): Dropout rate. Default: 0
+        attn_drop_rate (float): Attention dropout rate. Default: 0
+        drop_path_rate (float): Stochastic depth rate. Default: 0.1
+        norm_layer (nn.Layer): Normalization layer. Default: nn.LayerNorm.
+        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True
+    """
+
+    def __init__(self,
+                 pretrain_img_size=224,
+                 patch_size=4,
+                 in_chans=3,
+                 embed_dim=96,
+                 depths=[2, 2, 6, 2],
+                 num_heads=[3, 6, 12, 24],
+                 window_size=7,
+                 mlp_ratio=4.,
+                 qkv_bias=True,
+                 qk_scale=None,
+                 drop_rate=0.,
+                 attn_drop_rate=0.,
+                 drop_path_rate=0.2,
+                 norm_layer=nn.LayerNorm,
+                 ape=False,
+                 patch_norm=True,
+                 out_indices=(0, 1, 2, 3),
+                 frozen_stages=-1,
+                 pretrained=None):
+        super(SwinTransformer, self).__init__()
+
+        self.pretrain_img_size = pretrain_img_size
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.ape = ape
+        self.patch_norm = patch_norm
+        self.out_indices = out_indices
+        self.frozen_stages = frozen_stages
+
+        # split image into non-overlapping patches
+        self.patch_embed = PatchEmbed(
+            patch_size=patch_size,
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None)
+
+        # absolute position embedding
+        if self.ape:
+            pretrain_img_size = to_2tuple(pretrain_img_size)
+            patch_size = to_2tuple(patch_size)
+            patches_resolution = [
+                pretrain_img_size[0] // patch_size[0],
+                pretrain_img_size[1] // patch_size[1]
+            ]
+
+            self.absolute_pos_embed = add_parameter(
+                self,
+                paddle.zeros((1, embed_dim, patches_resolution[0],
+                              patches_resolution[1])))
+            trunc_normal_(self.absolute_pos_embed)
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = np.linspace(0, drop_path_rate,
+                          sum(depths))  # stochastic depth decay rule
+
+        # build layers
+        self.layers = nn.LayerList()
+        for i_layer in range(self.num_layers):
+            layer = BasicLayer(
+                dim=int(embed_dim * 2**i_layer),
+                depth=depths[i_layer],
+                num_heads=num_heads[i_layer],
+                window_size=window_size,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+                norm_layer=norm_layer,
+                downsample=PatchMerging
+                if (i_layer < self.num_layers - 1) else None)
+            self.layers.append(layer)
+
+        num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
+        self.num_features = num_features
+
+        # add a norm layer for each output
+        for i_layer in out_indices:
+            layer = norm_layer(num_features[i_layer])
+            layer_name = f'norm{i_layer}'
+            self.add_sublayer(layer_name, layer)
+
+        self.apply(self._init_weights)
+        self._freeze_stages()
+        if pretrained:
+            if 'http' in pretrained:  #URL
+                path = paddle.utils.download.get_weights_path_from_url(
+                    pretrained)
+            else:  #model in local path
+                path = pretrained
+            self.set_state_dict(paddle.load(path))
+
+    def _freeze_stages(self):
+        if self.frozen_stages >= 0:
+            self.patch_embed.eval()
+            for param in self.patch_embed.parameters():
+                param.requires_grad = False
+
+        if self.frozen_stages >= 1 and self.ape:
+            self.absolute_pos_embed.requires_grad = False
+
+        if self.frozen_stages >= 2:
+            self.pos_drop.eval()
+            for i in range(0, self.frozen_stages - 1):
+                m = self.layers[i]
+                m.eval()
+                for param in m.parameters():
+                    param.requires_grad = False
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                zeros_(m.bias)
+        elif isinstance(m, nn.LayerNorm):
+            zeros_(m.bias)
+            ones_(m.weight)
+
+    def forward(self, x):
+        """Forward function."""
+        x = self.patch_embed(x['image'])
+        _, _, Wh, Ww = x.shape
+        if self.ape:
+            # interpolate the position embedding to the corresponding size
+            absolute_pos_embed = F.interpolate(
+                self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
+            x = (x + absolute_pos_embed).flatten(2).transpose([0, 2, 1])
+        else:
+            x = x.flatten(2).transpose([0, 2, 1])
+        x = self.pos_drop(x)
+        outs = []
+        for i in range(self.num_layers):
+            layer = self.layers[i]
+            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+            if i in self.out_indices:
+                norm_layer = getattr(self, f'norm{i}')
+                x_out = norm_layer(x_out)
+                out = x_out.reshape(
+                    (-1, H, W, self.num_features[i])).transpose((0, 3, 1, 2))
+                outs.append(out)
+
+        return tuple(outs)
+
+    @property
+    def out_shape(self):
+        out_strides = [4, 8, 16, 32]
+        return [
+            ShapeSpec(
+                channels=self.num_features[i], stride=out_strides[i])
+            for i in self.out_indices
+        ]

+ 2 - 1
paddlex/ppdet/modeling/heads/centernet_head.py

@@ -98,7 +98,8 @@ class CenterNetHead(nn.Layer):
                 stride=1,
                 padding=0,
                 bias=True))
-        self.heatmap[2].conv.bias[:] = -2.19
+        with paddle.no_grad():
+            self.heatmap[2].conv.bias[:] = -2.19
         self.size = nn.Sequential(
             ConvLayer(
                 in_channels, head_planes, kernel_size=3, padding=1, bias=True),

+ 5 - 3
paddlex/ppdet/modeling/heads/detr_head.py

@@ -311,9 +311,11 @@ class DeformableDETRHead(nn.Layer):
         linear_init_(self.score_head)
         constant_(self.score_head.bias, -4.595)
         constant_(self.bbox_head.layers[-1].weight)
-        bias = paddle.zeros_like(self.bbox_head.layers[-1].bias)
-        bias[2:] = -2.0
-        self.bbox_head.layers[-1].bias.set_value(bias)
+
+        with paddle.no_grad():
+            bias = paddle.zeros_like(self.bbox_head.layers[-1].bias)
+            bias[2:] = -2.0
+            self.bbox_head.layers[-1].bias.set_value(bias)
 
     @classmethod
     def from_config(cls, cfg, hidden_dim, nhead, input_shape):

+ 22 - 15
paddlex/ppdet/modeling/heads/gfl_head.py

@@ -245,6 +245,9 @@ class GFLHead(nn.Layer):
             if self.dgqp_module:
                 quality_score = self.dgqp_module(bbox_reg)
                 cls_logits = F.sigmoid(cls_logits) * quality_score
+            if not self.training:
+                cls_logits = F.sigmoid(cls_logits.transpose([0, 2, 3, 1]))
+                bbox_reg = bbox_reg.transpose([0, 2, 3, 1])
             cls_logits_list.append(cls_logits)
             bboxes_reg_list.append(bbox_reg)
 
@@ -288,6 +291,11 @@ class GFLHead(nn.Layer):
         bbox_targets_list = self._images_to_levels(gt_meta['bbox_targets'],
                                                    num_level_anchors)
         num_total_pos = sum(gt_meta['pos_num'])
+        try:
+            num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone(
+            )) / paddle.distributed.get_world_size()
+        except:
+            num_total_pos = max(num_total_pos, 1)
 
         loss_bbox_list, loss_dfl_list, loss_qfl_list, avg_factor = [], [], [], []
         for cls_score, bbox_pred, grid_cells, labels, label_weights, bbox_targets, stride in zip(
@@ -317,7 +325,7 @@ class GFLHead(nn.Layer):
 
                 weight_targets = F.sigmoid(cls_score.detach())
                 weight_targets = paddle.gather(
-                    weight_targets.max(axis=1), pos_inds, axis=0)
+                    weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0)
                 pos_bbox_pred_corners = self.distribution_project(
                     pos_bbox_pred)
                 pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers,
@@ -335,20 +343,18 @@ class GFLHead(nn.Layer):
                 # regression loss
                 loss_bbox = paddle.sum(
                     self.loss_bbox(pos_decode_bbox_pred,
-                                   pos_decode_bbox_targets) *
-                    weight_targets.mean(axis=-1))
+                                   pos_decode_bbox_targets) * weight_targets)
 
                 # dfl loss
                 loss_dfl = self.loss_dfl(
                     pred_corners,
                     target_corners,
-                    weight=weight_targets.unsqueeze(-1).expand(
-                        [-1, 4]).reshape([-1]),
+                    weight=weight_targets.expand([-1, 4]).reshape([-1]),
                     avg_factor=4.0)
             else:
                 loss_bbox = bbox_pred.sum() * 0
                 loss_dfl = bbox_pred.sum() * 0
-                weight_targets = paddle.to_tensor([0])
+                weight_targets = paddle.to_tensor([0], dtype='float32')
 
             # qfl loss
             score = paddle.to_tensor(score)
@@ -362,6 +368,12 @@ class GFLHead(nn.Layer):
             avg_factor.append(weight_targets.sum())
 
         avg_factor = sum(avg_factor)
+        try:
+            avg_factor = paddle.distributed.all_reduce(avg_factor.clone())
+            avg_factor = paddle.clip(
+                avg_factor / paddle.distributed.get_world_size(), min=1)
+        except:
+            avg_factor = max(avg_factor.item(), 1)
         if avg_factor <= 0:
             loss_qfl = paddle.to_tensor(
                 0, dtype='float32', stop_gradient=False)
@@ -413,14 +425,13 @@ class GFLHead(nn.Layer):
         mlvl_scores = []
         for stride, cls_score, bbox_pred in zip(self.fpn_stride, cls_scores,
                                                 bbox_preds):
-            featmap_size = cls_score.shape[-2:]
+            featmap_size = [
+                paddle.shape(cls_score)[0], paddle.shape(cls_score)[1]
+            ]
             y, x = self.get_single_level_center_point(
                 featmap_size, stride, cell_offset=cell_offset)
             center_points = paddle.stack([x, y], axis=-1)
-            scores = F.sigmoid(
-                cls_score.transpose([1, 2, 0]).reshape(
-                    [-1, self.cls_out_channels]))
-            bbox_pred = bbox_pred.transpose([1, 2, 0])
+            scores = cls_score.reshape([-1, self.cls_out_channels])
             bbox_pred = self.distribution_project(bbox_pred) * stride
 
             if scores.shape[0] > self.nms_pre:
@@ -440,10 +451,6 @@ class GFLHead(nn.Layer):
             im_scale = paddle.concat([scale_factor[::-1], scale_factor[::-1]])
             mlvl_bboxes /= im_scale
         mlvl_scores = paddle.concat(mlvl_scores)
-        if self.use_sigmoid:
-            # add a dummy background class to the backend when use_sigmoid
-            padding = paddle.zeros([mlvl_scores.shape[0], 1])
-            mlvl_scores = paddle.concat([mlvl_scores, padding], axis=1)
         mlvl_scores = mlvl_scores.transpose([1, 0])
         return mlvl_bboxes, mlvl_scores
 

+ 3 - 74
paddlex/ppdet/modeling/heads/pico_head.py

@@ -66,7 +66,7 @@ class PicoFeat(nn.Layer):
                     ConvNormLayer(
                         ch_in=in_c,
                         ch_out=feat_out,
-                        filter_size=3,
+                        filter_size=5,
                         stride=1,
                         groups=feat_out,
                         norm_type=norm_type,
@@ -91,7 +91,7 @@ class PicoFeat(nn.Layer):
                         ConvNormLayer(
                             ch_in=in_c,
                             ch_out=feat_out,
-                            filter_size=3,
+                            filter_size=5,
                             stride=1,
                             groups=feat_out,
                             norm_type=norm_type,
@@ -250,80 +250,9 @@ class PicoHead(GFLHead):
 
             if not self.training:
                 cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1]))
-                bbox_pred = self.distribution_project(
-                    bbox_pred.transpose([0, 2, 3, 1])) * self.fpn_stride[i]
+                bbox_pred = bbox_pred.transpose([0, 2, 3, 1])
 
             cls_logits_list.append(cls_score)
             bboxes_reg_list.append(bbox_pred)
 
         return (cls_logits_list, bboxes_reg_list)
-
-    def get_bboxes_single(self,
-                          cls_scores,
-                          bbox_preds,
-                          img_shape,
-                          scale_factor,
-                          rescale=True,
-                          cell_offset=0):
-        assert len(cls_scores) == len(bbox_preds)
-        mlvl_bboxes = []
-        mlvl_scores = []
-        for stride, cls_score, bbox_pred in zip(self.fpn_stride, cls_scores,
-                                                bbox_preds):
-            featmap_size = cls_score.shape[0:2]
-            y, x = self.get_single_level_center_point(
-                featmap_size, stride, cell_offset=cell_offset)
-            center_points = paddle.stack([x, y], axis=-1)
-            scores = cls_score.reshape([-1, self.cls_out_channels])
-
-            if scores.shape[0] > self.nms_pre:
-                max_scores = scores.max(axis=1)
-                _, topk_inds = max_scores.topk(self.nms_pre)
-                center_points = center_points.gather(topk_inds)
-                bbox_pred = bbox_pred.gather(topk_inds)
-                scores = scores.gather(topk_inds)
-
-            bboxes = distance2bbox(
-                center_points, bbox_pred, max_shape=img_shape)
-            mlvl_bboxes.append(bboxes)
-            mlvl_scores.append(scores)
-        mlvl_bboxes = paddle.concat(mlvl_bboxes)
-        if rescale:
-            # [h_scale, w_scale] to [w_scale, h_scale, w_scale, h_scale]
-            im_scale = paddle.concat([scale_factor[::-1], scale_factor[::-1]])
-            mlvl_bboxes /= im_scale
-        mlvl_scores = paddle.concat(mlvl_scores)
-        mlvl_scores = mlvl_scores.transpose([1, 0])
-        return mlvl_bboxes, mlvl_scores
-
-    def decode(self, cls_scores, bbox_preds, im_shape, scale_factor,
-               cell_offset):
-        batch_bboxes = []
-        batch_scores = []
-        batch_size = cls_scores[0].shape[0]
-        for img_id in range(batch_size):
-            num_levels = len(cls_scores)
-            cls_score_list = [cls_scores[i][img_id] for i in range(num_levels)]
-            bbox_pred_list = [
-                bbox_preds[i].reshape([batch_size, -1, 4])[img_id]
-                for i in range(num_levels)
-            ]
-            bboxes, scores = self.get_bboxes_single(
-                cls_score_list,
-                bbox_pred_list,
-                im_shape[img_id],
-                scale_factor[img_id],
-                cell_offset=cell_offset)
-            batch_bboxes.append(bboxes)
-            batch_scores.append(scores)
-        batch_bboxes = paddle.stack(batch_bboxes, axis=0)
-        batch_scores = paddle.stack(batch_scores, axis=0)
-
-        return batch_bboxes, batch_scores
-
-    def post_process(self, gfl_head_outs, im_shape, scale_factor):
-        cls_scores, bboxes_reg = gfl_head_outs
-        bboxes, score = self.decode(cls_scores, bboxes_reg, im_shape,
-                                    scale_factor, self.cell_offset)
-        bbox_pred, bbox_num, _ = self.nms(bboxes, score)
-        return bbox_pred, bbox_num

+ 3 - 0
paddlex/ppdet/modeling/heads/solov2_head.py

@@ -490,6 +490,9 @@ class SOLOv2Head(nn.Layer):
                     fill_value=self.segm_strides[_ind],
                     dtype="int32"))
         strides = paddle.concat(strides)
+        strides = paddle.concat(
+            [strides, paddle.zeros(
+                shape=[1], dtype='int32')])
         strides = paddle.gather(strides, index=inds[:, 0])
 
         # mask encoding.

+ 47 - 7
paddlex/ppdet/modeling/heads/ssd_head.py

@@ -28,7 +28,7 @@ class SepConvLayer(nn.Layer):
                  out_channels,
                  kernel_size=3,
                  padding=1,
-                 conv_decay=0):
+                 conv_decay=0.):
         super(SepConvLayer, self).__init__()
         self.dw_conv = nn.Conv2D(
             in_channels=in_channels,
@@ -61,6 +61,35 @@ class SepConvLayer(nn.Layer):
         return x
 
 
+class SSDExtraHead(nn.Layer):
+    def __init__(self,
+                 in_channels=256,
+                 out_channels=([256, 512], [256, 512], [128, 256], [128, 256],
+                               [128, 256]),
+                 strides=(2, 2, 2, 1, 1),
+                 paddings=(1, 1, 1, 0, 0)):
+        super(SSDExtraHead, self).__init__()
+        self.convs = nn.LayerList()
+        for out_channel, stride, padding in zip(out_channels, strides,
+                                                paddings):
+            self.convs.append(
+                self._make_layers(in_channels, out_channel[0], out_channel[1],
+                                  stride, padding))
+            in_channels = out_channel[-1]
+
+    def _make_layers(self, c_in, c_hidden, c_out, stride_3x3, padding_3x3):
+        return nn.Sequential(
+            nn.Conv2D(c_in, c_hidden, 1),
+            nn.ReLU(),
+            nn.Conv2D(c_hidden, c_out, 3, stride_3x3, padding_3x3), nn.ReLU())
+
+    def forward(self, x):
+        out = [x]
+        for conv_layer in self.convs:
+            out.append(conv_layer(out[-1]))
+        return out
+
+
 @register
 class SSDHead(nn.Layer):
     """
@@ -75,6 +104,7 @@ class SSDHead(nn.Layer):
         use_sepconv (bool): Use SepConvLayer if true
         conv_decay (float): Conv regularization coeff
         loss (object): 'SSDLoss' instance
+        use_extra_head (bool): If use ResNet34 as baskbone, you should set `use_extra_head`=True
     """
 
     __shared__ = ['num_classes']
@@ -88,13 +118,19 @@ class SSDHead(nn.Layer):
                  padding=1,
                  use_sepconv=False,
                  conv_decay=0.,
-                 loss='SSDLoss'):
+                 loss='SSDLoss',
+                 use_extra_head=False):
         super(SSDHead, self).__init__()
         # add background class
         self.num_classes = num_classes + 1
         self.in_channels = in_channels
         self.anchor_generator = anchor_generator
         self.loss = loss
+        self.use_extra_head = use_extra_head
+
+        if self.use_extra_head:
+            self.ssd_extra_head = SSDExtraHead()
+            self.in_channels = [256, 512, 512, 256, 256, 256]
 
         if isinstance(anchor_generator, dict):
             self.anchor_generator = AnchorGeneratorSSD(**anchor_generator)
@@ -108,7 +144,7 @@ class SSDHead(nn.Layer):
                 box_conv = self.add_sublayer(
                     box_conv_name,
                     nn.Conv2D(
-                        in_channels=in_channels[i],
+                        in_channels=self.in_channels[i],
                         out_channels=num_prior * 4,
                         kernel_size=kernel_size,
                         padding=padding))
@@ -116,7 +152,7 @@ class SSDHead(nn.Layer):
                 box_conv = self.add_sublayer(
                     box_conv_name,
                     SepConvLayer(
-                        in_channels=in_channels[i],
+                        in_channels=self.in_channels[i],
                         out_channels=num_prior * 4,
                         kernel_size=kernel_size,
                         padding=padding,
@@ -128,7 +164,7 @@ class SSDHead(nn.Layer):
                 score_conv = self.add_sublayer(
                     score_conv_name,
                     nn.Conv2D(
-                        in_channels=in_channels[i],
+                        in_channels=self.in_channels[i],
                         out_channels=num_prior * self.num_classes,
                         kernel_size=kernel_size,
                         padding=padding))
@@ -136,7 +172,7 @@ class SSDHead(nn.Layer):
                 score_conv = self.add_sublayer(
                     score_conv_name,
                     SepConvLayer(
-                        in_channels=in_channels[i],
+                        in_channels=self.in_channels[i],
                         out_channels=num_prior * self.num_classes,
                         kernel_size=kernel_size,
                         padding=padding,
@@ -148,9 +184,13 @@ class SSDHead(nn.Layer):
         return {'in_channels': [i.channels for i in input_shape], }
 
     def forward(self, feats, image, gt_bbox=None, gt_class=None):
+        if self.use_extra_head:
+            assert len(feats) == 1, \
+                ("If you set use_extra_head=True, backbone feature "
+                 "list length should be 1.")
+            feats = self.ssd_extra_head(feats[0])
         box_preds = []
         cls_scores = []
-        prior_boxes = []
         for feat, box_conv, score_conv in zip(feats, self.box_convs,
                                               self.score_convs):
             box_pred = box_conv(feat)

+ 1 - 2
paddlex/ppdet/modeling/layers.py

@@ -281,8 +281,7 @@ class DropBlock(nn.Layer):
             for s in shape:
                 gamma *= s / (s - self.block_size + 1)
 
-            matrix = paddle.cast(
-                paddle.rand(x.shape, x.dtype) < gamma, x.dtype)
+            matrix = paddle.cast(paddle.rand(x.shape) < gamma, x.dtype)
             mask_inv = F.max_pool2d(
                 matrix,
                 self.block_size,

+ 1 - 1
paddlex/ppdet/modeling/mot/tracker/deepsort_tracker.py

@@ -84,7 +84,7 @@ class DeepSORTTracker(object):
         """
         Perform measurement update and track management.
         Args:
-            detections (list): List[paddlex.ppdet.modeling.mot.utils.Detection]
+            detections (list): List[ppdet.modeling.mot.utils.Detection]
             A list of detections at the current time step.
         """
         # Run matching cascade.

+ 2 - 0
paddlex/ppdet/modeling/necks/__init__.py

@@ -18,6 +18,7 @@ from . import hrfpn
 from . import ttf_fpn
 from . import centernet_fpn
 from . import pan
+from . import bifpn
 
 from .fpn import *
 from .yolo_fpn import *
@@ -26,3 +27,4 @@ from .ttf_fpn import *
 from .centernet_fpn import *
 from .blazeface_fpn import *
 from .pan import *
+from .bifpn import *

+ 302 - 0
paddlex/ppdet/modeling/necks/bifpn.py

@@ -0,0 +1,302 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+from paddle import ParamAttr
+from paddle.nn.initializer import Constant
+
+from paddlex.ppdet.core.workspace import register, serializable
+from paddlex.ppdet.modeling.layers import ConvNormLayer
+from ..shape_spec import ShapeSpec
+
+__all__ = ['BiFPN']
+
+
+class SeparableConvLayer(nn.Layer):
+    def __init__(self,
+                 in_channels,
+                 out_channels=None,
+                 kernel_size=3,
+                 norm_type='bn',
+                 norm_groups=32,
+                 act='swish'):
+        super(SeparableConvLayer, self).__init__()
+        assert norm_type in ['bn', 'sync_bn', 'gn', None]
+        assert act in ['swish', 'relu', None]
+
+        self.in_channels = in_channels
+        if out_channels is None:
+            self.out_channels = self.in_channels
+        self.norm_type = norm_type
+        self.norm_groups = norm_groups
+        self.depthwise_conv = nn.Conv2D(
+            in_channels,
+            in_channels,
+            kernel_size,
+            padding=kernel_size // 2,
+            groups=in_channels,
+            bias_attr=False)
+        self.pointwise_conv = nn.Conv2D(in_channels, self.out_channels, 1)
+
+        # norm type
+        if self.norm_type == 'bn':
+            self.norm = nn.BatchNorm2D(self.out_channels)
+        elif self.norm_type == 'sync_bn':
+            self.norm = nn.SyncBatchNorm(self.out_channels)
+        elif self.norm_type == 'gn':
+            self.norm = nn.GroupNorm(
+                num_groups=self.norm_groups, num_channels=self.out_channels)
+
+        # activation
+        if act == 'swish':
+            self.act = nn.Swish()
+        elif act == 'relu':
+            self.act = nn.ReLU()
+
+    def forward(self, x):
+        if self.act is not None:
+            x = self.act(x)
+        out = self.depthwise_conv(x)
+        out = self.pointwise_conv(out)
+        if self.norm_type is not None:
+            out = self.norm(out)
+        return out
+
+
+class BiFPNCell(nn.Layer):
+    def __init__(self,
+                 channels=256,
+                 num_levels=5,
+                 eps=1e-5,
+                 use_weighted_fusion=True,
+                 kernel_size=3,
+                 norm_type='bn',
+                 norm_groups=32,
+                 act='swish'):
+        super(BiFPNCell, self).__init__()
+        self.channels = channels
+        self.num_levels = num_levels
+        self.eps = eps
+        self.use_weighted_fusion = use_weighted_fusion
+
+        # up
+        self.conv_up = nn.LayerList([
+            SeparableConvLayer(
+                self.channels,
+                kernel_size=kernel_size,
+                norm_type=norm_type,
+                norm_groups=norm_groups,
+                act=act) for _ in range(self.num_levels - 1)
+        ])
+        # down
+        self.conv_down = nn.LayerList([
+            SeparableConvLayer(
+                self.channels,
+                kernel_size=kernel_size,
+                norm_type=norm_type,
+                norm_groups=norm_groups,
+                act=act) for _ in range(self.num_levels - 1)
+        ])
+
+        if self.use_weighted_fusion:
+            self.up_weights = self.create_parameter(
+                shape=[self.num_levels - 1, 2],
+                attr=ParamAttr(initializer=Constant(1.)))
+            self.down_weights = self.create_parameter(
+                shape=[self.num_levels - 1, 3],
+                attr=ParamAttr(initializer=Constant(1.)))
+
+    def _feature_fusion_cell(self,
+                             conv_layer,
+                             lateral_feat,
+                             sampling_feat,
+                             route_feat=None,
+                             weights=None):
+        if self.use_weighted_fusion:
+            weights = F.relu(weights)
+            weights = weights / (weights.sum() + self.eps)
+            if route_feat is not None:
+                out_feat = weights[0] * lateral_feat + \
+                           weights[1] * sampling_feat + \
+                           weights[2] * route_feat
+            else:
+                out_feat = weights[0] * lateral_feat + \
+                           weights[1] * sampling_feat
+        else:
+            if route_feat is not None:
+                out_feat = lateral_feat + sampling_feat + route_feat
+            else:
+                out_feat = lateral_feat + sampling_feat
+
+        out_feat = conv_layer(out_feat)
+        return out_feat
+
+    def forward(self, feats):
+        # feats: [P3 - P7]
+        lateral_feats = []
+
+        # up
+        up_feature = feats[-1]
+        for i, feature in enumerate(feats[::-1]):
+            if i == 0:
+                lateral_feats.append(feature)
+            else:
+                shape = paddle.shape(feature)
+                up_feature = F.interpolate(
+                    up_feature, size=[shape[2], shape[3]])
+                lateral_feature = self._feature_fusion_cell(
+                    self.conv_up[i - 1],
+                    feature,
+                    up_feature,
+                    weights=self.up_weights[i - 1]
+                    if self.use_weighted_fusion else None)
+                lateral_feats.append(lateral_feature)
+                up_feature = lateral_feature
+
+        out_feats = []
+        # down
+        down_feature = lateral_feats[-1]
+        for i, (lateral_feature,
+                route_feature) in enumerate(zip(lateral_feats[::-1], feats)):
+            if i == 0:
+                out_feats.append(lateral_feature)
+            else:
+                down_feature = F.max_pool2d(down_feature, 3, 2, 1)
+                if i == len(feats) - 1:
+                    route_feature = None
+                    weights = self.down_weights[
+                        i - 1][:2] if self.use_weighted_fusion else None
+                else:
+                    weights = self.down_weights[
+                        i - 1] if self.use_weighted_fusion else None
+                out_feature = self._feature_fusion_cell(
+                    self.conv_down[i - 1],
+                    lateral_feature,
+                    down_feature,
+                    route_feature,
+                    weights=weights)
+                out_feats.append(out_feature)
+                down_feature = out_feature
+
+        return out_feats
+
+
+@register
+@serializable
+class BiFPN(nn.Layer):
+    """
+    Bidirectional Feature Pyramid Network, see https://arxiv.org/abs/1911.09070
+
+    Args:
+        in_channels (list[int]): input channels of each level which can be
+            derived from the output shape of backbone by from_config.
+        out_channel (int): output channel of each level.
+        num_extra_levels (int): the number of extra stages added to the last level.
+            default: 2
+        fpn_strides (List): The stride of each level.
+        num_stacks (int): the number of stacks for BiFPN, default: 1.
+        use_weighted_fusion (bool): use weighted feature fusion in BiFPN, default: True.
+        norm_type (string|None): the normalization type in BiFPN module. If
+            norm_type is None, norm will not be used after conv and if
+            norm_type is string, bn, gn, sync_bn are available. default: bn.
+        norm_groups (int): if you use gn, set this param.
+        act (string|None): the activation function of BiFPN.
+    """
+
+    def __init__(self,
+                 in_channels=(512, 1024, 2048),
+                 out_channel=256,
+                 num_extra_levels=2,
+                 fpn_strides=[8, 16, 32, 64, 128],
+                 num_stacks=1,
+                 use_weighted_fusion=True,
+                 norm_type='bn',
+                 norm_groups=32,
+                 act='swish'):
+        super(BiFPN, self).__init__()
+        assert num_stacks > 0, "The number of stacks of BiFPN is at least 1."
+        assert norm_type in ['bn', 'sync_bn', 'gn', None]
+        assert act in ['swish', 'relu', None]
+        assert num_extra_levels >= 0, \
+            "The `num_extra_levels` must be non negative(>=0)."
+
+        self.in_channels = in_channels
+        self.out_channel = out_channel
+        self.num_extra_levels = num_extra_levels
+        self.num_stacks = num_stacks
+        self.use_weighted_fusion = use_weighted_fusion
+        self.norm_type = norm_type
+        self.norm_groups = norm_groups
+        self.act = act
+        self.num_levels = len(self.in_channels) + self.num_extra_levels
+        if len(fpn_strides) != self.num_levels:
+            for i in range(self.num_extra_levels):
+                fpn_strides += [fpn_strides[-1] * 2]
+        self.fpn_strides = fpn_strides
+
+        self.lateral_convs = nn.LayerList()
+        for in_c in in_channels:
+            self.lateral_convs.append(
+                ConvNormLayer(in_c, self.out_channel, 1, 1))
+        if self.num_extra_levels > 0:
+            self.extra_convs = nn.LayerList()
+            for i in range(self.num_extra_levels):
+                if i == 0:
+                    self.extra_convs.append(
+                        ConvNormLayer(self.in_channels[-1], self.out_channel,
+                                      3, 2))
+                else:
+                    self.extra_convs.append(nn.MaxPool2D(3, 2, 1))
+
+        self.bifpn_cells = nn.LayerList()
+        for i in range(self.num_stacks):
+            self.bifpn_cells.append(
+                BiFPNCell(
+                    self.out_channel,
+                    self.num_levels,
+                    use_weighted_fusion=self.use_weighted_fusion,
+                    norm_type=self.norm_type,
+                    norm_groups=self.norm_groups,
+                    act=self.act))
+
+    @classmethod
+    def from_config(cls, cfg, input_shape):
+        return {
+            'in_channels': [i.channels for i in input_shape],
+            'fpn_strides': [i.stride for i in input_shape]
+        }
+
+    @property
+    def out_shape(self):
+        return [
+            ShapeSpec(
+                channels=self.out_channel, stride=s) for s in self.fpn_strides
+        ]
+
+    def forward(self, feats):
+        assert len(feats) == len(self.in_channels)
+        fpn_feats = []
+        for conv_layer, feature in zip(self.lateral_convs, feats):
+            fpn_feats.append(conv_layer(feature))
+        if self.num_extra_levels > 0:
+            feat = feats[-1]
+            for conv_layer in self.extra_convs:
+                feat = conv_layer(feat)
+                fpn_feats.append(feat)
+
+        for bifpn_cell in self.bifpn_cells:
+            fpn_feats = bifpn_cell(fpn_feats)
+        return fpn_feats

+ 150 - 2
paddlex/ppdet/modeling/necks/centernet_fpn.py

@@ -16,11 +16,15 @@ import numpy as np
 import math
 import paddle
 import paddle.nn as nn
+import paddle.nn.functional as F
 from paddle.nn.initializer import KaimingUniform
 from paddlex.ppdet.core.workspace import register, serializable
 from paddlex.ppdet.modeling.layers import ConvNormLayer
+from paddlex.ppdet.modeling.backbones.hardnet import ConvLayer, HarDBlock
 from ..shape_spec import ShapeSpec
 
+__all__ = ['CenterNetDLAFPN', 'CenterNetHarDNetFPN']
+
 
 def fill_up_weights(up):
     weight = up.weight
@@ -134,8 +138,9 @@ class CenterNetDLAFPN(nn.Layer):
         last_level (int): the last level of input feature fed into the upsamplng block
         out_channel (int): the channel of the output feature, 0 by default means
             the channel of the input feature whose down ratio is `down_ratio`
+        first_level (int): the first level of input feature fed into the upsamplng
+            block, -1 by default and it will be calculated by down_ratio
         dcn_v2 (bool): whether use the DCNv2, true by default
-
     """
 
     def __init__(self,
@@ -143,9 +148,11 @@ class CenterNetDLAFPN(nn.Layer):
                  down_ratio=4,
                  last_level=5,
                  out_channel=0,
+                 first_level=-1,
                  dcn_v2=True):
         super(CenterNetDLAFPN, self).__init__()
-        self.first_level = int(np.log2(down_ratio))
+        self.first_level = int(np.log2(
+            down_ratio)) if first_level == -1 else first_level
         self.down_ratio = down_ratio
         self.last_level = last_level
         scales = [2**i for i in range(len(in_channels[self.first_level:]))]
@@ -168,6 +175,7 @@ class CenterNetDLAFPN(nn.Layer):
         return {'in_channels': [i.channels for i in input_shape]}
 
     def forward(self, body_feats):
+
         dla_up_feats = self.dla_up(body_feats)
 
         ida_up_feats = []
@@ -181,3 +189,143 @@ class CenterNetDLAFPN(nn.Layer):
     @property
     def out_shape(self):
         return [ShapeSpec(channels=self.out_channel, stride=self.down_ratio)]
+
+
+class TransitionUp(nn.Layer):
+    def __init__(self, in_channels, out_channels):
+        super().__init__()
+
+    def forward(self, x, skip, concat=True):
+        w, h = skip.shape[2], skip.shape[3]
+        out = F.interpolate(
+            x, size=(w, h), mode="bilinear", align_corners=True)
+        if concat:
+            out = paddle.concat([out, skip], 1)
+        return out
+
+
+@register
+@serializable
+class CenterNetHarDNetFPN(nn.Layer):
+    """
+    Args:
+        in_channels (list): number of input feature channels from backbone.
+            [96, 214, 458, 784] by default, means the channels of HarDNet85
+        num_layers (int): HarDNet laters, 85 by default
+        down_ratio (int): the down ratio from images to heatmap, 4 by default
+        first_level (int): the first level of input feature fed into the
+            upsamplng block
+        last_level (int): the last level of input feature fed into the upsamplng block
+        out_channel (int): the channel of the output feature, 0 by default means
+            the channel of the input feature whose down ratio is `down_ratio`
+    """
+
+    def __init__(self,
+                 in_channels,
+                 num_layers=85,
+                 down_ratio=4,
+                 first_level=-1,
+                 last_level=4,
+                 out_channel=0):
+        super(CenterNetHarDNetFPN, self).__init__()
+        self.first_level = int(np.log2(
+            down_ratio)) - 1 if first_level == -1 else first_level
+        self.down_ratio = down_ratio
+        self.last_level = last_level
+        self.last_pool = nn.AvgPool2D(kernel_size=2, stride=2)
+
+        assert num_layers in [68, 85], "HarDNet-{} not support.".format(
+            num_layers)
+        if num_layers == 85:
+            self.last_proj = ConvLayer(784, 256, kernel_size=1)
+            self.last_blk = HarDBlock(768, 80, 1.7, 8)
+            self.skip_nodes = [1, 3, 8, 13]
+            self.SC = [32, 32, 0]
+            gr = [64, 48, 28]
+            layers = [8, 8, 4]
+            ch_list2 = [224 + self.SC[0], 160 + self.SC[1], 96 + self.SC[2]]
+            channels = [96, 214, 458, 784]
+            self.skip_lv = 3
+
+        elif num_layers == 68:
+            self.last_proj = ConvLayer(654, 192, kernel_size=1)
+            self.last_blk = HarDBlock(576, 72, 1.7, 8)
+            self.skip_nodes = [1, 3, 8, 11]
+            self.SC = [32, 32, 0]
+            gr = [48, 32, 20]
+            layers = [8, 8, 4]
+            ch_list2 = [224 + self.SC[0], 96 + self.SC[1], 64 + self.SC[2]]
+            channels = [64, 124, 328, 654]
+            self.skip_lv = 2
+
+        self.transUpBlocks = nn.LayerList([])
+        self.denseBlocksUp = nn.LayerList([])
+        self.conv1x1_up = nn.LayerList([])
+        self.avg9x9 = nn.AvgPool2D(
+            kernel_size=(9, 9), stride=1, padding=(4, 4))
+        prev_ch = self.last_blk.get_out_ch()
+
+        for i in range(3):
+            skip_ch = channels[3 - i]
+            self.transUpBlocks.append(TransitionUp(prev_ch, prev_ch))
+            if i < self.skip_lv:
+                cur_ch = prev_ch + skip_ch
+            else:
+                cur_ch = prev_ch
+            self.conv1x1_up.append(
+                ConvLayer(
+                    cur_ch, ch_list2[i], kernel_size=1))
+            cur_ch = ch_list2[i]
+            cur_ch -= self.SC[i]
+            cur_ch *= 3
+
+            blk = HarDBlock(cur_ch, gr[i], 1.7, layers[i])
+            self.denseBlocksUp.append(blk)
+            prev_ch = blk.get_out_ch()
+
+        prev_ch += self.SC[0] + self.SC[1] + self.SC[2]
+        self.out_channel = prev_ch
+
+    @classmethod
+    def from_config(cls, cfg, input_shape):
+        return {'in_channels': [i.channels for i in input_shape]}
+
+    def forward(self, body_feats):
+        x = body_feats[-1]
+        x_sc = []
+        x = self.last_proj(x)
+        x = self.last_pool(x)
+        x2 = self.avg9x9(x)
+        x3 = x / (x.sum((2, 3), keepdim=True) + 0.1)
+        x = paddle.concat([x, x2, x3], 1)
+        x = self.last_blk(x)
+
+        for i in range(3):
+            skip_x = body_feats[3 - i]
+            x = self.transUpBlocks[i](x, skip_x, (i < self.skip_lv))
+            x = self.conv1x1_up[i](x)
+            if self.SC[i] > 0:
+                end = x.shape[1]
+                x_sc.append(x[:, end - self.SC[i]:, :, :])
+                x = x[:, :end - self.SC[i], :, :]
+            x2 = self.avg9x9(x)
+            x3 = x / (x.sum((2, 3), keepdim=True) + 0.1)
+            x = paddle.concat([x, x2, x3], 1)
+            x = self.denseBlocksUp[i](x)
+
+        scs = [x]
+        for i in range(3):
+            if self.SC[i] > 0:
+                scs.insert(
+                    0,
+                    F.interpolate(
+                        x_sc[i],
+                        size=(x.shape[2], x.shape[3]),
+                        mode="bilinear",
+                        align_corners=True))
+        neck_feat = paddle.concat(scs, 1)
+        return neck_feat
+
+    @property
+    def out_shape(self):
+        return [ShapeSpec(channels=self.out_channel, stride=self.down_ratio)]

+ 4 - 12
paddlex/ppdet/modeling/necks/pan.py

@@ -39,21 +39,13 @@ class PAN(nn.Layer):
         spatial_scales (list[float]): the spatial scales between input feature
             maps and original input image which can be derived from the output
             shape of backbone by from_config
-        has_extra_convs (bool): whether to add extra conv to the last level.
-            default False
-        extra_stage (int): the number of extra stages added to the last level.
-            default 1
-        use_c5 (bool): Whether to use c5 as the input of extra stage,
-            otherwise p5 is used. default True
+        start_level (int): Index of the start input backbone level used to
+            build the feature pyramid. Default: 0.
+        end_level (int): Index of the end input backbone level (exclusive) to
+            build the feature pyramid. Default: -1, which means the last level.
         norm_type (string|None): The normalization type in FPN module. If
             norm_type is None, norm will not be used after conv and if
             norm_type is string, bn, gn, sync_bn are available. default None
-        norm_decay (float): weight decay for normalization layer weights.
-            default 0.
-        freeze_norm (bool): whether to freeze normalization layer.
-            default False
-        relu_before_extra_convs (bool): whether to add relu before extra convs.
-            default False
     """
 
     def __init__(self,

+ 2 - 2
paddlex/ppdet/modeling/necks/yolo_fpn.py

@@ -30,8 +30,8 @@ def add_coord(x, data_format):
     else:
         h, w = x.shape[1], x.shape[2]
 
-    gx = paddle.arange(w, dtype=x.dtype) / ((w - 1.) * 2.0) - 1.
-    gy = paddle.arange(h, dtype=x.dtype) / ((h - 1.) * 2.0) - 1.
+    gx = paddle.cast(paddle.arange(w) / ((w - 1.) * 2.0) - 1., x.dtype)
+    gy = paddle.cast(paddle.arange(h) / ((h - 1.) * 2.0) - 1., x.dtype)
 
     if data_format == 'NCHW':
         gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w])

+ 15 - 6
paddlex/ppdet/modeling/proposal_generator/target.py

@@ -27,7 +27,8 @@ def rpn_anchor_target(anchors,
                       batch_size=1,
                       ignore_thresh=-1,
                       is_crowd=None,
-                      weights=[1., 1., 1., 1.]):
+                      weights=[1., 1., 1., 1.],
+                      assign_on_cpu=False):
     tgt_labels = []
     tgt_bboxes = []
     tgt_deltas = []
@@ -37,7 +38,7 @@ def rpn_anchor_target(anchors,
         # Step1: match anchor and gt_bbox
         matches, match_labels = label_box(
             anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True,
-            ignore_thresh, is_crowd_i)
+            ignore_thresh, is_crowd_i, assign_on_cpu)
         # Step2: sample anchor
         fg_inds, bg_inds = subsample_labels(match_labels,
                                             rpn_batch_size_per_im,
@@ -72,8 +73,14 @@ def label_box(anchors,
               negative_overlap,
               allow_low_quality,
               ignore_thresh,
-              is_crowd=None):
-    iou = bbox_overlaps(gt_boxes, anchors)
+              is_crowd=None,
+              assign_on_cpu=False):
+    if assign_on_cpu:
+        paddle.set_device("cpu")
+        iou = bbox_overlaps(gt_boxes, anchors)
+        paddle.set_device("gpu")
+    else:
+        iou = bbox_overlaps(gt_boxes, anchors)
     n_gt = gt_boxes.shape[0]
     if n_gt == 0 or is_crowd is None:
         n_gt_crowd = 0
@@ -178,7 +185,8 @@ def generate_proposal_target(rpn_rois,
                              is_crowd=None,
                              use_random=True,
                              is_cascade=False,
-                             cascade_iou=0.5):
+                             cascade_iou=0.5,
+                             assign_on_cpu=False):
 
     rois_with_gt = []
     tgt_labels = []
@@ -203,7 +211,8 @@ def generate_proposal_target(rpn_rois,
 
         # Step1: label bbox
         matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh,
-                                          False, ignore_thresh, is_crowd_i)
+                                          False, ignore_thresh, is_crowd_i,
+                                          assign_on_cpu)
         # Step2: sample bbox
         sampled_inds, sampled_gt_classes = sample_bbox(
             matches, match_labels, gt_class, batch_size_per_im, fg_fraction,

+ 24 - 7
paddlex/ppdet/modeling/proposal_generator/target_layer.py

@@ -22,6 +22,7 @@ import numpy as np
 @register
 @serializable
 class RPNTargetAssign(object):
+    __shared__ = ['assign_on_cpu']
     """
     RPN targets assignment module
 
@@ -48,6 +49,8 @@ class RPNTargetAssign(object):
             if the value is larger than zero.
         use_random (bool): Use random sampling to choose foreground and
             background boxes, default true.
+        assign_on_cpu (bool): In case the number of gt box is too large,
+            compute IoU on CPU, default false.
     """
 
     def __init__(self,
@@ -56,7 +59,8 @@ class RPNTargetAssign(object):
                  positive_overlap=0.7,
                  negative_overlap=0.3,
                  ignore_thresh=-1.,
-                 use_random=True):
+                 use_random=True,
+                 assign_on_cpu=False):
         super(RPNTargetAssign, self).__init__()
         self.batch_size_per_im = batch_size_per_im
         self.fg_fraction = fg_fraction
@@ -64,6 +68,7 @@ class RPNTargetAssign(object):
         self.negative_overlap = negative_overlap
         self.ignore_thresh = ignore_thresh
         self.use_random = use_random
+        self.assign_on_cpu = assign_on_cpu
 
     def __call__(self, inputs, anchors):
         """
@@ -74,9 +79,17 @@ class RPNTargetAssign(object):
         is_crowd = inputs.get('is_crowd', None)
         batch_size = len(gt_boxes)
         tgt_labels, tgt_bboxes, tgt_deltas = rpn_anchor_target(
-            anchors, gt_boxes, self.batch_size_per_im, self.positive_overlap,
-            self.negative_overlap, self.fg_fraction, self.use_random,
-            batch_size, self.ignore_thresh, is_crowd)
+            anchors,
+            gt_boxes,
+            self.batch_size_per_im,
+            self.positive_overlap,
+            self.negative_overlap,
+            self.fg_fraction,
+            self.use_random,
+            batch_size,
+            self.ignore_thresh,
+            is_crowd,
+            assign_on_cpu=self.assign_on_cpu)
         norm = self.batch_size_per_im * batch_size
 
         return tgt_labels, tgt_bboxes, tgt_deltas, norm
@@ -84,7 +97,7 @@ class RPNTargetAssign(object):
 
 @register
 class BBoxAssigner(object):
-    __shared__ = ['num_classes']
+    __shared__ = ['num_classes', 'assign_on_cpu']
     """
     RCNN targets assignment module
 
@@ -113,6 +126,8 @@ class BBoxAssigner(object):
         cascade_iou (list[iou]): The list of overlap to select foreground and
             background of each stage, which is only used In Cascade RCNN.
         num_classes (int): The number of class.
+        assign_on_cpu (bool): In case the number of gt box is too large,
+            compute IoU on CPU, default false.
     """
 
     def __init__(self,
@@ -123,7 +138,8 @@ class BBoxAssigner(object):
                  ignore_thresh=-1.,
                  use_random=True,
                  cascade_iou=[0.5, 0.6, 0.7],
-                 num_classes=80):
+                 num_classes=80,
+                 assign_on_cpu=False):
         super(BBoxAssigner, self).__init__()
         self.batch_size_per_im = batch_size_per_im
         self.fg_fraction = fg_fraction
@@ -133,6 +149,7 @@ class BBoxAssigner(object):
         self.use_random = use_random
         self.cascade_iou = cascade_iou
         self.num_classes = num_classes
+        self.assign_on_cpu = assign_on_cpu
 
     def __call__(self,
                  rpn_rois,
@@ -149,7 +166,7 @@ class BBoxAssigner(object):
             rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
             self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
             self.ignore_thresh, is_crowd, self.use_random, is_cascade,
-            self.cascade_iou[stage])
+            self.cascade_iou[stage], self.assign_on_cpu)
         rois = outs[0]
         rois_num = outs[-1]
         # tgt_labels, tgt_bboxes, tgt_gt_inds

+ 3 - 3
paddlex/ppdet/modeling/tests/test_architectures.py

@@ -17,7 +17,7 @@ from __future__ import division
 from __future__ import print_function
 
 import unittest
-import paddlex.ppdet
+import paddlex.ppdet as ppdet
 
 
 class TestFasterRCNN(unittest.TestCase):
@@ -31,8 +31,8 @@ class TestFasterRCNN(unittest.TestCase):
         # Trainer __init__ will build model and DataLoader
         # 'train' and 'eval' mode include dataset loading
         # use 'test' mode to simplify tests
-        cfg = paddlex.ppdet.core.workspace.load_config(self.cfg_file)
-        trainer = paddlex.ppdet.engine.Trainer(cfg, mode='test')
+        cfg = ppdet.core.workspace.load_config(self.cfg_file)
+        trainer = ppdet.engine.Trainer(cfg, mode='test')
 
 
 class TestMaskRCNN(TestFasterRCNN):

+ 32 - 1
paddlex/ppdet/optimizer.py

@@ -251,13 +251,40 @@ class OptimizerBuilder():
 
 
 class ModelEMA(object):
-    def __init__(self, decay, model, use_thres_step=False):
+    """
+    Exponential Weighted Average for Deep Neutal Networks
+    Args:
+        model (nn.Layer): Detector of model.
+        decay (int):  The decay used for updating ema parameter.
+            Ema's parameter are updated with the formula:
+           `ema_param = decay * ema_param + (1 - decay) * cur_param`.
+            Defaults is 0.9998.
+        use_thres_step (bool): Whether set decay by thres_step or not
+        cycle_epoch (int): The epoch of interval to reset ema_param and
+            step. Defaults is -1, which means not reset. Its function is to
+            add a regular effect to ema, which is set according to experience
+            and is effective when the total training epoch is large.
+    """
+
+    def __init__(self,
+                 model,
+                 decay=0.9998,
+                 use_thres_step=False,
+                 cycle_epoch=-1):
         self.step = 0
+        self.epoch = 0
         self.decay = decay
         self.state_dict = dict()
         for k, v in model.state_dict().items():
             self.state_dict[k] = paddle.zeros_like(v)
         self.use_thres_step = use_thres_step
+        self.cycle_epoch = cycle_epoch
+
+    def reset(self):
+        self.step = 0
+        self.epoch = 0
+        for k, v in self.state_dict.items():
+            self.state_dict[k] = paddle.zeros_like(v)
 
     def update(self, model):
         if self.use_thres_step:
@@ -280,4 +307,8 @@ class ModelEMA(object):
             v = v / (1 - self._decay**self.step)
             v.stop_gradient = True
             state_dict[k] = v
+        self.epoch += 1
+        if self.cycle_epoch > 0 and self.epoch == self.cycle_epoch:
+            self.reset()
+
         return state_dict

+ 7 - 0
paddlex/ppdet/utils/checkpoint.py

@@ -139,6 +139,13 @@ def match_state_dict(model_state_dict, weight_state_dict):
     max_id = match_matrix.argmax(1)
     max_len = match_matrix.max(1)
     max_id[max_len == 0] = -1
+    not_load_weight_name = []
+    for match_idx in range(len(max_id)):
+        if match_idx < len(weight_keys) and max_id[match_idx] == -1:
+            not_load_weight_name.append(weight_keys[match_idx])
+    if len(not_load_weight_name) > 0:
+        logger.info('{} in pretrained weight is not used in the model, '
+                    'and its will not be loaded'.format(not_load_weight_name))
     matched_keys = {}
     result_state_dict = {}
     for model_id, weight_id in enumerate(max_id):

+ 1 - 1
paddlex/ppdet/utils/download.py

@@ -140,7 +140,7 @@ def get_config_path(url):
 
     # 2. get url
     try:
-        from ppdet import __version__ as version
+        from paddlex.ppdet import __version__ as version
     except ImportError:
         version = None