pp_shituv2.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. from operator import attrgetter
  16. from typing import Any, Dict, List
  17. from fastapi import FastAPI
  18. from ....pipelines.components import IndexData
  19. from ...infra import utils as serving_utils
  20. from ...infra.config import AppConfig
  21. from ...infra.models import ResultResponse
  22. from ...schemas import pp_shituv2 as schema
  23. from .._app import create_app, primary_operation
  24. from ._common import image_recognition as ir_common
  25. def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> FastAPI:
  26. app, ctx = create_app(
  27. pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
  28. )
  29. ir_common.update_app_context(ctx)
  30. @primary_operation(
  31. app,
  32. schema.BUILD_INDEX_ENDPOINT,
  33. "buildIndex",
  34. )
  35. async def _build_index(
  36. request: schema.BuildIndexRequest,
  37. ) -> ResultResponse[schema.BuildIndexResult]:
  38. pipeline = ctx.pipeline
  39. aiohttp_session = ctx.aiohttp_session
  40. file_bytes_list = await asyncio.gather(
  41. *(
  42. serving_utils.get_raw_bytes_async(img, aiohttp_session)
  43. for img in map(attrgetter("image"), request.imageLabelPairs)
  44. )
  45. )
  46. images = [serving_utils.image_bytes_to_array(item) for item in file_bytes_list]
  47. labels = [pair.label for pair in request.imageLabelPairs]
  48. # TODO: Support specifying `index_type` and `metric_type` in the
  49. # request
  50. index_data = await pipeline.call(
  51. pipeline.pipeline.build_index,
  52. images,
  53. labels,
  54. index_type="Flat",
  55. metric_type="IP",
  56. )
  57. index_storage = ctx.extra["index_storage"]
  58. index_key = ir_common.generate_index_key()
  59. index_data_bytes = index_data.to_bytes()
  60. await serving_utils.call_async(index_storage.set, index_key, index_data_bytes)
  61. return ResultResponse[schema.BuildIndexResult](
  62. logId=serving_utils.generate_log_id(),
  63. result=schema.BuildIndexResult(
  64. indexKey=index_key, imageCount=len(index_data.id_map)
  65. ),
  66. )
  67. @primary_operation(
  68. app,
  69. schema.ADD_IMAGES_TO_INDEX_ENDPOINT,
  70. "addImagesToIndex",
  71. )
  72. async def _add_images_to_index(
  73. request: schema.AddImagesToIndexRequest,
  74. ) -> ResultResponse[schema.AddImagesToIndexResult]:
  75. pipeline = ctx.pipeline
  76. aiohttp_session = ctx.aiohttp_session
  77. file_bytes_list = await asyncio.gather(
  78. *(
  79. serving_utils.get_raw_bytes_async(img, aiohttp_session)
  80. for img in map(attrgetter("image"), request.imageLabelPairs)
  81. )
  82. )
  83. images = [serving_utils.image_bytes_to_array(item) for item in file_bytes_list]
  84. labels = [pair.label for pair in request.imageLabelPairs]
  85. index_storage = ctx.extra["index_storage"]
  86. index_data_bytes = await serving_utils.call_async(
  87. index_storage.get, request.indexKey
  88. )
  89. index_data = IndexData.from_bytes(index_data_bytes)
  90. index_data = await pipeline.call(
  91. pipeline.pipeline.append_index, images, labels, index_data
  92. )
  93. index_data_bytes = index_data.to_bytes()
  94. await serving_utils.call_async(
  95. index_storage.set, request.indexKey, index_data_bytes
  96. )
  97. return ResultResponse[schema.AddImagesToIndexResult](
  98. logId=serving_utils.generate_log_id(),
  99. result=schema.AddImagesToIndexResult(imageCount=len(index_data.id_map)),
  100. )
  101. @primary_operation(
  102. app,
  103. schema.REMOVE_IMAGES_FROM_INDEX_ENDPOINT,
  104. "removeImagesFromIndex",
  105. )
  106. async def _remove_images_from_index(
  107. request: schema.RemoveImagesFromIndexRequest,
  108. ) -> ResultResponse[schema.RemoveImagesFromIndexResult]:
  109. pipeline = ctx.pipeline
  110. index_storage = ctx.extra["index_storage"]
  111. index_data_bytes = await serving_utils.call_async(
  112. index_storage.get, request.indexKey
  113. )
  114. index_data = IndexData.from_bytes(index_data_bytes)
  115. index_data = await pipeline.call(
  116. pipeline.pipeline.remove_index, request.ids, index_data
  117. )
  118. index_data_bytes = index_data.to_bytes()
  119. await serving_utils.call_async(
  120. index_storage.set, request.indexKey, index_data_bytes
  121. )
  122. return ResultResponse[schema.RemoveImagesFromIndexResult](
  123. logId=serving_utils.generate_log_id(),
  124. result=schema.RemoveImagesFromIndexResult(
  125. imageCount=len(index_data.id_map)
  126. ),
  127. )
  128. @primary_operation(
  129. app,
  130. schema.INFER_ENDPOINT,
  131. "infer",
  132. )
  133. async def _infer(
  134. request: schema.InferRequest,
  135. ) -> ResultResponse[schema.InferResult]:
  136. pipeline = ctx.pipeline
  137. aiohttp_session = ctx.aiohttp_session
  138. image_bytes = await serving_utils.get_raw_bytes_async(
  139. request.image, aiohttp_session
  140. )
  141. image = serving_utils.image_bytes_to_array(image_bytes)
  142. if request.indexKey is not None:
  143. index_storage = ctx.extra["index_storage"]
  144. index_data_bytes = await serving_utils.call_async(
  145. index_storage.get, request.indexKey
  146. )
  147. index_data = IndexData.from_bytes(index_data_bytes)
  148. else:
  149. index_data = None
  150. result = list(
  151. await pipeline.call(
  152. pipeline.pipeline.predict,
  153. image,
  154. index=index_data,
  155. det_threshold=request.detThreshold,
  156. rec_threshold=request.recThreshold,
  157. hamming_radius=request.hammingRadius,
  158. topk=request.topk,
  159. )
  160. )[0]
  161. objs: List[Dict[str, Any]] = []
  162. for obj in result["boxes"]:
  163. rec_results: List[Dict[str, Any]] = []
  164. if obj["rec_scores"] != [None]:
  165. for label, score in zip(obj["labels"], obj["rec_scores"]):
  166. rec_results.append(
  167. dict(
  168. label=label,
  169. score=score,
  170. )
  171. )
  172. objs.append(
  173. dict(
  174. bbox=obj["coordinate"],
  175. recResults=rec_results,
  176. score=obj["det_score"],
  177. )
  178. )
  179. if ctx.config.visualize:
  180. output_image_base64 = serving_utils.base64_encode(
  181. serving_utils.image_to_bytes(result.img["res"])
  182. )
  183. else:
  184. output_image_base64 = None
  185. return ResultResponse[schema.InferResult](
  186. logId=serving_utils.generate_log_id(),
  187. result=schema.InferResult(detectedObjects=objs, image=output_image_base64),
  188. )
  189. return app