| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344 |
- # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Any, List, Optional, Sequence, Tuple, Union
- import numpy as np
- from ....modules.object_detection.model_list import MODELS
- from ....utils.func_register import FuncRegister
- from ...common.batch_sampler import ImageBatchSampler
- from ..base import BasePredictor
- from .processors import (
- DetPad,
- DetPostProcess,
- Normalize,
- PadStride,
- ReadImage,
- Resize,
- ToBatch,
- ToCHWImage,
- WarpAffine,
- )
- from .result import DetResult
- from .utils import STATIC_SHAPE_MODEL_LIST
- class DetPredictor(BasePredictor):
- entities = MODELS
- _FUNC_MAP = {}
- register = FuncRegister(_FUNC_MAP)
- def __init__(
- self,
- *args,
- img_size: Optional[Union[int, Tuple[int, int]]] = None,
- threshold: Optional[Union[float, dict]] = None,
- layout_nms: Optional[bool] = None,
- layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
- layout_merge_bboxes_mode: Optional[Union[str, dict]] = None,
- **kwargs,
- ):
- """Initializes DetPredictor.
- Args:
- *args: Arbitrary positional arguments passed to the superclass.
- img_size (Optional[Union[int, Tuple[int, int]]], optional): The input image size (w, h). Defaults to None.
- threshold (Optional[float], optional): The threshold for filtering out low-confidence predictions.
- Defaults to None.
- layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
- layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
- Defaults to None.
- If it's a single number, then both width and height are used.
- If it's a tuple of two numbers, then they are used separately for width and height respectively.
- If it's None, then no unclipping will be performed.
- layout_merge_bboxes_mode (Optional[Union[str, dict]], optional): The mode for merging bounding boxes. Defaults to None.
- **kwargs: Arbitrary keyword arguments passed to the superclass.
- """
- super().__init__(*args, **kwargs)
- if img_size is not None:
- assert (
- self.model_name not in STATIC_SHAPE_MODEL_LIST
- ), f"The model {self.model_name} is not supported set input shape"
- if isinstance(img_size, int):
- img_size = (img_size, img_size)
- elif isinstance(img_size, (tuple, list)):
- assert len(img_size) == 2, f"The length of `img_size` should be 2."
- else:
- raise ValueError(
- f"The type of `img_size` must be int or Tuple[int, int], but got {type(img_size)}."
- )
- if layout_unclip_ratio is not None:
- if isinstance(layout_unclip_ratio, float):
- layout_unclip_ratio = (layout_unclip_ratio, layout_unclip_ratio)
- elif isinstance(layout_unclip_ratio, (tuple, list)):
- assert (
- len(layout_unclip_ratio) == 2
- ), f"The length of `layout_unclip_ratio` should be 2."
- elif isinstance(layout_unclip_ratio, dict):
- pass
- else:
- raise ValueError(
- f"The type of `layout_unclip_ratio` must be float, Tuple[float, float] or Dict, but got {type(layout_unclip_ratio)}."
- )
- if layout_merge_bboxes_mode is not None:
- if isinstance(layout_merge_bboxes_mode, str):
- assert layout_merge_bboxes_mode in [
- "union",
- "large",
- "small",
- ], f"The value of `layout_merge_bboxes_mode` must be one of ['union', 'large', 'small'] or a dict, but got {layout_merge_bboxes_mode}"
- self.img_size = img_size
- self.threshold = threshold
- self.layout_nms = layout_nms
- self.layout_unclip_ratio = layout_unclip_ratio
- self.layout_merge_bboxes_mode = layout_merge_bboxes_mode
- self.pre_ops, self.infer, self.post_op = self._build()
- def _build_batch_sampler(self):
- return ImageBatchSampler()
- def _get_result_class(self):
- return DetResult
- def _build(self) -> Tuple:
- """Build the preprocessors, inference engine, and postprocessors based on the configuration.
- Returns:
- tuple: A tuple containing the preprocessors, inference engine, and postprocessors.
- """
- # build preprocess ops
- pre_ops = [ReadImage(format="RGB")]
- for cfg in self.config["Preprocess"]:
- tf_key = cfg["type"]
- func = self._FUNC_MAP[tf_key]
- cfg.pop("type")
- args = cfg
- op = func(self, **args) if args else func(self)
- if op:
- pre_ops.append(op)
- pre_ops.append(self.build_to_batch())
- if self.img_size is not None:
- if isinstance(pre_ops[1], Resize):
- pre_ops.pop(1)
- pre_ops.insert(1, self.build_resize(self.img_size, False, 2))
- # build infer
- infer = self.create_static_infer()
- # build postprocess op
- post_op = self.build_postprocess()
- return pre_ops, infer, post_op
- def _format_output(self, pred: Sequence[Any]) -> List[dict]:
- """
- Transform batch outputs into a list of single image output.
- Args:
- pred (Sequence[Any]): The input predictions, which can be either a list of 3 or 4 elements.
- - When len(pred) == 4, it is expected to be in the format [boxes, class_ids, scores, masks],
- compatible with SOLOv2 output.
- - When len(pred) == 3, it is expected to be in the format [boxes, box_nums, masks],
- compatible with Instance Segmentation output.
- Returns:
- List[dict]: A list of dictionaries, each containing either 'class_id' and 'masks' (for SOLOv2),
- or 'boxes' and 'masks' (for Instance Segmentation), or just 'boxes' if no masks are provided.
- """
- box_idx_start = 0
- pred_box = []
- if len(pred) == 4:
- # Adapt to SOLOv2
- pred_class_id = []
- pred_mask = []
- pred_class_id.append([pred[1], pred[2]])
- pred_mask.append(pred[3])
- return [
- {
- "class_id": np.array(pred_class_id[i]),
- "masks": np.array(pred_mask[i]),
- }
- for i in range(len(pred_class_id))
- ]
- if len(pred) == 3:
- # Adapt to Instance Segmentation
- pred_mask = []
- for idx in range(len(pred[1])):
- np_boxes_num = pred[1][idx]
- box_idx_end = box_idx_start + np_boxes_num
- np_boxes = pred[0][box_idx_start:box_idx_end]
- pred_box.append(np_boxes)
- if len(pred) == 3:
- np_masks = pred[2][box_idx_start:box_idx_end]
- pred_mask.append(np_masks)
- box_idx_start = box_idx_end
- if len(pred) == 3:
- return [
- {"boxes": np.array(pred_box[i]), "masks": np.array(pred_mask[i])}
- for i in range(len(pred_box))
- ]
- else:
- return [{"boxes": np.array(res)} for res in pred_box]
- def process(
- self,
- batch_data: List[Any],
- threshold: Optional[Union[float, dict]] = None,
- layout_nms: bool = False,
- layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
- layout_merge_bboxes_mode: Optional[Union[str, dict]] = None,
- ):
- """
- Process a batch of data through the preprocessing, inference, and postprocessing.
- Args:
- batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
- threshold (Optional[float, dict], optional): The threshold for filtering out low-confidence predictions.
- layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to None.
- layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
- layout_merge_bboxes_mode (Optional[Union[str, dict]], optional): The mode for merging bounding boxes. Defaults to None.
- Returns:
- dict: A dictionary containing the input path, raw image, class IDs, scores, and label names
- for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
- """
- datas = batch_data.instances
- # preprocess
- for pre_op in self.pre_ops[:-1]:
- datas = pre_op(datas)
- # use `ToBatch` format batch inputs
- batch_inputs = self.pre_ops[-1](datas)
- # do infer
- batch_preds = self.infer(batch_inputs)
- # process a batch of predictions into a list of single image result
- preds_list = self._format_output(batch_preds)
- # postprocess
- boxes = self.post_op(
- preds_list,
- datas,
- threshold=threshold if threshold is not None else self.threshold,
- layout_nms=layout_nms or self.layout_nms,
- layout_unclip_ratio=layout_unclip_ratio or self.layout_unclip_ratio,
- layout_merge_bboxes_mode=layout_merge_bboxes_mode
- or self.layout_merge_bboxes_mode,
- )
- return {
- "input_path": batch_data.input_paths,
- "page_index": batch_data.page_indexes,
- "input_img": [data["ori_img"] for data in datas],
- "boxes": boxes,
- }
- @register("Resize")
- def build_resize(self, target_size, keep_ratio=False, interp=2):
- assert target_size
- if isinstance(interp, int):
- interp = {
- 0: "NEAREST",
- 1: "LINEAR",
- 2: "BICUBIC",
- 3: "AREA",
- 4: "LANCZOS4",
- }[interp]
- op = Resize(target_size=target_size[::-1], keep_ratio=keep_ratio, interp=interp)
- return op
- @register("NormalizeImage")
- def build_normalize(
- self,
- norm_type=None,
- mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225],
- is_scale=True,
- ):
- if is_scale:
- scale = 1.0 / 255.0
- else:
- scale = 1
- if not norm_type or norm_type == "none":
- norm_type = "mean_std"
- if norm_type != "mean_std":
- mean = 0
- std = 1
- return Normalize(scale=scale, mean=mean, std=std)
- @register("Permute")
- def build_to_chw(self):
- return ToCHWImage()
- @register("Pad")
- def build_pad(self, fill_value=None, size=None):
- if fill_value is None:
- fill_value = [127.5, 127.5, 127.5]
- if size is None:
- size = [3, 640, 640]
- return DetPad(size=size, fill_value=fill_value)
- @register("PadStride")
- def build_pad_stride(self, stride=32):
- return PadStride(stride=stride)
- @register("WarpAffine")
- def build_warp_affine(self, input_h=512, input_w=512, keep_res=True):
- return WarpAffine(input_h=input_h, input_w=input_w, keep_res=keep_res)
- def build_to_batch(self):
- models_required_imgsize = [
- "DETR",
- "DINO",
- "RCNN",
- "YOLOv3",
- "CenterNet",
- "BlazeFace",
- "BlazeFace-FPN-SSH",
- "PP-DocLayout-L",
- "PP-DocLayout_plus-L",
- "PP-DocBlockLayout",
- ]
- if any(name in self.model_name for name in models_required_imgsize):
- ordered_required_keys = (
- "img_size",
- "img",
- "scale_factors",
- )
- else:
- ordered_required_keys = ("img", "scale_factors")
- return ToBatch(ordered_required_keys=ordered_required_keys)
- def build_postprocess(self):
- if self.threshold is None:
- self.threshold = self.config.get("draw_threshold", 0.5)
- if not self.layout_nms:
- self.layout_nms = self.config.get("layout_nms", None)
- if self.layout_unclip_ratio is None:
- self.layout_unclip_ratio = self.config.get("layout_unclip_ratio", None)
- if self.layout_merge_bboxes_mode is None:
- self.layout_merge_bboxes_mode = self.config.get(
- "layout_merge_bboxes_mode", None
- )
- return DetPostProcess(labels=self.config["label_list"])
|