pp_shituv2.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. from operator import attrgetter
  16. from typing import Any, Dict, List
  17. from fastapi import FastAPI
  18. from ....pipelines_new.components import IndexData
  19. from ...infra import utils as serving_utils
  20. from ...infra.config import AppConfig
  21. from ...infra.models import ResultResponse
  22. from ...schemas import pp_shituv2 as schema
  23. from .._app import create_app, primary_operation
  24. from ._common import image_recognition as ir_common
  25. def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
  26. app, ctx = create_app(
  27. pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
  28. )
  29. ir_common.update_app_context(ctx)
  30. @primary_operation(
  31. app,
  32. schema.BUILD_INDEX_ENDPOINT,
  33. "buildIndex",
  34. )
  35. async def _build_index(
  36. request: schema.BuildIndexRequest,
  37. ) -> ResultResponse[schema.BuildIndexResult]:
  38. pipeline = ctx.pipeline
  39. aiohttp_session = ctx.aiohttp_session
  40. file_bytes_list = await asyncio.gather(
  41. *(
  42. serving_utils.get_raw_bytes_async(img, aiohttp_session)
  43. for img in map(attrgetter("image"), request.imageLabelPairs)
  44. )
  45. )
  46. images = [serving_utils.image_bytes_to_array(item) for item in file_bytes_list]
  47. labels = [pair.label for pair in request.imageLabelPairs]
  48. # TODO: Support specifying `index_type` and `metric_type` in the
  49. # request
  50. index_data = await pipeline.call(
  51. pipeline.pipeline.build_index,
  52. images,
  53. labels,
  54. index_type="Flat",
  55. metric_type="IP",
  56. )
  57. index_storage = ctx.extra["index_storage"]
  58. index_key = ir_common.generate_index_key()
  59. index_data_bytes = index_data.to_bytes()
  60. await serving_utils.call_async(index_storage.set, index_key, index_data_bytes)
  61. return ResultResponse[schema.BuildIndexResult](
  62. logId=serving_utils.generate_log_id(),
  63. result=schema.BuildIndexResult(indexKey=index_key, idMap=index_data.id_map),
  64. )
  65. @primary_operation(
  66. app,
  67. schema.ADD_IMAGES_TO_INDEX_ENDPOINT,
  68. "addImagesToIndex",
  69. )
  70. async def _add_images_to_index(
  71. request: schema.AddImagesToIndexRequest,
  72. ) -> ResultResponse[schema.AddImagesToIndexResult]:
  73. pipeline = ctx.pipeline
  74. aiohttp_session = ctx.aiohttp_session
  75. file_bytes_list = await asyncio.gather(
  76. *(
  77. serving_utils.get_raw_bytes_async(img, aiohttp_session)
  78. for img in map(attrgetter("image"), request.imageLabelPairs)
  79. )
  80. )
  81. images = [serving_utils.image_bytes_to_array(item) for item in file_bytes_list]
  82. labels = [pair.label for pair in request.imageLabelPairs]
  83. if request.indexKey is not None:
  84. index_storage = ctx.extra["index_storage"]
  85. index_data_bytes = await serving_utils.call_async(
  86. index_storage.get, request.indexKey
  87. )
  88. index_data = IndexData.from_bytes(index_data_bytes)
  89. else:
  90. index_data = None
  91. index_data = await pipeline.call(
  92. pipeline.pipeline.append_index, images, labels, index_data
  93. )
  94. index_data_bytes = index_data.to_bytes()
  95. await serving_utils.call_async(
  96. index_storage.set, request.indexKey, index_data_bytes
  97. )
  98. return ResultResponse[schema.AddImagesToIndexResult](
  99. logId=serving_utils.generate_log_id(),
  100. result=schema.AddImagesToIndexResult(idMap=index_data.id_map),
  101. )
  102. @primary_operation(
  103. app,
  104. schema.REMOVE_IMAGES_FROM_INDEX_ENDPOINT,
  105. "removeImagesFromIndex",
  106. )
  107. async def _remove_images_from_index(
  108. request: schema.RemoveImagesFromIndexRequest,
  109. ) -> ResultResponse[schema.RemoveImagesFromIndexResult]:
  110. pipeline = ctx.pipeline
  111. if request.indexKey is not None:
  112. index_storage = ctx.extra["index_storage"]
  113. index_data_bytes = await serving_utils.call_async(
  114. index_storage.get, request.indexKey
  115. )
  116. index_data = IndexData.from_bytes(index_data_bytes)
  117. else:
  118. index_data = None
  119. index_data = await pipeline.call(
  120. pipeline.pipeline.remove_index, request.ids, index_data
  121. )
  122. index_data_bytes = index_data.to_bytes()
  123. await serving_utils.call_async(
  124. index_storage.set, request.indexKey, index_data_bytes
  125. )
  126. return ResultResponse[schema.RemoveImagesFromIndexResult](
  127. logId=serving_utils.generate_log_id(),
  128. result=schema.RemoveImagesFromIndexResult(idMap=index_data.id_map),
  129. )
  130. @primary_operation(
  131. app,
  132. schema.INFER_ENDPOINT,
  133. "infer",
  134. )
  135. async def _infer(
  136. request: schema.InferRequest,
  137. ) -> ResultResponse[schema.InferResult]:
  138. pipeline = ctx.pipeline
  139. aiohttp_session = ctx.aiohttp_session
  140. image_bytes = await serving_utils.get_raw_bytes_async(
  141. request.image, aiohttp_session
  142. )
  143. image = serving_utils.image_bytes_to_array(image_bytes)
  144. if request.indexKey is not None:
  145. index_storage = ctx.extra["index_storage"]
  146. index_data_bytes = await serving_utils.call_async(
  147. index_storage.get, request.indexKey
  148. )
  149. index_data = IndexData.from_bytes(index_data_bytes)
  150. else:
  151. index_data = None
  152. result = list(
  153. await pipeline.call(
  154. pipeline.pipeline.predict,
  155. image,
  156. index=index_data,
  157. det_threshold=request.detThreshold,
  158. rec_threshold=request.recThreshold,
  159. hamming_radius=request.hammingRadius,
  160. topk=request.topk,
  161. )
  162. )[0]
  163. objs: List[Dict[str, Any]] = []
  164. for obj in result["boxes"]:
  165. rec_results: List[Dict[str, Any]] = []
  166. if obj["rec_scores"] is not None:
  167. for label, score in zip(obj["labels"], obj["rec_scores"]):
  168. rec_results.append(
  169. dict(
  170. label=label,
  171. score=score,
  172. )
  173. )
  174. objs.append(
  175. dict(
  176. bbox=obj["coordinate"],
  177. recResults=rec_results,
  178. score=obj["det_score"],
  179. )
  180. )
  181. if ctx.config.visualize:
  182. output_image_base64 = serving_utils.base64_encode(
  183. serving_utils.image_to_bytes(result.img["res"])
  184. )
  185. else:
  186. output_image_base64 = None
  187. return ResultResponse[schema.InferResult](
  188. logId=serving_utils.generate_log_id(),
  189. result=schema.InferResult(detectedObjects=objs, image=output_image_base64),
  190. )
  191. return app