浏览代码

update for `class Batch` and fix sth (#3578)

* sup read operations

* update for  and fix sth

* update

* update
zhang-prog 8 月之前
父节点
当前提交
9e95cb04dc

+ 1 - 1
libs/paddlex-hpi/src/paddlex_hpi/models/base.py

@@ -29,7 +29,7 @@ import ultra_infer as ui
 from ultra_infer.model import BaseUltraInferModel
 from paddlex.inference.common.reader import ReadImage, ReadTS
 from paddlex.inference.models import BasePredictor
-from paddlex.inference.utils.new_ir_blacklist import NEWIR_BLOCKLIST
+from paddlex.inference.utils.new_ir_blocklist import NEWIR_BLOCKLIST
 from paddlex.utils import device as device_helper
 from paddlex.utils import logging
 from paddlex.utils.subclass_register import AutoRegisterABCMetaClass

+ 3 - 3
paddlex/inference/common/batch_sampler/image_batch_sampler.py

@@ -79,7 +79,7 @@ class ImageBatchSampler(BaseBatchSampler):
                 batch.append(input, None, None)
                 if len(batch) == self.batch_size:
                     yield batch
-                    batch.reset()
+                    batch = ImgBatch()
             elif isinstance(input, str) and input.split(".")[-1] in ("PDF", "pdf"):
                 file_path = (
                     self._download_from_url(input)
@@ -90,7 +90,7 @@ class ImageBatchSampler(BaseBatchSampler):
                     batch.append(page_img, file_path, page_idx)
                     if len(batch) == self.batch_size:
                         yield batch
-                        batch.reset()
+                        batch = ImgBatch()
             elif isinstance(input, str):
                 file_path = (
                     self._download_from_url(input)
@@ -102,7 +102,7 @@ class ImageBatchSampler(BaseBatchSampler):
                     batch.append(file_path, file_path, None)
                     if len(batch) == self.batch_size:
                         yield batch
-                        batch.reset()
+                        batch = ImgBatch()
             else:
                 logging.warning(
                     f"Not supported input data type! Only `numpy.ndarray` and `str` are supported! So has been ignored: {input}."

+ 1 - 2
paddlex/inference/models/common/tokenizer/clip_tokenizer.py

@@ -21,8 +21,6 @@ import unicodedata
 from functools import lru_cache
 from typing import List, Optional
 
-from paddle.utils import try_import
-
 from .tokenizer_utils_base import AddedToken
 from .tokenizer_utils import PretrainedTokenizer
 from .tokenizer_utils import _is_control, _is_punctuation, _is_whitespace
@@ -325,6 +323,7 @@ class CLIPTokenizer(PretrainedTokenizer):
         pad_token="<|endoftext|>",
         **kwargs
     ):
+        from paddle.utils import try_import
 
         bos_token = (
             AddedToken(bos_token, lstrip=False, rstrip=False)

+ 9 - 5
paddlex/inference/models/object_detection/processors.py

@@ -684,7 +684,7 @@ def check_containment(boxes, formula_index=None, category_index=None, mode=None)
                 if mode == "large" and boxes[j][0] == category_index:
                     if is_contained(boxes[i], boxes[j]):
                         contained_by_other[i] = 1
-                        contains_other[j] = 1                   
+                        contains_other[j] = 1
                 if mode == "small" and boxes[i][0] == category_index:
                     if is_contained(boxes[i], boxes[j]):
                         contained_by_other[i] = 1
@@ -759,8 +759,10 @@ class DetPostProcess:
             boxes = np.array(boxes[selected_indices])
 
         if layout_merge_bboxes_mode:
-            formula_index = (self.labels.index("formula") if "formula" in self.labels else None)
-            if isinstance(layout_merge_bboxes_mode, str): 
+            formula_index = (
+                self.labels.index("formula") if "formula" in self.labels else None
+            )
+            if isinstance(layout_merge_bboxes_mode, str):
                 assert layout_merge_bboxes_mode in [
                     "union",
                     "large",
@@ -793,13 +795,15 @@ class DetPostProcess:
                                 boxes, formula_index, category_index, mode=layout_mode
                             )
                             # Remove boxes that are contained by other boxes
-                            keep_mask &= (contained_by_other == 0)
+                            keep_mask &= contained_by_other == 0
                         elif layout_mode == "small":
                             contains_other, contained_by_other = check_containment(
                                 boxes, formula_index, category_index, mode=layout_mode
                             )
                             # Keep boxes that do not contain others or are contained by others
-                            keep_mask &= (contains_other == 0) | (contained_by_other == 1)
+                            keep_mask &= (contains_other == 0) | (
+                                contained_by_other == 1
+                            )
                 boxes = boxes[keep_mask]
 
         if layout_unclip_ratio: