|
|
@@ -18,11 +18,17 @@ from ....modules.video_classification.model_list import MODELS
|
|
|
from ...common.batch_sampler import VideoBatchSampler
|
|
|
from ...common.reader import ReadVideo
|
|
|
from ..common import (
|
|
|
- ToBatch,
|
|
|
StaticInfer,
|
|
|
)
|
|
|
from ..base import BasicPredictor
|
|
|
-from .processors import Scale, CenterCrop, Image2Array, NormalizeVideo, VideoClasTopk
|
|
|
+from .processors import (
|
|
|
+ Scale,
|
|
|
+ CenterCrop,
|
|
|
+ Image2Array,
|
|
|
+ NormalizeVideo,
|
|
|
+ VideoClasTopk,
|
|
|
+ ToBatch,
|
|
|
+)
|
|
|
from .result import TopkVideoResult
|
|
|
|
|
|
|
|
|
@@ -76,8 +82,8 @@ class VideoClasPredictor(BasicPredictor):
|
|
|
batch_videos = self.pre_tfs["Scale"](videos=batch_raw_videos)
|
|
|
batch_videos = self.pre_tfs["CenterCrop"](videos=batch_videos)
|
|
|
batch_videos = self.pre_tfs["Image2Array"](videos=batch_videos)
|
|
|
- x = self.pre_tfs["NormalizeVideo"](videos=batch_videos)
|
|
|
-
|
|
|
+ batch_videos = self.pre_tfs["NormalizeVideo"](videos=batch_videos)
|
|
|
+ x = self.pre_tfs["ToBatch"](videos=batch_videos)
|
|
|
batch_preds = self.infer(x=x)
|
|
|
|
|
|
batch_class_ids, batch_scores, batch_label_names = self.post_op["Topk"](
|