|
@@ -83,20 +83,26 @@ class DetPredictor(BasicPredictor):
|
|
|
raise ValueError(
|
|
raise ValueError(
|
|
|
f"The type of `img_size` must be int or Tuple[int, int], but got {type(img_size)}."
|
|
f"The type of `img_size` must be int or Tuple[int, int], but got {type(img_size)}."
|
|
|
)
|
|
)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
if layout_unclip_ratio is not None:
|
|
if layout_unclip_ratio is not None:
|
|
|
if isinstance(layout_unclip_ratio, float):
|
|
if isinstance(layout_unclip_ratio, float):
|
|
|
layout_unclip_ratio = (layout_unclip_ratio, layout_unclip_ratio)
|
|
layout_unclip_ratio = (layout_unclip_ratio, layout_unclip_ratio)
|
|
|
elif isinstance(layout_unclip_ratio, (tuple, list)):
|
|
elif isinstance(layout_unclip_ratio, (tuple, list)):
|
|
|
- assert len(layout_unclip_ratio) == 2, f"The length of `layout_unclip_ratio` should be 2."
|
|
|
|
|
|
|
+ assert (
|
|
|
|
|
+ len(layout_unclip_ratio) == 2
|
|
|
|
|
+ ), f"The length of `layout_unclip_ratio` should be 2."
|
|
|
else:
|
|
else:
|
|
|
raise ValueError(
|
|
raise ValueError(
|
|
|
f"The type of `layout_unclip_ratio` must be float or Tuple[float, float], but got {type(layout_unclip_ratio)}."
|
|
f"The type of `layout_unclip_ratio` must be float or Tuple[float, float], but got {type(layout_unclip_ratio)}."
|
|
|
)
|
|
)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
if layout_merge_bboxes_mode is not None:
|
|
if layout_merge_bboxes_mode is not None:
|
|
|
- assert layout_merge_bboxes_mode in ["union", "large", "small"], \
|
|
|
|
|
- f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
|
|
|
|
|
|
|
+ assert layout_merge_bboxes_mode in [
|
|
|
|
|
+ "union",
|
|
|
|
|
+ "large",
|
|
|
|
|
+ "small",
|
|
|
|
|
+ ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'], but got {layout_merge_bboxes_mode}"
|
|
|
|
|
+
|
|
|
self.img_size = img_size
|
|
self.img_size = img_size
|
|
|
self.threshold = threshold
|
|
self.threshold = threshold
|
|
|
self.layout_nms = layout_nms
|
|
self.layout_nms = layout_nms
|
|
@@ -197,13 +203,14 @@ class DetPredictor(BasicPredictor):
|
|
|
else:
|
|
else:
|
|
|
return [{"boxes": np.array(res)} for res in pred_box]
|
|
return [{"boxes": np.array(res)} for res in pred_box]
|
|
|
|
|
|
|
|
- def process(self,
|
|
|
|
|
- batch_data: List[Any],
|
|
|
|
|
- threshold: Optional[Union[float, dict]] = None,
|
|
|
|
|
- layout_nms: Optional[bool] = None,
|
|
|
|
|
- layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
|
|
|
|
|
- layout_merge_bboxes_mode: Optional[str] = None,
|
|
|
|
|
- ):
|
|
|
|
|
|
|
+ def process(
|
|
|
|
|
+ self,
|
|
|
|
|
+ batch_data: List[Any],
|
|
|
|
|
+ threshold: Optional[Union[float, dict]] = None,
|
|
|
|
|
+ layout_nms: bool = False,
|
|
|
|
|
+ layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
|
|
|
|
|
+ layout_merge_bboxes_mode: Optional[str] = None,
|
|
|
|
|
+ ):
|
|
|
"""
|
|
"""
|
|
|
Process a batch of data through the preprocessing, inference, and postprocessing.
|
|
Process a batch of data through the preprocessing, inference, and postprocessing.
|
|
|
|
|
|
|
@@ -218,7 +225,7 @@ class DetPredictor(BasicPredictor):
|
|
|
dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
|
|
dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
|
|
|
for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
|
|
for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
|
|
|
"""
|
|
"""
|
|
|
- datas = batch_data
|
|
|
|
|
|
|
+ datas = batch_data.instances
|
|
|
# preprocess
|
|
# preprocess
|
|
|
for pre_op in self.pre_ops[:-1]:
|
|
for pre_op in self.pre_ops[:-1]:
|
|
|
datas = pre_op(datas)
|
|
datas = pre_op(datas)
|
|
@@ -233,16 +240,18 @@ class DetPredictor(BasicPredictor):
|
|
|
preds_list = self._format_output(batch_preds)
|
|
preds_list = self._format_output(batch_preds)
|
|
|
# postprocess
|
|
# postprocess
|
|
|
boxes = self.post_op(
|
|
boxes = self.post_op(
|
|
|
- preds_list,
|
|
|
|
|
- datas,
|
|
|
|
|
- threshold = threshold or self.threshold,
|
|
|
|
|
|
|
+ preds_list,
|
|
|
|
|
+ datas,
|
|
|
|
|
+ threshold=threshold or self.threshold,
|
|
|
layout_nms=layout_nms or self.layout_nms,
|
|
layout_nms=layout_nms or self.layout_nms,
|
|
|
layout_unclip_ratio=layout_unclip_ratio or self.layout_unclip_ratio,
|
|
layout_unclip_ratio=layout_unclip_ratio or self.layout_unclip_ratio,
|
|
|
- layout_merge_bboxes_mode=layout_merge_bboxes_mode or self.layout_merge_bboxes_mode
|
|
|
|
|
|
|
+ layout_merge_bboxes_mode=layout_merge_bboxes_mode
|
|
|
|
|
+ or self.layout_merge_bboxes_mode,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
return {
|
|
return {
|
|
|
- "input_path": [data.get("img_path", None) for data in datas],
|
|
|
|
|
|
|
+ "input_path": batch_data.input_paths,
|
|
|
|
|
+ "page_index": batch_data.page_indexes,
|
|
|
"input_img": [data["ori_img"] for data in datas],
|
|
"input_img": [data["ori_img"] for data in datas],
|
|
|
"boxes": boxes,
|
|
"boxes": boxes,
|
|
|
}
|
|
}
|
|
@@ -330,7 +339,7 @@ class DetPredictor(BasicPredictor):
|
|
|
if self.layout_unclip_ratio is None:
|
|
if self.layout_unclip_ratio is None:
|
|
|
self.layout_unclip_ratio = self.config.get("layout_unclip_ratio", None)
|
|
self.layout_unclip_ratio = self.config.get("layout_unclip_ratio", None)
|
|
|
if self.layout_merge_bboxes_mode is None:
|
|
if self.layout_merge_bboxes_mode is None:
|
|
|
- self.layout_merge_bboxes_mode = self.config.get("layout_merge_bboxes_mode", None)
|
|
|
|
|
- return DetPostProcess(
|
|
|
|
|
- labels=self.config["label_list"]
|
|
|
|
|
- )
|
|
|
|
|
|
|
+ self.layout_merge_bboxes_mode = self.config.get(
|
|
|
|
|
+ "layout_merge_bboxes_mode", None
|
|
|
|
|
+ )
|
|
|
|
|
+ return DetPostProcess(labels=self.config["label_list"])
|