|
|
@@ -12,8 +12,6 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
-import os
|
|
|
-import tempfile
|
|
|
from typing import Any, Dict, List
|
|
|
|
|
|
from fastapi import FastAPI
|
|
|
@@ -23,35 +21,10 @@ from ...infra.config import AppConfig
|
|
|
from ...infra.models import ResultResponse
|
|
|
from ...schemas import pp_chatocrv3_doc as schema
|
|
|
from .._app import create_app, primary_operation
|
|
|
+from ._common import common
|
|
|
from ._common import ocr as ocr_common
|
|
|
|
|
|
|
|
|
-# XXX: Since the pipeline class does not provide serialization and
|
|
|
-# deserialization methods, these are implemented here based on the save-to-path
|
|
|
-# and load-from-path methods.
|
|
|
-def _serialize_vector_info(pipeline: Any, vector_info: dict) -> str:
|
|
|
- with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as f:
|
|
|
- path = f.name
|
|
|
- try:
|
|
|
- pipeline.save_vector(vector_info, path)
|
|
|
- with open(path, "r", encoding="utf-8") as f:
|
|
|
- return f.read()
|
|
|
- finally:
|
|
|
- os.unlink(path)
|
|
|
-
|
|
|
-
|
|
|
-def _deserialize_vector_info(pipeline: Any, vector_info: str) -> dict:
|
|
|
- with tempfile.NamedTemporaryFile(
|
|
|
- "w", encoding="utf-8", suffix=".json", delete=False
|
|
|
- ) as f:
|
|
|
- f.write(vector_info)
|
|
|
- path = f.name
|
|
|
- try:
|
|
|
- return pipeline.load_vector(path)
|
|
|
- finally:
|
|
|
- os.unlink(path)
|
|
|
-
|
|
|
-
|
|
|
def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
|
|
|
app, ctx = create_app(
|
|
|
pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
|
|
|
@@ -81,46 +54,58 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
|
|
|
use_general_ocr=request.useGeneralOcr,
|
|
|
use_seal_recognition=request.useSealRecognition,
|
|
|
use_table_recognition=request.useTableRecognition,
|
|
|
+ text_det_limit_side_len=request.textDetLimitSideLen,
|
|
|
+ text_det_limit_type=request.textDetLimitType,
|
|
|
+ text_det_thresh=request.textDetThresh,
|
|
|
+ text_det_box_thresh=request.textDetBoxThresh,
|
|
|
+ text_det_unclip_ratio=request.textDetUnclipRatio,
|
|
|
+ text_rec_score_thresh=request.textRecScoreThresh,
|
|
|
+ seal_det_limit_side_len=request.sealDetLimitSideLen,
|
|
|
+ seal_det_limit_type=request.sealDetLimitType,
|
|
|
+ seal_det_thresh=request.sealDetThresh,
|
|
|
+ seal_det_box_thresh=request.sealDetBoxThresh,
|
|
|
+ seal_det_unclip_ratio=request.sealDetUnclipRatio,
|
|
|
+ seal_rec_score_thresh=request.sealRecScoreThresh,
|
|
|
)
|
|
|
|
|
|
- visual_results: List[Dict[str, Any]] = []
|
|
|
- for i, (img, item) in enumerate(zip(images, result["layout_parsing_result"])):
|
|
|
- texts: List[dict] = []
|
|
|
- for poly, text, score in zip(
|
|
|
- item["ocr_result"]["dt_polys"],
|
|
|
- item["ocr_result"]["rec_text"],
|
|
|
- item["ocr_result"]["rec_score"],
|
|
|
- ):
|
|
|
- texts.append(dict(poly=poly, text=text, score=score))
|
|
|
- tables = [
|
|
|
- dict(bbox=r["layout_bbox"], html=r["html"])
|
|
|
- for r in item["table_result"]
|
|
|
- ]
|
|
|
+ layout_parsing_results: List[Dict[str, Any]] = []
|
|
|
+ visual_info: List[dict] = []
|
|
|
+ for i, (img, item) in enumerate(zip(images, result)):
|
|
|
+ pruned_res = common.prune_result(item["layout_parsing_result"].json["res"])
|
|
|
if ctx.config.visualize:
|
|
|
- input_img, layout_img, ocr_img = await ocr_common.postprocess_images(
|
|
|
- log_id=log_id,
|
|
|
- index=i,
|
|
|
- app_context=ctx,
|
|
|
- input_image=img,
|
|
|
- layout_image=item["layout_result"].img,
|
|
|
- ocr_image=item["ocr_result"].img,
|
|
|
+ imgs = {
|
|
|
+ "input_img": img,
|
|
|
+ **item["layout_parsing_result"].img,
|
|
|
+ }
|
|
|
+ imgs = await serving_utils.call_async(
|
|
|
+ common.postprocess_images,
|
|
|
+ imgs,
|
|
|
+ log_id,
|
|
|
+ filename_template=f"{{key}}_{i}.jpg",
|
|
|
+ file_storage=ctx.extra["file_storage"],
|
|
|
+ return_urls=ctx.extra["return_img_urls"],
|
|
|
+ max_img_size=ctx.extra["max_output_img_size"],
|
|
|
)
|
|
|
else:
|
|
|
- input_img, layout_img, ocr_img = None, None, None
|
|
|
- visual_result = dict(
|
|
|
- texts=texts,
|
|
|
- tables=tables,
|
|
|
- inputImage=input_img,
|
|
|
- layoutImage=layout_img,
|
|
|
- ocrImage=ocr_img,
|
|
|
+ imgs = {}
|
|
|
+ layout_parsing_results.append(
|
|
|
+ dict(
|
|
|
+ prunedResult=pruned_res,
|
|
|
+ outputImages=(
|
|
|
+ {k: v for k, v in imgs.items() if k != "input_img"}
|
|
|
+ if imgs
|
|
|
+ else None
|
|
|
+ ),
|
|
|
+ inputImage=imgs.get("input_img"),
|
|
|
+ )
|
|
|
)
|
|
|
- visual_results.append(visual_result)
|
|
|
+ visual_info.append(item["visual_info"])
|
|
|
|
|
|
return ResultResponse[schema.AnalyzeImagesResult](
|
|
|
logId=log_id,
|
|
|
result=schema.AnalyzeImagesResult(
|
|
|
- visualResults=visual_results,
|
|
|
- visualInfo=result["visual_info"],
|
|
|
+ layoutParsingResults=layout_parsing_results,
|
|
|
+ visualInfo=visual_info,
|
|
|
dataInfo=data_info,
|
|
|
),
|
|
|
)
|
|
|
@@ -135,15 +120,16 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
|
|
|
) -> ResultResponse[schema.BuildVectorStoreResult]:
|
|
|
pipeline = ctx.pipeline
|
|
|
|
|
|
+ kwargs: Dict[str, Any] = {"flag_save_bytes_vector": True}
|
|
|
+ if request.minCharacters is not None:
|
|
|
+ kwargs["min_characters"] = request.minCharacters
|
|
|
+ if request.llmRequestInterval is not None:
|
|
|
+ kwargs["llm_request_interval"] = request.llmRequestInterval
|
|
|
+
|
|
|
vector_info = await serving_utils.call_async(
|
|
|
pipeline.pipeline.build_vector,
|
|
|
request.visualInfo,
|
|
|
- min_characters=request.minCharacters,
|
|
|
- llm_request_interval=request.llmRequestInterval,
|
|
|
- )
|
|
|
-
|
|
|
- vector_info = await serving_utils.call_async(
|
|
|
- _serialize_vector_info, pipeline.pipeline, vector_info
|
|
|
+ **kwargs,
|
|
|
)
|
|
|
|
|
|
return ResultResponse[schema.BuildVectorStoreResult](
|
|
|
@@ -161,22 +147,8 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
|
|
|
) -> ResultResponse[schema.ChatResult]:
|
|
|
pipeline = ctx.pipeline
|
|
|
|
|
|
- if request.vectorInfo:
|
|
|
- vector_info = await serving_utils.call_async(
|
|
|
- _deserialize_vector_info,
|
|
|
- pipeline.pipeline,
|
|
|
- request.vectorInfo,
|
|
|
- )
|
|
|
- else:
|
|
|
- vector_info = None
|
|
|
-
|
|
|
- result = await serving_utils.call_async(
|
|
|
- pipeline.pipeline.chat,
|
|
|
- request.keyList,
|
|
|
- request.visualInfo,
|
|
|
- use_vector_retrieval=request.useVectorRetrieval,
|
|
|
- vector_info=vector_info,
|
|
|
- min_characters=request.minCharacters,
|
|
|
+ kwargs: Dict[str, Any] = dict(
|
|
|
+ vector_info=request.vectorInfo,
|
|
|
text_task_description=request.textTaskDescription,
|
|
|
text_output_format=request.textOutputFormat,
|
|
|
text_rules_str=request.textRulesStr,
|
|
|
@@ -188,6 +160,16 @@ def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
|
|
|
table_few_shot_demo_text_content=request.tableFewShotDemoTextContent,
|
|
|
table_few_shot_demo_key_value_list=request.tableFewShotDemoKeyValueList,
|
|
|
)
|
|
|
+ if request.useVectorRetrieval is not None:
|
|
|
+ kwargs["use_vector_retrieval"] = request.useVectorRetrieval
|
|
|
+ if request.minCharacters is not None:
|
|
|
+ kwargs["min_characters"] = request.minCharacters
|
|
|
+
|
|
|
+ result = await serving_utils.call_async(
|
|
|
+ pipeline.pipeline.chat,
|
|
|
+ request.keyList,
|
|
|
+ request.visualInfo,
|
|
|
+ )
|
|
|
|
|
|
return ResultResponse[schema.ChatResult](
|
|
|
logId=serving_utils.generate_log_id(),
|