__init__.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from pathlib import Path
  15. from typing import Any, Dict, Optional
  16. from .base import BasePipeline
  17. from ..utils.pp_option import PaddlePredictorOption
  18. from .components import BaseChat, BaseRetriever, BaseGeneratePrompt
  19. from ...utils import logging
  20. from ...utils.config import parse_config
  21. from .ocr import OCRPipeline
  22. from .doc_preprocessor import DocPreprocessorPipeline
  23. from .layout_parsing import LayoutParsingPipeline
  24. from .pp_chatocr import PP_ChatOCRv3_Pipeline, PP_ChatOCRv4_Pipeline
  25. from .image_classification import ImageClassificationPipeline
  26. from .object_detection import ObjectDetectionPipeline
  27. from .seal_recognition import SealRecognitionPipeline
  28. from .table_recognition import TableRecognitionPipeline
  29. from .table_recognition import TableRecognitionPipelineV2
  30. from .multilingual_speech_recognition import MultilingualSpeechRecognitionPipeline
  31. from .formula_recognition import FormulaRecognitionPipeline
  32. from .image_multilabel_classification import ImageMultiLabelClassificationPipeline
  33. from .video_classification import VideoClassificationPipeline
  34. from .video_detection import VideoDetectionPipeline
  35. from .anomaly_detection import AnomalyDetectionPipeline
  36. from .ts_forecasting import TSFcPipeline
  37. from .ts_anomaly_detection import TSAnomalyDetPipeline
  38. from .ts_classification import TSClsPipeline
  39. from .pp_shitu_v2 import ShiTuV2Pipeline
  40. from .face_recognition import FaceRecPipeline
  41. from .attribute_recognition import (
  42. PedestrianAttributeRecPipeline,
  43. VehicleAttributeRecPipeline,
  44. )
  45. from .semantic_segmentation import SemanticSegmentationPipeline
  46. from .instance_segmentation import InstanceSegmentationPipeline
  47. from .small_object_detection import SmallObjectDetectionPipeline
  48. from .rotated_object_detection import RotatedObjectDetectionPipeline
  49. from .keypoint_detection import KeypointDetectionPipeline
  50. from .open_vocabulary_detection import OpenVocabularyDetectionPipeline
  51. from .open_vocabulary_segmentation import OpenVocabularySegmentationPipeline
  52. def get_pipeline_path(pipeline_name: str) -> str:
  53. """
  54. Get the full path of the pipeline configuration file based on the provided pipeline name.
  55. Args:
  56. pipeline_name (str): The name of the pipeline.
  57. Returns:
  58. str: The full path to the pipeline configuration file or None if not found.
  59. """
  60. pipeline_path = (
  61. Path(__file__).parent.parent.parent
  62. / "configs/pipelines"
  63. / f"{pipeline_name}.yaml"
  64. ).resolve()
  65. if not Path(pipeline_path).exists():
  66. return None
  67. return pipeline_path
  68. def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
  69. """
  70. Load the pipeline configuration.
  71. Args:
  72. pipeline_name (str): The name of the pipeline or the path to the config file.
  73. Returns:
  74. Dict[str, Any]: The parsed pipeline configuration.
  75. Raises:
  76. Exception: If the config file of pipeline does not exist.
  77. """
  78. if not (pipeline_name.endswith(".yml") or pipeline_name.endswith(".yaml")):
  79. pipeline_path = get_pipeline_path(pipeline_name)
  80. if pipeline_path is None:
  81. raise Exception(
  82. f"The pipeline ({pipeline_name}) does not exist! Please use a pipeline name or a config file path!"
  83. )
  84. else:
  85. pipeline_path = pipeline_name
  86. config = parse_config(pipeline_path)
  87. return config
  88. def create_pipeline(
  89. pipeline_name: Optional[str] = None,
  90. config: Optional[Dict[str, Any]] = None,
  91. device: Optional[str] = None,
  92. pp_option: Optional[PaddlePredictorOption] = None,
  93. use_hpip: bool = False,
  94. *args: Any,
  95. **kwargs: Any,
  96. ) -> BasePipeline:
  97. """
  98. Create a pipeline instance based on the provided parameters.
  99. If the input parameter config is not provided, it is obtained from the
  100. default config corresponding to the pipeline name.
  101. Args:
  102. pipeline_name (Optional[str], optional): The name of the pipeline to
  103. create, or the path to the config file. Defaults to None.
  104. config (Optional[Dict[str, Any]], optional): The pipeline configuration.
  105. Defaults to None.
  106. device (Optional[str], optional): The device to run the pipeline on.
  107. Defaults to None.
  108. pp_option (Optional[PaddlePredictorOption], optional): The options for
  109. the PaddlePredictor. Defaults to None.
  110. use_hpip (bool, optional): Whether to use high-performance inference
  111. plugin (HPIP) for prediction. Defaults to False.
  112. *args: Additional positional arguments.
  113. **kwargs: Additional keyword arguments.
  114. Returns:
  115. BasePipeline: The created pipeline instance.
  116. """
  117. if pipeline_name is None and config is None:
  118. raise ValueError(
  119. "Both `pipeline_name` and `config` cannot be None at the same time."
  120. )
  121. if config is None:
  122. config = load_pipeline_config(pipeline_name)
  123. if pipeline_name is not None and config["pipeline_name"] != pipeline_name:
  124. logging.warning(
  125. "The pipeline name in the config (%r) is different from the specified pipeline name (%r). %r will be used.",
  126. config["pipeline_name"],
  127. pipeline_name,
  128. config["pipeline_name"],
  129. )
  130. pipeline_name = config["pipeline_name"]
  131. pipeline = BasePipeline.get(pipeline_name)(
  132. config=config,
  133. device=device,
  134. pp_option=pp_option,
  135. use_hpip=use_hpip,
  136. *args,
  137. **kwargs,
  138. )
  139. return pipeline
  140. def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat:
  141. """Creates an instance of a chat bot based on the provided configuration.
  142. Args:
  143. config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
  144. *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
  145. **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
  146. Returns:
  147. BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config.
  148. """
  149. if "chat_bot_config_error" in config:
  150. raise ValueError(config["chat_bot_config_error"])
  151. api_type = config["api_type"]
  152. chat_bot = BaseChat.get(api_type)(config)
  153. return chat_bot
  154. def create_retriever(
  155. config: Dict,
  156. *args,
  157. **kwargs,
  158. ) -> BaseRetriever:
  159. """
  160. Creates a retriever instance based on the provided configuration.
  161. Args:
  162. config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
  163. *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
  164. **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
  165. Returns:
  166. BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config.
  167. """
  168. if "retriever_config_error" in config:
  169. raise ValueError(config["retriever_config_error"])
  170. api_type = config["api_type"]
  171. retriever = BaseRetriever.get(api_type)(config)
  172. return retriever
  173. def create_prompt_engeering(
  174. config: Dict,
  175. *args,
  176. **kwargs,
  177. ) -> BaseGeneratePrompt:
  178. """
  179. Creates a prompt engineering instance based on the provided configuration.
  180. Args:
  181. config (Dict): Configuration settings, expected to be a dictionary with at least a 'task_type' key.
  182. *args: Variable length argument list for additional positional arguments.
  183. **kwargs: Arbitrary keyword arguments.
  184. Returns:
  185. BaseGeneratePrompt: An instance of a prompt engineering class corresponding to the 'task_type' in the config.
  186. """
  187. if "pe_config_error" in config:
  188. raise ValueError(config["pe_config_error"])
  189. task_type = config["task_type"]
  190. pe = BaseGeneratePrompt.get(task_type)(config)
  191. return pe