__init__.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from pathlib import Path
  15. from typing import Any, Dict, Optional
  16. from importlib import import_module
  17. from .base import BasePipeline
  18. from ..utils.pp_option import PaddlePredictorOption
  19. from .components import BaseChat, BaseRetriever, BaseGeneratePrompt
  20. from ...utils import logging
  21. from ...utils.config import parse_config
  22. from .ocr import OCRPipeline
  23. from .doc_preprocessor import DocPreprocessorPipeline
  24. from .layout_parsing import LayoutParsingPipeline
  25. from .pp_chatocr import PP_ChatOCRv3_Pipeline, PP_ChatOCRv4_Pipeline
  26. from .image_classification import ImageClassificationPipeline
  27. from .object_detection import ObjectDetectionPipeline
  28. from .seal_recognition import SealRecognitionPipeline
  29. from .table_recognition import TableRecognitionPipeline
  30. from .table_recognition import TableRecognitionPipelineV2
  31. from .multilingual_speech_recognition import MultilingualSpeechRecognitionPipeline
  32. from .formula_recognition import FormulaRecognitionPipeline
  33. from .image_multilabel_classification import ImageMultiLabelClassificationPipeline
  34. from .video_classification import VideoClassificationPipeline
  35. from .video_detection import VideoDetectionPipeline
  36. from .anomaly_detection import AnomalyDetectionPipeline
  37. from .ts_forecasting import TSFcPipeline
  38. from .ts_anomaly_detection import TSAnomalyDetPipeline
  39. from .ts_classification import TSClsPipeline
  40. from .pp_shitu_v2 import ShiTuV2Pipeline
  41. from .face_recognition import FaceRecPipeline
  42. from .attribute_recognition import (
  43. PedestrianAttributeRecPipeline,
  44. VehicleAttributeRecPipeline,
  45. )
  46. from .semantic_segmentation import SemanticSegmentationPipeline
  47. from .instance_segmentation import InstanceSegmentationPipeline
  48. from .small_object_detection import SmallObjectDetectionPipeline
  49. from .rotated_object_detection import RotatedObjectDetectionPipeline
  50. from .keypoint_detection import KeypointDetectionPipeline
  51. from .open_vocabulary_detection import OpenVocabularyDetectionPipeline
  52. from .open_vocabulary_segmentation import OpenVocabularySegmentationPipeline
  53. module_3d_bev_detection = import_module(
  54. ".3d_bev_detection", "paddlex.inference.pipelines"
  55. )
  56. BEVDet3DPipeline = getattr(module_3d_bev_detection, "BEVDet3DPipeline")
  57. def get_pipeline_path(pipeline_name: str) -> str:
  58. """
  59. Get the full path of the pipeline configuration file based on the provided pipeline name.
  60. Args:
  61. pipeline_name (str): The name of the pipeline.
  62. Returns:
  63. str: The full path to the pipeline configuration file or None if not found.
  64. """
  65. pipeline_path = (
  66. Path(__file__).parent.parent.parent
  67. / "configs/pipelines"
  68. / f"{pipeline_name}.yaml"
  69. ).resolve()
  70. if not Path(pipeline_path).exists():
  71. return None
  72. return pipeline_path
  73. def load_pipeline_config(pipeline: str) -> Dict[str, Any]:
  74. """
  75. Load the pipeline configuration.
  76. Args:
  77. pipeline (str): The name of the pipeline or the path to the config file.
  78. Returns:
  79. Dict[str, Any]: The parsed pipeline configuration.
  80. Raises:
  81. Exception: If the config file of pipeline does not exist.
  82. """
  83. if not (pipeline.endswith(".yml") or pipeline.endswith(".yaml")):
  84. pipeline_path = get_pipeline_path(pipeline)
  85. if pipeline_path is None:
  86. raise Exception(
  87. f"The pipeline ({pipeline}) does not exist! Please use a pipeline name or a config file path!"
  88. )
  89. else:
  90. pipeline_path = pipeline
  91. config = parse_config(pipeline_path)
  92. return config
  93. def create_pipeline(
  94. pipeline: Optional[str] = None,
  95. config: Optional[Dict[str, Any]] = None,
  96. device: Optional[str] = None,
  97. pp_option: Optional[PaddlePredictorOption] = None,
  98. use_hpip: bool = False,
  99. *args: Any,
  100. **kwargs: Any,
  101. ) -> BasePipeline:
  102. """
  103. Create a pipeline instance based on the provided parameters.
  104. If the input parameter config is not provided, it is obtained from the
  105. default config corresponding to the pipeline name.
  106. Args:
  107. pipeline (Optional[str], optional): The name of the pipeline to
  108. create, or the path to the config file. Defaults to None.
  109. config (Optional[Dict[str, Any]], optional): The pipeline configuration.
  110. Defaults to None.
  111. device (Optional[str], optional): The device to run the pipeline on.
  112. Defaults to None.
  113. pp_option (Optional[PaddlePredictorOption], optional): The options for
  114. the PaddlePredictor. Defaults to None.
  115. use_hpip (bool, optional): Whether to use high-performance inference
  116. plugin (HPIP) for prediction. Defaults to False.
  117. *args: Additional positional arguments.
  118. **kwargs: Additional keyword arguments.
  119. Returns:
  120. BasePipeline: The created pipeline instance.
  121. """
  122. if pipeline is None and config is None:
  123. raise ValueError(
  124. "Both `pipeline` and `config` cannot be None at the same time."
  125. )
  126. if config is None:
  127. config = load_pipeline_config(pipeline)
  128. else:
  129. if pipeline is not None and config["pipeline_name"] != pipeline:
  130. logging.warning(
  131. "The pipeline name in the config (%r) is different from the specified pipeline name (%r). %r will be used.",
  132. config["pipeline_name"],
  133. pipeline,
  134. config["pipeline_name"],
  135. )
  136. pipeline_name = config["pipeline_name"]
  137. pipeline = BasePipeline.get(pipeline_name)(
  138. config=config,
  139. device=device,
  140. pp_option=pp_option,
  141. use_hpip=use_hpip,
  142. *args,
  143. **kwargs,
  144. )
  145. return pipeline
  146. def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat:
  147. """Creates an instance of a chat bot based on the provided configuration.
  148. Args:
  149. config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
  150. *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
  151. **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
  152. Returns:
  153. BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config.
  154. """
  155. if "chat_bot_config_error" in config:
  156. raise ValueError(config["chat_bot_config_error"])
  157. api_type = config["api_type"]
  158. chat_bot = BaseChat.get(api_type)(config)
  159. return chat_bot
  160. def create_retriever(
  161. config: Dict,
  162. *args,
  163. **kwargs,
  164. ) -> BaseRetriever:
  165. """
  166. Creates a retriever instance based on the provided configuration.
  167. Args:
  168. config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
  169. *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
  170. **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
  171. Returns:
  172. BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config.
  173. """
  174. if "retriever_config_error" in config:
  175. raise ValueError(config["retriever_config_error"])
  176. api_type = config["api_type"]
  177. retriever = BaseRetriever.get(api_type)(config)
  178. return retriever
  179. def create_prompt_engineering(
  180. config: Dict,
  181. *args,
  182. **kwargs,
  183. ) -> BaseGeneratePrompt:
  184. """
  185. Creates a prompt engineering instance based on the provided configuration.
  186. Args:
  187. config (Dict): Configuration settings, expected to be a dictionary with at least a 'task_type' key.
  188. *args: Variable length argument list for additional positional arguments.
  189. **kwargs: Arbitrary keyword arguments.
  190. Returns:
  191. BaseGeneratePrompt: An instance of a prompt engineering class corresponding to the 'task_type' in the config.
  192. """
  193. if "pe_config_error" in config:
  194. raise ValueError(config["pe_config_error"])
  195. task_type = config["task_type"]
  196. pe = BaseGeneratePrompt.get(task_type)(config)
  197. return pe