| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from .base import BaseGeneratePrompt
- class GenerateKIEPrompt(BaseGeneratePrompt):
- """Generate KIE Prompt"""
- entities = [
- "text_kie_prompt",
- "table_kie_prompt"
- ]
- def __init__(self, config):
- super().__init__()
- task_type = config.get('task_type', "")
- task_description = config.get('task_description', "")
- output_format = config.get('output_format', "")
- rules_str = config.get('rules_str', "")
- few_shot_demo_text_content = config.get('few_shot_demo_text_content', "")
- few_shot_demo_key_value_list = config.get('few_shot_demo_key_value_list', "")
- if task_description is None:
- task_description = ""
-
- if output_format is None:
- output_format = ""
-
- if rules_str is None:
- rules_str = ""
-
- if few_shot_demo_text_content is None:
- few_shot_demo_text_content = ""
-
- if few_shot_demo_key_value_list is None:
- few_shot_demo_key_value_list = ""
- if task_type not in self.entities:
- raise ValueError(f"task type must be in {self.entities} of GenerateKIEPrompt.")
- self.task_type = task_type
- self.task_description = task_description
- self.output_format = output_format
- self.rules_str = rules_str
- self.few_shot_demo_text_content = few_shot_demo_text_content
- self.few_shot_demo_key_value_list = few_shot_demo_key_value_list
-
- def generate_prompt(self, text_content,
- key_list,
- task_description=None,
- output_format=None,
- rules_str=None,
- few_shot_demo_text_content=None,
- few_shot_demo_key_value_list=None):
- """
- args:
- return:
- """
- if task_description is None:
- task_description = self.task_description
- if output_format is None:
- output_format = self.output_format
- if rules_str is None:
- rules_str = self.rules_str
- if few_shot_demo_text_content is None:
- few_shot_demo_text_content = self.few_shot_demo_text_content
-
- if few_shot_demo_key_value_list is None:
- few_shot_demo_key_value_list = self.few_shot_demo_key_value_list
- prompt = f"""{task_description}{output_format}{rules_str}{few_shot_demo_text_content}{few_shot_demo_key_value_list}"""
- if self.task_type == "table_kie_prompt":
- prompt += f"""\n结合上面,下面正式开始:\
- 表格内容:```{text_content}```\
- 关键词列表:{key_list}。""".replace(
- " ", "")
- elif self.task_type == "text_kie_prompt":
- prompt += f"""\n结合上面的例子,下面正式开始:\
- OCR文字:```{text_content}```\
- 关键词列表:{key_list}。""".replace(
- " ", "")
- else:
- raise ValueError(f"{self.task_type} is currently not supported.")
- return prompt
|