will-jl944 il y a 4 ans
Parent
commit
7fd6c0ff28
1 fichiers modifiés avec 1 ajouts et 1 suppressions
  1. 1 1
      paddlex/cv/models/detector.py

+ 1 - 1
paddlex/cv/models/detector.py

@@ -686,7 +686,7 @@ class PicoDet(BaseDetector):
             loss_dfl = ppdet.modeling.DistributionFocalLoss(loss_weight=.25)
             loss_bbox = ppdet.modeling.GIoULoss(loss_weight=2.0)
             assigner = ppdet.modeling.SimOTAAssigner(
-                candidate_topk=10, iou_weight=6)
+                candidate_topk=10, iou_weight=6, num_classes=num_classes)
             nms = ppdet.modeling.MultiClassNMS(
                 nms_top_k=nms_top_k,
                 keep_top_k=nms_keep_top_k,