|
|
@@ -422,16 +422,9 @@ class PPChatOCRPipeline(_TableRecPipeline):
|
|
|
if not any([visual_info, self.visual_info]):
|
|
|
return VectorResult({"vector": None})
|
|
|
|
|
|
- if visual_info:
|
|
|
- # use for serving or local
|
|
|
- _visual_info = visual_info
|
|
|
- else:
|
|
|
- # use for local
|
|
|
- _visual_info = self.visual_info
|
|
|
-
|
|
|
- ocr_text = _visual_info["ocr_text"]
|
|
|
- html_list = _visual_info["table_html"]
|
|
|
- table_text_list = _visual_info["table_text"]
|
|
|
+ ocr_text = visual_info["ocr_text"]
|
|
|
+ html_list = visual_info["table_html"]
|
|
|
+ table_text_list = visual_info["table_text"]
|
|
|
|
|
|
# add table text to ocr text
|
|
|
for html, table_text_rec in zip(html_list, table_text_list):
|
|
|
@@ -457,36 +450,16 @@ class PPChatOCRPipeline(_TableRecPipeline):
|
|
|
def retrieval(
|
|
|
self,
|
|
|
key_list,
|
|
|
- visual_info=None,
|
|
|
- vector=None,
|
|
|
+ vector,
|
|
|
llm_name=None,
|
|
|
llm_params={},
|
|
|
llm_request_interval=0.1,
|
|
|
):
|
|
|
-
|
|
|
- if not any([visual_info, vector, self.visual_info, self.vector]):
|
|
|
- return RetrievalResult({"retrieval": None})
|
|
|
-
|
|
|
+ assert "vector" in vector
|
|
|
key_list = format_key(key_list)
|
|
|
|
|
|
- is_seving = visual_info and llm_name
|
|
|
-
|
|
|
- if self.visual_flag and not is_seving:
|
|
|
- self.vector = self.build_vector()
|
|
|
-
|
|
|
- if not any([vector, self.vector]):
|
|
|
- logging.warning(
|
|
|
- "The vector library is not created, and is being created automatically"
|
|
|
- )
|
|
|
- if is_seving:
|
|
|
- # for serving
|
|
|
- vector = self.build_vector(
|
|
|
- llm_name=llm_name, llm_params=llm_params, visual_info=visual_info
|
|
|
- )
|
|
|
- else:
|
|
|
- self.vector = self.build_vector()
|
|
|
-
|
|
|
- if vector and llm_name:
|
|
|
+ # for serving
|
|
|
+ if llm_name:
|
|
|
_vector = vector["vector"]
|
|
|
llm_api = create_llm_api(llm_name, llm_params)
|
|
|
retrieval = llm_api.caculate_similar(
|
|
|
@@ -496,7 +469,7 @@ class PPChatOCRPipeline(_TableRecPipeline):
|
|
|
sleep_time=llm_request_interval,
|
|
|
)
|
|
|
else:
|
|
|
- _vector = self.vector["vector"]
|
|
|
+ _vector = vector["vector"]
|
|
|
retrieval = self.llm_api.caculate_similar(
|
|
|
vector=_vector, key_list=key_list, sleep_time=llm_request_interval
|
|
|
)
|
|
|
@@ -512,33 +485,24 @@ class PPChatOCRPipeline(_TableRecPipeline):
|
|
|
user_task_description="",
|
|
|
rules="",
|
|
|
few_shot="",
|
|
|
- use_retrieval=True,
|
|
|
save_prompt=False,
|
|
|
- llm_name="ernie-3.5",
|
|
|
+ llm_name=None,
|
|
|
llm_params={},
|
|
|
):
|
|
|
"""
|
|
|
chat with key
|
|
|
|
|
|
"""
|
|
|
- if not any(
|
|
|
- [vector, visual_info, retrieval_result, self.visual_info, self.vector]
|
|
|
- ):
|
|
|
+ if not any([vector, visual_info, retrieval_result]):
|
|
|
return ChatResult(
|
|
|
{"chat_res": "请先完成图像解析再开始再对话", "prompt": ""}
|
|
|
)
|
|
|
key_list = format_key(key_list)
|
|
|
# first get from table, then get from text in table, last get from all ocr
|
|
|
- if visual_info:
|
|
|
- # use for serving or local
|
|
|
- _visual_info = visual_info
|
|
|
- else:
|
|
|
- # use for local
|
|
|
- _visual_info = self.visual_info
|
|
|
|
|
|
- ocr_text = _visual_info["ocr_text"]
|
|
|
- html_list = _visual_info["table_html"]
|
|
|
- table_text_list = _visual_info["table_text"]
|
|
|
+ ocr_text = visual_info["ocr_text"]
|
|
|
+ html_list = visual_info["table_html"]
|
|
|
+ table_text_list = visual_info["table_text"]
|
|
|
|
|
|
prompt_res = {"ocr_prompt": "str", "table_prompt": [], "html_prompt": []}
|
|
|
|
|
|
@@ -571,18 +535,21 @@ class PPChatOCRPipeline(_TableRecPipeline):
|
|
|
logging.debug("get result from ocr")
|
|
|
if retrieval_result:
|
|
|
ocr_text = retrieval_result.get("retrieval")
|
|
|
- elif use_retrieval and any([visual_info, vector]):
|
|
|
- # for serving or local
|
|
|
- ocr_text = self.retrieval(
|
|
|
- key_list=key_list,
|
|
|
- visual_info=visual_info,
|
|
|
- vector=vector,
|
|
|
- llm_name=llm_name,
|
|
|
- llm_params=llm_params,
|
|
|
- )["retrieval"]
|
|
|
- else:
|
|
|
+ elif vector:
|
|
|
+ # for serving
|
|
|
+ if llm_name:
|
|
|
+ ocr_text = self.retrieval(
|
|
|
+ key_list=key_list,
|
|
|
+ vector=vector,
|
|
|
+ llm_name=llm_name,
|
|
|
+ llm_params=llm_params,
|
|
|
+ )["retrieval"]
|
|
|
# for local
|
|
|
- ocr_text = self.retrieval(key_list=key_list)["retrieval"]
|
|
|
+ else:
|
|
|
+ ocr_text = self.retrieval(key_list=key_list, vector=vector)[
|
|
|
+ "retrieval"
|
|
|
+ ]
|
|
|
+
|
|
|
prompt = self.get_prompt_for_ocr(
|
|
|
ocr_text,
|
|
|
key_list,
|