__init__.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from pathlib import Path
  15. from typing import Any, Dict, Optional, Union
  16. from ...utils import logging
  17. from ...utils.config import parse_config
  18. from ..utils.hpi import HPIConfig
  19. from ..utils.pp_option import PaddlePredictorOption
  20. from .anomaly_detection import AnomalyDetectionPipeline
  21. from .attribute_recognition import (
  22. PedestrianAttributeRecPipeline,
  23. VehicleAttributeRecPipeline,
  24. )
  25. from .base import BasePipeline
  26. from .components import BaseChat, BaseGeneratePrompt, BaseRetriever
  27. from .doc_preprocessor import DocPreprocessorPipeline
  28. from .doc_understanding import DocUnderstandingPipeline
  29. from .face_recognition import FaceRecPipeline
  30. from .formula_recognition import FormulaRecognitionPipeline
  31. from .image_classification import ImageClassificationPipeline
  32. from .image_multilabel_classification import ImageMultiLabelClassificationPipeline
  33. from .instance_segmentation import InstanceSegmentationPipeline
  34. from .keypoint_detection import KeypointDetectionPipeline
  35. from .layout_parsing import LayoutParsingPipeline
  36. from .m_3d_bev_detection import BEVDet3DPipeline
  37. from .multilingual_speech_recognition import MultilingualSpeechRecognitionPipeline
  38. from .object_detection import ObjectDetectionPipeline
  39. from .ocr import OCRPipeline
  40. from .open_vocabulary_detection import OpenVocabularyDetectionPipeline
  41. from .open_vocabulary_segmentation import OpenVocabularySegmentationPipeline
  42. from .paddleocr_vl import PPOCRVLPipeline
  43. from .pp_chatocr import PP_ChatOCRv3_Pipeline, PP_ChatOCRv4_Pipeline
  44. from .pp_doctranslation import PP_DocTranslation_Pipeline
  45. from .pp_shitu_v2 import ShiTuV2Pipeline
  46. from .rotated_object_detection import RotatedObjectDetectionPipeline
  47. from .seal_recognition import SealRecognitionPipeline
  48. from .semantic_segmentation import SemanticSegmentationPipeline
  49. from .small_object_detection import SmallObjectDetectionPipeline
  50. from .table_recognition import TableRecognitionPipeline, TableRecognitionPipelineV2
  51. from .ts_anomaly_detection import TSAnomalyDetPipeline
  52. from .ts_classification import TSClsPipeline
  53. from .ts_forecasting import TSFcPipeline
  54. from .video_classification import VideoClassificationPipeline
  55. from .video_detection import VideoDetectionPipeline
  56. def get_pipeline_path(pipeline_name: str) -> str:
  57. """
  58. Get the full path of the pipeline configuration file based on the provided pipeline name.
  59. Args:
  60. pipeline_name (str): The name of the pipeline.
  61. Returns:
  62. str: The full path to the pipeline configuration file or None if not found.
  63. """
  64. pipeline_path = (
  65. Path(__file__).parent.parent.parent
  66. / "configs/pipelines"
  67. / f"{pipeline_name}.yaml"
  68. ).resolve()
  69. if not Path(pipeline_path).exists():
  70. return None
  71. return pipeline_path
  72. def load_pipeline_config(pipeline: str) -> Dict[str, Any]:
  73. """
  74. Load the pipeline configuration.
  75. Args:
  76. pipeline (str): The name of the pipeline or the path to the config file.
  77. Returns:
  78. Dict[str, Any]: The parsed pipeline configuration.
  79. Raises:
  80. Exception: If the config file of pipeline does not exist.
  81. """
  82. if not (pipeline.endswith(".yml") or pipeline.endswith(".yaml")):
  83. pipeline_path = get_pipeline_path(pipeline)
  84. if pipeline_path is None:
  85. raise Exception(
  86. f"The pipeline ({pipeline}) does not exist! Please use a pipeline name or a config file path!"
  87. )
  88. else:
  89. pipeline_path = pipeline
  90. config = parse_config(pipeline_path)
  91. return config
  92. def create_pipeline(
  93. pipeline: Optional[str] = None,
  94. config: Optional[Dict[str, Any]] = None,
  95. device: Optional[str] = None,
  96. pp_option: Optional[PaddlePredictorOption] = None,
  97. use_hpip: Optional[bool] = None,
  98. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  99. *args: Any,
  100. **kwargs: Any,
  101. ) -> BasePipeline:
  102. """
  103. Create a pipeline instance based on the provided parameters.
  104. If the input parameter config is not provided, it is obtained from the
  105. default config corresponding to the pipeline name.
  106. Args:
  107. pipeline (Optional[str], optional): The name of the pipeline to
  108. create, or the path to the config file. Defaults to None.
  109. config (Optional[Dict[str, Any]], optional): The pipeline configuration.
  110. Defaults to None.
  111. device (Optional[str], optional): The device to run the pipeline on.
  112. Defaults to None.
  113. pp_option (Optional[PaddlePredictorOption], optional): The options for
  114. the PaddlePredictor. Defaults to None.
  115. use_hpip (Optional[bool], optional): Whether to use the high-performance
  116. inference plugin (HPIP). If set to None, the setting from the
  117. configuration file or `config` will be used. Defaults to None.
  118. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional): The
  119. high-performance inference configuration dictionary.
  120. Defaults to None.
  121. *args: Additional positional arguments.
  122. **kwargs: Additional keyword arguments.
  123. Returns:
  124. BasePipeline: The created pipeline instance.
  125. """
  126. if pipeline is None and config is None:
  127. raise ValueError(
  128. "Both `pipeline` and `config` cannot be None at the same time."
  129. )
  130. if config is None:
  131. config = load_pipeline_config(pipeline)
  132. else:
  133. if pipeline is not None and config["pipeline_name"] != pipeline:
  134. logging.warning(
  135. "The pipeline name in the config (%r) is different from the specified pipeline name (%r). %r will be used.",
  136. config["pipeline_name"],
  137. pipeline,
  138. config["pipeline_name"],
  139. )
  140. config = config.copy()
  141. pipeline_name = config["pipeline_name"]
  142. if use_hpip is None:
  143. use_hpip = config.pop("use_hpip", False)
  144. else:
  145. config.pop("use_hpip", None)
  146. if hpi_config is None:
  147. hpi_config = config.pop("hpi_config", None)
  148. else:
  149. config.pop("hpi_config", None)
  150. pipeline = BasePipeline.get(pipeline_name)(
  151. config=config,
  152. device=device,
  153. pp_option=pp_option,
  154. use_hpip=use_hpip,
  155. hpi_config=hpi_config,
  156. *args,
  157. **kwargs,
  158. )
  159. return pipeline
  160. def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat:
  161. """Creates an instance of a chat bot based on the provided configuration.
  162. Args:
  163. config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
  164. *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
  165. **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
  166. Returns:
  167. BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config.
  168. """
  169. if "chat_bot_config_error" in config:
  170. raise ValueError(config["chat_bot_config_error"])
  171. api_type = config["api_type"]
  172. chat_bot = BaseChat.get(api_type)(config)
  173. return chat_bot
  174. def create_retriever(
  175. config: Dict,
  176. *args,
  177. **kwargs,
  178. ) -> BaseRetriever:
  179. """
  180. Creates a retriever instance based on the provided configuration.
  181. Args:
  182. config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
  183. *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
  184. **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
  185. Returns:
  186. BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config.
  187. """
  188. if "retriever_config_error" in config:
  189. raise ValueError(config["retriever_config_error"])
  190. api_type = config["api_type"]
  191. retriever = BaseRetriever.get(api_type)(config)
  192. return retriever
  193. def create_prompt_engineering(
  194. config: Dict,
  195. *args,
  196. **kwargs,
  197. ) -> BaseGeneratePrompt:
  198. """
  199. Creates a prompt engineering instance based on the provided configuration.
  200. Args:
  201. config (Dict): Configuration settings, expected to be a dictionary with at least a 'task_type' key.
  202. *args: Variable length argument list for additional positional arguments.
  203. **kwargs: Arbitrary keyword arguments.
  204. Returns:
  205. BaseGeneratePrompt: An instance of a prompt engineering class corresponding to the 'task_type' in the config.
  206. """
  207. if "pe_config_error" in config:
  208. raise ValueError(config["pe_config_error"])
  209. task_type = config["task_type"]
  210. pe = BaseGeneratePrompt.get(task_type)(config)
  211. return pe