__init__.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # from .single_model_pipeline import (
  15. # _SingleModelPipeline,
  16. # ImageClassification,
  17. # ObjectDetection,
  18. # InstanceSegmentation,
  19. # SemanticSegmentation,
  20. # TSFc,
  21. # TSAd,
  22. # TSCls,
  23. # MultiLableImageClas,
  24. # SmallObjDet,
  25. # AnomalyDetection,
  26. # )
  27. # from .ocr import OCRPipeline
  28. # from .formula_recognition import FormulaRecognitionPipeline
  29. # from .table_recognition import TableRecPipeline
  30. # from .face_recognition import FaceRecPipeline
  31. # from .seal_recognition import SealOCRPipeline
  32. # from .ppchatocrv3 import PPChatOCRPipeline
  33. # from .layout_parsing import LayoutParsingPipeline
  34. # from .pp_shitu_v2 import ShiTuV2Pipeline
  35. # from .attribute_recognition import AttributeRecPipeline
  36. from pathlib import Path
  37. from typing import Any, Dict, Optional
  38. from .base import BasePipeline
  39. from ..utils.pp_option import PaddlePredictorOption
  40. from .components import BaseChat, BaseRetriever, BaseGeneratePrompt
  41. from ...utils.config import parse_config
  42. from .ocr import OCRPipeline
  43. from .doc_preprocessor import DocPreprocessorPipeline
  44. from .layout_parsing import LayoutParsingPipeline
  45. from .pp_chatocrv3_doc import PP_ChatOCRv3_doc_Pipeline
  46. def get_pipeline_path(pipeline_name: str) -> str:
  47. """
  48. Get the full path of the pipeline configuration file based on the provided pipeline name.
  49. Args:
  50. pipeline_name (str): The name of the pipeline.
  51. Returns:
  52. str: The full path to the pipeline configuration file or None if not found.
  53. """
  54. pipeline_path = (
  55. Path(__file__).parent.parent.parent
  56. / "configs/pipelines"
  57. / f"{pipeline_name}.yaml"
  58. ).resolve()
  59. if not Path(pipeline_path).exists():
  60. return None
  61. return pipeline_path
  62. def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
  63. """
  64. Load the pipeline configuration.
  65. Args:
  66. pipeline_name (str): The name of the pipeline or the path to the config file.
  67. Returns:
  68. Dict[str, Any]: The parsed pipeline configuration.
  69. Raises:
  70. Exception: If the config file of pipeline does not exist.
  71. """
  72. if not Path(pipeline_name).exists():
  73. pipeline_path = get_pipeline_path(pipeline_name)
  74. if pipeline_path is None:
  75. raise Exception(
  76. f"The pipeline ({pipeline_name}) does not exist! Please use a pipeline name or a config file path!"
  77. )
  78. else:
  79. pipeline_path = pipeline_name
  80. config = parse_config(pipeline_path)
  81. return config
  82. def create_pipeline(
  83. pipeline: str,
  84. config: Dict = None,
  85. device: str = None,
  86. pp_option: PaddlePredictorOption = None,
  87. use_hpip: bool = False,
  88. hpi_params: Optional[Dict[str, Any]] = None,
  89. *args,
  90. **kwargs,
  91. ) -> BasePipeline:
  92. """
  93. Create a pipeline instance based on the provided parameters.
  94. If the input parameter config is not provided,
  95. it is obtained from the default config corresponding to the pipeline name.
  96. Args:
  97. pipeline (str): The name of the pipeline to create.
  98. config (Dict, optional): The path to the pipeline configuration file. Defaults to None.
  99. device (str, optional): The device to run the pipeline on. Defaults to None.
  100. pp_option (PaddlePredictorOption, optional): The options for the PaddlePredictor. Defaults to None.
  101. use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
  102. hpi_params (Optional[Dict[str, Any]], optional): Additional parameters for hpip. Defaults to None.
  103. *args: Additional positional arguments.
  104. **kwargs: Additional keyword arguments.
  105. Returns:
  106. BasePipeline: The created pipeline instance.
  107. """
  108. pipeline_name = pipeline
  109. if config is None:
  110. config = load_pipeline_config(pipeline_name)
  111. assert pipeline_name == config["pipeline_name"]
  112. pipeline = BasePipeline.get(pipeline_name)(
  113. config=config,
  114. device=device,
  115. pp_option=pp_option,
  116. use_hpip=use_hpip,
  117. hpi_params=hpi_params,
  118. )
  119. return pipeline
  120. def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat:
  121. """Creates an instance of a chat bot based on the provided configuration.
  122. Args:
  123. config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
  124. *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
  125. **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
  126. Returns:
  127. BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config.
  128. """
  129. model_name = config["model_name"]
  130. chat_bot = BaseChat.get(model_name)(config)
  131. return chat_bot
  132. def create_retriever(
  133. config: Dict,
  134. *args,
  135. **kwargs,
  136. ) -> BaseRetriever:
  137. """
  138. Creates a retriever instance based on the provided configuration.
  139. Args:
  140. config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
  141. *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
  142. **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
  143. Returns:
  144. BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config.
  145. """
  146. model_name = config["model_name"]
  147. retriever = BaseRetriever.get(model_name)(config)
  148. return retriever
  149. def create_prompt_engeering(
  150. config: Dict,
  151. *args,
  152. **kwargs,
  153. ) -> BaseGeneratePrompt:
  154. """
  155. Creates a prompt engineering instance based on the provided configuration.
  156. Args:
  157. config (Dict): Configuration settings, expected to be a dictionary with at least a 'task_type' key.
  158. *args: Variable length argument list for additional positional arguments.
  159. **kwargs: Arbitrary keyword arguments.
  160. Returns:
  161. BaseGeneratePrompt: An instance of a prompt engineering class corresponding to the 'task_type' in the config.
  162. """
  163. task_type = config["task_type"]
  164. pe = BaseGeneratePrompt.get(task_type)(config)
  165. return pe