| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from pathlib import Path
- from typing import Any, Dict, Optional
- from .base import BasePipeline
- from ..utils.pp_option import PaddlePredictorOption
- from .components import BaseChat, BaseRetriever, BaseGeneratePrompt
- from ...utils.config import parse_config
- from .ocr import OCRPipeline
- from .doc_preprocessor import DocPreprocessorPipeline
- from .layout_parsing import LayoutParsingPipeline
- from .pp_chatocr import PP_ChatOCRv3_Pipeline, PP_ChatOCRv4_Pipeline
- from .image_classification import ImageClassificationPipeline
- from .seal_recognition import SealRecognitionPipeline
- from .table_recognition import TableRecognitionPipeline
- from .video_classification import VideoClassificationPipeline
- def get_pipeline_path(pipeline_name: str) -> str:
- """
- Get the full path of the pipeline configuration file based on the provided pipeline name.
- Args:
- pipeline_name (str): The name of the pipeline.
- Returns:
- str: The full path to the pipeline configuration file or None if not found.
- """
- pipeline_path = (
- Path(__file__).parent.parent.parent
- / "configs/pipelines"
- / f"{pipeline_name}.yaml"
- ).resolve()
- if not Path(pipeline_path).exists():
- return None
- return pipeline_path
- def load_pipeline_config(pipeline_name: str) -> Dict[str, Any]:
- """
- Load the pipeline configuration.
- Args:
- pipeline_name (str): The name of the pipeline or the path to the config file.
- Returns:
- Dict[str, Any]: The parsed pipeline configuration.
- Raises:
- Exception: If the config file of pipeline does not exist.
- """
- if not Path(pipeline_name).exists():
- pipeline_path = get_pipeline_path(pipeline_name)
- if pipeline_path is None:
- raise Exception(
- f"The pipeline ({pipeline_name}) does not exist! Please use a pipeline name or a config file path!"
- )
- else:
- pipeline_path = pipeline_name
- config = parse_config(pipeline_path)
- return config
- def create_pipeline(
- pipeline: str,
- config: Dict = None,
- device: str = None,
- pp_option: PaddlePredictorOption = None,
- use_hpip: bool = False,
- hpi_params: Optional[Dict[str, Any]] = None,
- *args,
- **kwargs,
- ) -> BasePipeline:
- """
- Create a pipeline instance based on the provided parameters.
- If the input parameter config is not provided,
- it is obtained from the default config corresponding to the pipeline name.
- Args:
- pipeline (str): The name of the pipeline to create.
- config (Dict, optional): The path to the pipeline configuration file. Defaults to None.
- device (str, optional): The device to run the pipeline on. Defaults to None.
- pp_option (PaddlePredictorOption, optional): The options for the PaddlePredictor. Defaults to None.
- use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
- hpi_params (Optional[Dict[str, Any]], optional): Additional parameters for hpip. Defaults to None.
- *args: Additional positional arguments.
- **kwargs: Additional keyword arguments.
- Returns:
- BasePipeline: The created pipeline instance.
- """
- if config is None:
- config = load_pipeline_config(pipeline)
- pipeline_name = config["pipeline_name"]
- else:
- pipeline_name = pipeline
- pipeline = BasePipeline.get(pipeline_name)(
- config=config,
- device=device,
- pp_option=pp_option,
- use_hpip=use_hpip,
- hpi_params=hpi_params,
- *args,
- **kwargs,
- )
- return pipeline
- def create_chat_bot(config: Dict, *args, **kwargs) -> BaseChat:
- """Creates an instance of a chat bot based on the provided configuration.
- Args:
- config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
- *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
- **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
- Returns:
- BaseChat: An instance of the chat bot class corresponding to the 'model_name' in the config.
- """
- model_name = config["model_name"]
- chat_bot = BaseChat.get(model_name)(config)
- return chat_bot
- def create_retriever(
- config: Dict,
- *args,
- **kwargs,
- ) -> BaseRetriever:
- """
- Creates a retriever instance based on the provided configuration.
- Args:
- config (Dict): Configuration settings, expected to be a dictionary with at least a 'model_name' key.
- *args: Additional positional arguments. Not used in this function but allowed for future compatibility.
- **kwargs: Additional keyword arguments. Not used in this function but allowed for future compatibility.
- Returns:
- BaseRetriever: An instance of a retriever class corresponding to the 'model_name' in the config.
- """
- model_name = config["model_name"]
- retriever = BaseRetriever.get(model_name)(config)
- return retriever
- def create_prompt_engeering(
- config: Dict,
- *args,
- **kwargs,
- ) -> BaseGeneratePrompt:
- """
- Creates a prompt engineering instance based on the provided configuration.
- Args:
- config (Dict): Configuration settings, expected to be a dictionary with at least a 'task_type' key.
- *args: Variable length argument list for additional positional arguments.
- **kwargs: Arbitrary keyword arguments.
- Returns:
- BaseGeneratePrompt: An instance of a prompt engineering class corresponding to the 'task_type' in the config.
- """
- task_type = config["task_type"]
- pe = BaseGeneratePrompt.get(task_type)(config)
- return pe
|