__init__.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from pathlib import Path
  15. from typing import Any, Dict, Optional, Union
  16. from ...utils import logging
  17. from ...utils.config import parse_config
  18. from ..utils.hpi import HPIConfig
  19. from ..utils.pp_option import PaddlePredictorOption
  20. from .anomaly_detection import AnomalyDetectionPipeline
  21. from .attribute_recognition import (
  22. PedestrianAttributeRecPipeline,
  23. VehicleAttributeRecPipeline,
  24. )
  25. from .base import BasePipeline
  26. from .components import BaseChat, BaseGeneratePrompt, BaseRetriever
  27. from .doc_preprocessor import DocPreprocessorPipeline
  28. from .doc_understanding import DocUnderstandingPipeline
  29. from .face_recognition import FaceRecPipeline
  30. from .formula_recognition import FormulaRecognitionPipeline
  31. from .image_classification import ImageClassificationPipeline
  32. from .image_multilabel_classification import ImageMultiLabelClassificationPipeline
  33. from .instance_segmentation import InstanceSegmentationPipeline
  34. from .keypoint_detection import KeypointDetectionPipeline
  35. from .layout_parsing import LayoutParsingPipeline
  36. from .m_3d_bev_detection import BEVDet3DPipeline
  37. from .multilingual_speech_recognition import MultilingualSpeechRecognitionPipeline
  38. from .object_detection import ObjectDetectionPipeline
  39. from .ocr import OCRPipeline
  40. from .open_vocabulary_detection import OpenVocabularyDetectionPipeline
  41. from .open_vocabulary_segmentation import OpenVocabularySegmentationPipeline
  42. from .pp_chatocr import PP_ChatOCRv3_Pipeline, PP_ChatOCRv4_Pipeline
  43. from .pp_shitu_v2 import ShiTuV2Pipeline
  44. from .rotated_object_detection import RotatedObjectDetectionPipeline
  45. from .seal_recognition import SealRecognitionPipeline
  46. from .semantic_segmentation import SemanticSegmentationPipeline
  47. from .small_object_detection import SmallObjectDetectionPipeline
  48. from .table_recognition import TableRecognitionPipeline, TableRecognitionPipelineV2
  49. from .ts_anomaly_detection import TSAnomalyDetPipeline
  50. from .ts_classification import TSClsPipeline
  51. from .ts_forecasting import TSFcPipeline
  52. from .video_classification import VideoClassificationPipeline
  53. from .video_detection import VideoDetectionPipeline
  54. def get_pipeline_path(pipeline_name: str) -> str:
  55. """
  56. Get the full path of the pipeline configuration file based on the provided pipeline name.
  57. Args:
  58. pipeline_name (str): The name of the pipeline.
  59. Returns:
  60. str: The full path to the pipeline configuration file or None if not found.
  61. """
  62. pipeline_path = (
  63. Path(__file__).parent.parent.parent
  64. / "configs/pipelines"
  65. / f"{pipeline_name}.yaml"
  66. ).resolve()
  67. if not Path(pipeline_path).exists():
  68. return None
  69. return pipeline_path
  70. def load_pipeline_config(pipeline: str) -> Dict[str, Any]:
  71. """
  72. Load the pipeline configuration.
  73. Args:
  74. pipeline (str): The name of the pipeline or the path to the config file.
  75. Returns:
  76. Dict[str, Any]: The parsed pipeline configuration.
  77. Raises:
  78. Exception: If the config file of pipeline does not exist.
  79. """
  80. if not (pipeline.endswith(".yml") or pipeline.endswith(".yaml")):
  81. pipeline_path = get_pipeline_path(pipeline)
  82. if pipeline_path is None:
  83. raise Exception(
  84. f"The pipeline ({pipeline}) does not exist! Please use a pipeline name or a config file path!"
  85. )
  86. else:
  87. pipeline_path = pipeline
  88. config = parse_config(pipeline_path)
  89. return config
  90. def create_pipeline(
  91. pipeline: Optional[str] = None,
  92. config: Optional[Dict[str, Any]] = None,
  93. device: Optional[str] = None,
  94. pp_option: Optional[PaddlePredictorOption] = None,
  95. use_hpip: Optional[bool] = None,
  96. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  97. *args: Any,
  98. **kwargs: Any,
  99. ) -> BasePipeline:
  100. """
  101. Create a pipeline instance based on the provided parameters.
  102. If the input parameter config is not provided, it is obtained from the
  103. default config corresponding to the pipeline name.
  104. Args:
  105. pipeline (Optional[str], optional): The name of the pipeline to
  106. create, or the path to the config file. Defaults to None.
  107. config (Optional[Dict[str, Any]], optional): The pipeline configuration.
  108. Defaults to None.
  109. device (Optional[str], optional): The device to run the pipeline on.
  110. Defaults to None.
  111. pp_option (Optional[PaddlePredictorOption], optional): The options for
  112. the PaddlePredictor. Defaults to None.
  113. use_hpip (Optional[bool], optional): Whether to use the high-performance
  114. inference plugin (HPIP). If set to None, the setting from the
  115. configuration file or `config` will be used. Defaults to None.
  116. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional): The
  117. high-performance inference configuration dictionary.
  118. Defaults to None.
  119. *args: Additional positional arguments.
  120. **kwargs: Additional keyword arguments.
  121. Returns:
  122. BasePipeline: The created pipeline instance.
  123. """
  124. if pipeline is None and config is None:
  125. raise ValueError(
  126. "Both `pipeline` and `config` cannot be None at the same time."
  127. )
  128. if config is None:
  129. config = load_pipeline_config(pipeline)
  130. else:
  131. if pipeline is not None and config["pipeline_name"] != pipeline:
  132. logging.warning(
  133. "The pipeline name in the config (%r) is different from the specified pipeline name (%r). %r will be used.",
  134. config["pipeline_name"],
  135. pipeline,
  136. config["pipeline_name"],
  137. )
  138. config = config.copy()
  139. pipeline_name = config["pipeline_name"]
  140. if use_hpip is None:
  141. use_hpip = config.pop("use_hpip", False)
  142. else:
  143. config.pop("use_hpip", None)
  144. if hpi_config is None:
  145. hpi_config = config.pop("hpi_config", None)
  146. else:
  147. config.pop("hpi_config", None)
  148. pipeline = BasePipeline.get(pipeline_name)(
  149. config=config,
  150. device=device,
  151. pp_option=pp_option,
  152. use_hpip=use_hpip,
  153. hpi_config=hpi_config,
  154. *args,
  155. **kwargs,
  156. )
  157. return pipeline
  158. def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat:
  159. """Creates an instance of a chat bot based on the provided configuration.
  160. Args:
  161. config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
  162. *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
  163. **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
  164. Returns:
  165. BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config.
  166. """
  167. if "chat_bot_config_error" in config:
  168. raise ValueError(config["chat_bot_config_error"])
  169. api_type = config["api_type"]
  170. chat_bot = BaseChat.get(api_type)(config)
  171. return chat_bot
  172. def create_retriever(
  173. config: Dict,
  174. *args,
  175. **kwargs,
  176. ) -> BaseRetriever:
  177. """
  178. Creates a retriever instance based on the provided configuration.
  179. Args:
  180. config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
  181. *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
  182. **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
  183. Returns:
  184. BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config.
  185. """
  186. if "retriever_config_error" in config:
  187. raise ValueError(config["retriever_config_error"])
  188. api_type = config["api_type"]
  189. retriever = BaseRetriever.get(api_type)(config)
  190. return retriever
  191. def create_prompt_engineering(
  192. config: Dict,
  193. *args,
  194. **kwargs,
  195. ) -> BaseGeneratePrompt:
  196. """
  197. Creates a prompt engineering instance based on the provided configuration.
  198. Args:
  199. config (Dict): Configuration settings, expected to be a dictionary with at least a 'task_type' key.
  200. *args: Variable length argument list for additional positional arguments.
  201. **kwargs: Arbitrary keyword arguments.
  202. Returns:
  203. BaseGeneratePrompt: An instance of a prompt engineering class corresponding to the 'task_type' in the config.
  204. """
  205. if "pe_config_error" in config:
  206. raise ValueError(config["pe_config_error"])
  207. task_type = config["task_type"]
  208. pe = BaseGeneratePrompt.get(task_type)(config)
  209. return pe