gaotingquan 1 жил өмнө
parent
commit
b26ce06df0

+ 7 - 1
paddlex/inference/components/task_related/__init__.py

@@ -13,7 +13,13 @@
 # limitations under the License.
 
 from .clas import Topk, MultiLabelThreshOutput, NormalizeFeatures
-from .text_det import DetResizeForTest, NormalizeImage, DBPostProcess, CropByPolys
+from .text_det import (
+    DetResizeForTest,
+    NormalizeImage,
+    DBPostProcess,
+    SortBoxes,
+    CropByPolys,
+)
 from .text_rec import OCRReisizeNormImg, CTCLabelDecode
 from .table_rec import TableLabelDecode
 from .det import DetPostProcess, CropByBoxes

+ 36 - 29
paddlex/inference/components/task_related/text_det.py

@@ -433,11 +433,8 @@ class CropByPolys(BaseComponent):
         """apply"""
         img = self._reader.read(img_path)
 
-        # TODO
-        # dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
         if self.det_box_type == "quad":
-            dt_boxes = self.sorted_boxes(dt_polys)
-            dt_boxes = np.array(dt_boxes)
+            dt_boxes = np.array(dt_polys)
             output_list = []
             for bno in range(len(dt_boxes)):
                 tmp_box = copy.deepcopy(dt_boxes[bno])
@@ -465,31 +462,6 @@ class CropByPolys(BaseComponent):
 
         return output_list
 
-    def sorted_boxes(self, dt_boxes):
-        """
-        Sort text boxes in order from top to bottom, left to right
-        args:
-            dt_boxes(array):detected text boxes with shape [4, 2]
-        return:
-            sorted boxes(array) with shape [4, 2]
-        """
-        dt_boxes = np.array(dt_boxes)
-        num_boxes = dt_boxes.shape[0]
-        sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
-        _boxes = list(sorted_boxes)
-
-        for i in range(num_boxes - 1):
-            for j in range(i, -1, -1):
-                if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
-                    _boxes[j + 1][0][0] < _boxes[j][0][0]
-                ):
-                    tmp = _boxes[j]
-                    _boxes[j] = _boxes[j + 1]
-                    _boxes[j + 1] = tmp
-                else:
-                    break
-        return _boxes
-
     def get_minarea_rect_crop(self, img, points):
         """get_minarea_rect_crop"""
         bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
@@ -880,3 +852,38 @@ class CropByPolys(BaseComponent):
             img = np.stack((img,) * 3, axis=-1)
         img_crop, image = rectifier.run(img, new_points_list, mode="homography")
         return img_crop[0]
+
+
+class SortBoxes(BaseComponent):
+
+    YIELD_BATCH = False
+
+    INPUT_KEYS = ["dt_polys"]
+    OUTPUT_KEYS = ["dt_polys"]
+    DEAULT_INPUTS = {"dt_polys": "dt_polys"}
+    DEAULT_OUTPUTS = {"dt_polys": "dt_polys"}
+
+    def apply(self, dt_polys):
+        """
+        Sort text boxes in order from top to bottom, left to right
+        args:
+            dt_boxes(array):detected text boxes with shape [4, 2]
+        return:
+            sorted boxes(array) with shape [4, 2]
+        """
+        dt_boxes = np.array(dt_polys)
+        num_boxes = dt_boxes.shape[0]
+        sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
+        _boxes = list(sorted_boxes)
+
+        for i in range(num_boxes - 1):
+            for j in range(i, -1, -1):
+                if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
+                    _boxes[j + 1][0][0] < _boxes[j][0][0]
+                ):
+                    tmp = _boxes[j]
+                    _boxes[j] = _boxes[j + 1]
+                    _boxes[j + 1] = tmp
+                else:
+                    break
+        return {"dt_polys": [box.tolist() for box in _boxes]}

+ 15 - 11
paddlex/inference/pipelines/ocr.py

@@ -12,9 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .base import BasePipeline
-from ..components import CropByPolys
+from ..components import SortBoxes, CropByPolys
 from ..results import OCRResult
+from .base import BasePipeline
 
 
 class OCRPipeline(BasePipeline):
@@ -23,21 +23,25 @@ class OCRPipeline(BasePipeline):
     entities = "ocr"
 
     def __init__(
-        self, det_model, rec_model, rec_batch_size, predictor_kwargs=None, is_curve=False, **kwargs
+        self,
+        det_model,
+        rec_model,
+        rec_batch_size,
+        predictor_kwargs=None,
     ):
         super().__init__(predictor_kwargs)
         self._det_predict = self._create_predictor(det_model)
         self._rec_predict = self._create_predictor(rec_model, batch_size=rec_batch_size)
-        # TODO: foo
-        if is_curve:
-            det_box_type = 'poly'
-        else:
-            det_box_type = 'quad'
-        self._crop_by_polys = CropByPolys(det_box_type=det_box_type)
+        is_curve = self._det_predict.model_name in [
+            "PP-OCRv4_mobile_seal_det",
+            "PP-OCRv4_server_seal_det",
+        ]
+        self._sort_boxes = SortBoxes()
+        self._crop_by_polys = CropByPolys(det_box_type="poly" if is_curve else "quad")
 
     def predict(self, x):
         for det_res in self._det_predict(x):
-            single_img_res = det_res
+            single_img_res = next(self._sort_boxes(det_res))
             single_img_res["rec_text"] = []
             single_img_res["rec_score"] = []
             if len(single_img_res["dt_polys"]) > 0:
@@ -45,4 +49,4 @@ class OCRPipeline(BasePipeline):
                 for rec_res in self._rec_predict(all_subs_of_img):
                     single_img_res["rec_text"].append(rec_res["rec_text"])
                     single_img_res["rec_score"].append(rec_res["rec_score"])
-            yield single_img_res
+            yield OCRResult(single_img_res)

+ 7 - 5
paddlex/inference/results/ocr.py

@@ -49,7 +49,9 @@ class OCRResult(BaseResult):
             index_b = 3
             index_c = 2
 
-        box = np.array([points[index_a], points[index_b], points[index_c], points[index_d]]).astype(np.int32)
+        box = np.array(
+            [points[index_a], points[index_b], points[index_c], points[index_d]]
+        ).astype(np.int32)
 
         return box
 
@@ -85,9 +87,9 @@ class OCRResult(BaseResult):
                     pts = [(x, y) for x, y in box.tolist()]
                     draw_left.polygon(pts, outline=color, width=8)
                     box = self.get_minarea_rect(box)
-                    height = int(0.5 * (max(box[:,1]) - min(box[:,1])))
-                    box[:2,1] = np.mean(box[:,1])
-                    box[2:,1] = np.mean(box[:,1]) + min(20, height)
+                    height = int(0.5 * (max(box[:, 1]) - min(box[:, 1])))
+                    box[:2, 1] = np.mean(box[:, 1])
+                    box[2:, 1] = np.mean(box[:, 1]) + min(20, height)
                 draw_left.polygon(box, fill=color)
                 img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
                 pts = np.array(box, np.int32).reshape((-1, 1, 2))
@@ -99,7 +101,7 @@ class OCRResult(BaseResult):
         img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
         img_show.paste(img_left, (0, 0, w, h))
         img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
-        return np.array(img_show)
+        return cv2.cvtColor(np.array(img_show), cv2.COLOR_RGB2BGR)
 
 
 def draw_box_txt_fine(img_size, box, txt, font_path=PINGFANG_FONT_FILE_PATH):