__init__.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from pathlib import Path
  15. from typing import Any, Dict, Optional
  16. from .base import BasePipeline
  17. from ..utils.pp_option import PaddlePredictorOption
  18. from .components import BaseChat, BaseRetriever, BaseGeneratePrompt
  19. from ...utils.config import parse_config
  20. from .ocr import OCRPipeline
  21. from .doc_preprocessor import DocPreprocessorPipeline
  22. from .layout_parsing import LayoutParsingPipeline
  23. from .pp_chatocr import PP_ChatOCRv3_Pipeline, PP_ChatOCRv4_Pipeline
  24. from .image_classification import ImageClassificationPipeline
  25. from .object_detection import ObjectDetectionPipeline
  26. from .seal_recognition import SealRecognitionPipeline
  27. from .table_recognition import TableRecognitionPipeline
  28. from .table_recognition import TableRecognitionPipelineV2
  29. from .multilingual_speech_recognition import MultilingualSpeechRecognitionPipeline
  30. from .formula_recognition import FormulaRecognitionPipeline
  31. from .image_multilabel_classification import ImageMultiLabelClassificationPipeline
  32. from .video_classification import VideoClassificationPipeline
  33. from .video_detection import VideoDetectionPipeline
  34. from .anomaly_detection import AnomalyDetectionPipeline
  35. from .ts_forecasting import TSFcPipeline
  36. from .ts_anomaly_detection import TSAnomalyDetPipeline
  37. from .ts_classification import TSClsPipeline
  38. from .pp_shitu_v2 import ShiTuV2Pipeline
  39. from .face_recognition import FaceRecPipeline
  40. from .attribute_recognition import (
  41. PedestrianAttributeRecPipeline,
  42. VehicleAttributeRecPipeline,
  43. )
  44. from .semantic_segmentation import SemanticSegmentationPipeline
  45. from .instance_segmentation import InstanceSegmentationPipeline
  46. from .small_object__detection import SmallObjectDetectionPipeline
  47. from .rotated_object__detection import RotatedObjectDetectionPipeline
  48. from .keypoint_detection import KeypointDetectionPipeline
  49. def get_pipeline_path(pipeline_name: str) -> str:
  50. """
  51. Get the full path of the pipeline configuration file based on the provided pipeline name.
  52. Args:
  53. pipeline_name (str): The name of the pipeline.
  54. Returns:
  55. str: The full path to the pipeline configuration file or None if not found.
  56. """
  57. pipeline_path = (
  58. Path(__file__).parent.parent.parent
  59. / "configs/pipelines"
  60. / f"{pipeline_name}.yaml"
  61. ).resolve()
  62. if not Path(pipeline_path).exists():
  63. return None
  64. return pipeline_path
  65. def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
  66. """
  67. Load the pipeline configuration.
  68. Args:
  69. pipeline_name (str): The name of the pipeline or the path to the config file.
  70. Returns:
  71. Dict[str, Any]: The parsed pipeline configuration.
  72. Raises:
  73. Exception: If the config file of pipeline does not exist.
  74. """
  75. if not Path(pipeline_name).exists():
  76. pipeline_path = get_pipeline_path(pipeline_name)
  77. if pipeline_path is None:
  78. raise Exception(
  79. f"The pipeline ({pipeline_name}) does not exist! Please use a pipeline name or a config file path!"
  80. )
  81. else:
  82. pipeline_path = pipeline_name
  83. config = parse_config(pipeline_path)
  84. return config
  85. def create_pipeline(
  86. pipeline: str,
  87. config: Dict = None,
  88. device: str = None,
  89. pp_option: PaddlePredictorOption = None,
  90. use_hpip: bool = False,
  91. *args,
  92. **kwargs,
  93. ) -> BasePipeline:
  94. """
  95. Create a pipeline instance based on the provided parameters.
  96. If the input parameter config is not provided,
  97. it is obtained from the default config corresponding to the pipeline name.
  98. Args:
  99. pipeline (str): The name of the pipeline to create.
  100. config (Dict, optional): The path to the pipeline configuration file. Defaults to None.
  101. device (str, optional): The device to run the pipeline on. Defaults to None.
  102. pp_option (PaddlePredictorOption, optional): The options for the PaddlePredictor. Defaults to None.
  103. use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
  104. *args: Additional positional arguments.
  105. **kwargs: Additional keyword arguments.
  106. Returns:
  107. BasePipeline: The created pipeline instance.
  108. """
  109. if config is None:
  110. config = load_pipeline_config(pipeline)
  111. pipeline_name = config["pipeline_name"]
  112. else:
  113. pipeline_name = pipeline
  114. pipeline = BasePipeline.get(pipeline_name)(
  115. config=config,
  116. device=device,
  117. pp_option=pp_option,
  118. use_hpip=use_hpip,
  119. *args,
  120. **kwargs,
  121. )
  122. return pipeline
  123. def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat:
  124. """Creates an instance of a chat bot based on the provided configuration.
  125. Args:
  126. config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
  127. *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
  128. **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
  129. Returns:
  130. BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config.
  131. """
  132. api_type = config["api_type"]
  133. chat_bot = BaseChat.get(api_type)(config)
  134. return chat_bot
  135. def create_retriever(
  136. config: Dict,
  137. *args,
  138. **kwargs,
  139. ) -> BaseRetriever:
  140. """
  141. Creates a retriever instance based on the provided configuration.
  142. Args:
  143. config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
  144. *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
  145. **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
  146. Returns:
  147. BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config.
  148. """
  149. api_type = config["api_type"]
  150. retriever = BaseRetriever.get(api_type)(config)
  151. return retriever
  152. def create_prompt_engeering(
  153. config: Dict,
  154. *args,
  155. **kwargs,
  156. ) -> BaseGeneratePrompt:
  157. """
  158. Creates a prompt engineering instance based on the provided configuration.
  159. Args:
  160. config (Dict): Configuration settings, expected to be a dictionary with at least a 'task_type' key.
  161. *args: Variable length argument list for additional positional arguments.
  162. **kwargs: Arbitrary keyword arguments.
  163. Returns:
  164. BaseGeneratePrompt: An instance of a prompt engineering class corresponding to the 'task_type' in the config.
  165. """
  166. task_type = config["task_type"]
  167. pe = BaseGeneratePrompt.get(task_type)(config)
  168. return pe