瀏覽代碼

fix to RGB in chatocrv3

gaotingquan 1 年之前
父節點
當前提交
c2e1581c45

+ 1 - 1
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -111,7 +111,7 @@ class PPChatOCRPipeline(TableRecPipeline):
         else:
             self.user_prompt_dict = None
         self.recovery = recovery
-        self.img_reader = ReadImage()
+        self.img_reader = ReadImage(format="RGB")
         self.visual_info = None
         self.vector = None
         self.visual_flag = False

+ 9 - 0
paddlex/inference/results/chat_ocr.py

@@ -60,25 +60,34 @@ class VisualResult(BaseResult):
         oricls_save_path = f"{save_path}_oricls.jpg"
         oricls_result = self["oricls_result"]
         if oricls_result:
+            oricls_result._HARD_FLAG = True
             oricls_result.save_to_img(oricls_save_path)
         uvdoc_save_path = f"{save_path}_uvdoc.jpg"
         uvdoc_result = self["uvdoc_result"]
         if uvdoc_result:
+            # uvdoc_result._HARD_FLAG = True
             uvdoc_result.save_to_img(uvdoc_save_path)
         curve_save_path = f"{save_path}_curve.jpg"
         curve_results = self["curve_result"]
+        # TODO(): support list of result
+        if isinstance(curve_results, dict):
+            curve_results = [curve_results]
         for curve_result in curve_results:
+            curve_result._HARD_FLAG = True if not uvdoc_result else False
             curve_result.save_to_img(curve_save_path)
         layout_save_path = f"{save_path}_layout.jpg"
         layout_result = self["layout_result"]
         if layout_result:
+            layout_result._HARD_FLAG = True if not uvdoc_result else False
             layout_result.save_to_img(layout_save_path)
         ocr_save_path = f"{save_path}_ocr.jpg"
         table_save_path = f"{save_path}_table.jpg"
         ocr_result = self["ocr_result"]
         if ocr_result:
+            ocr_result._HARD_FLAG = True if not uvdoc_result else False
             ocr_result.save_to_img(ocr_save_path)
         for table_result in self["table_result"]:
+            table_result._HARD_FLAG = True if not uvdoc_result else False
             table_result.save_to_img(table_save_path)
 
 

+ 4 - 0
paddlex/inference/results/clas.py

@@ -24,6 +24,7 @@ from .base import CVResult
 
 
 class TopkResult(CVResult):
+    _HARD_FLAG = False
 
     def _to_img(self):
         """Draw label on image"""
@@ -31,6 +32,9 @@ class TopkResult(CVResult):
         label_str = f"{labels[0]} {self['scores'][0]:.2f}"
 
         image = self._img_reader.read(self["input_path"])
+        if self._HARD_FLAG:
+            image_np = np.array(image)
+            image = Image.fromarray(image_np[:, :, ::-1])
         image_size = image.size
         draw = ImageDraw.Draw(image)
         min_font_size = int(image_size[0] * 0.02)

+ 6 - 0
paddlex/inference/results/det.py

@@ -14,6 +14,7 @@
 
 import os
 import cv2
+import numpy as np
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 
@@ -76,9 +77,14 @@ def draw_box(img, boxes):
 class DetResult(CVResult):
     """Save Result Transform"""
 
+    _HARD_FLAG = False
+
     def _to_img(self):
         """apply"""
         boxes = self["boxes"]
         image = self._img_reader.read(self["input_path"])
+        if self._HARD_FLAG:
+            image_np = np.array(image)
+            image = Image.fromarray(image_np[:, :, ::-1])
         image = draw_box(image, boxes)
         return image

+ 4 - 0
paddlex/inference/results/ocr.py

@@ -24,6 +24,7 @@ from .base import CVResult
 
 
 class OCRResult(CVResult):
+    _HARD_FLAG = False
 
     def get_minarea_rect(self, points):
         bounding_box = cv2.minAreaRect(points)
@@ -60,6 +61,9 @@ class OCRResult(CVResult):
         txts = self["rec_text"]
         scores = self["rec_score"]
         image = self._img_reader.read(self["input_path"])
+        if self._HARD_FLAG:
+            image_np = np.array(image)
+            image = Image.fromarray(image_np[:, :, ::-1])
         h, w = image.height, image.width
         img_left = image.copy()
         img_right = np.ones((h, w, 3), dtype=np.uint8) * 255

+ 7 - 0
paddlex/inference/results/table_rec.py

@@ -15,6 +15,8 @@
 import cv2
 import numpy as np
 from pathlib import Path
+import PIL
+from PIL import Image, ImageDraw, ImageFont
 
 from .utils.mixin import HtmlMixin, XlsxMixin
 from .base import BaseResult, CVResult
@@ -23,6 +25,8 @@ from .base import BaseResult, CVResult
 class TableRecResult(CVResult, HtmlMixin):
     """SaveTableResults"""
 
+    _HARD_FLAG = False
+
     def __init__(self, data):
         super().__init__(data)
         HtmlMixin.__init__(self)
@@ -33,6 +37,9 @@ class TableRecResult(CVResult, HtmlMixin):
 
     def _to_img(self):
         image = self._img_reader.read(self["input_path"])
+        if self._HARD_FLAG:
+            image_np = np.array(image)
+            image = Image.fromarray(image_np[:, :, ::-1])
         bbox_res = self["bbox"]
         if len(bbox_res) > 0 and len(bbox_res[0]) == 4:
             vis_img = self.draw_rectangle(image, bbox_res)