|
|
@@ -30,10 +30,6 @@ from .base import BaseAPI
|
|
|
from collections import OrderedDict
|
|
|
from .utils.detection_eval import eval_results, bbox2out
|
|
|
|
|
|
-import random
|
|
|
-random.seed(0)
|
|
|
-np.random.seed(0)
|
|
|
-
|
|
|
|
|
|
class YOLOv3(BaseAPI):
|
|
|
"""构建YOLOv3,并实现其训练、评估、预测和模型导出。
|
|
|
@@ -181,7 +177,7 @@ class YOLOv3(BaseAPI):
|
|
|
model.max_width = self.max_width
|
|
|
inputs = model.generate_inputs()
|
|
|
model_out = model.build_net(inputs)
|
|
|
- outputs = OrderedDict([('bbox', model_out[0])])
|
|
|
+ outputs = OrderedDict([('bbox', model_out)])
|
|
|
if mode == 'train':
|
|
|
self.optimizer.minimize(model_out)
|
|
|
outputs = OrderedDict([('loss', model_out)])
|