face_recognition.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import asyncio
  15. from operator import attrgetter
  16. from typing import Any, Dict, List
  17. from .....utils.deps import function_requires_deps, is_dep_available
  18. from ....pipelines.components import IndexData
  19. from ...infra import utils as serving_utils
  20. from ...infra.config import AppConfig
  21. from ...infra.models import AIStudioResultResponse
  22. from ...schemas import face_recognition as schema
  23. from .._app import create_app, primary_operation
  24. from ._common import image_recognition as ir_common
  25. if is_dep_available("fastapi"):
  26. from fastapi import FastAPI
  27. # XXX: Currently the implementations of the face recognition and PP-ShiTuV2
  28. # pipeline apps overlap significantly. We should aim to facilitate code reuse,
  29. # but is it acceptable to assume a strong similarity between these two
  30. # pipelines?
  31. @function_requires_deps("fastapi")
  32. def create_pipeline_app(pipeline: Any, app_config: AppConfig) -> "FastAPI":
  33. app, ctx = create_app(
  34. pipeline=pipeline, app_config=app_config, app_aiohttp_session=True
  35. )
  36. ir_common.update_app_context(ctx)
  37. @primary_operation(
  38. app,
  39. schema.BUILD_INDEX_ENDPOINT,
  40. "buildIndex",
  41. )
  42. async def _build_index(
  43. request: schema.BuildIndexRequest,
  44. ) -> AIStudioResultResponse[schema.BuildIndexResult]:
  45. pipeline = ctx.pipeline
  46. aiohttp_session = ctx.aiohttp_session
  47. file_bytes_list = await asyncio.gather(
  48. *(
  49. serving_utils.get_raw_bytes_async(img, aiohttp_session)
  50. for img in map(attrgetter("image"), request.imageLabelPairs)
  51. )
  52. )
  53. images = [serving_utils.image_bytes_to_array(item) for item in file_bytes_list]
  54. labels = [pair.label for pair in request.imageLabelPairs]
  55. # TODO: Support specifying `index_type` and `metric_type` in the
  56. # request
  57. index_data = await pipeline.call(
  58. pipeline.pipeline.build_index,
  59. images,
  60. labels,
  61. index_type="Flat",
  62. metric_type="IP",
  63. )
  64. index_storage = ctx.extra["index_storage"]
  65. index_key = ir_common.generate_index_key()
  66. index_data_bytes = index_data.to_bytes()
  67. await serving_utils.call_async(index_storage.set, index_key, index_data_bytes)
  68. return AIStudioResultResponse[schema.BuildIndexResult](
  69. logId=serving_utils.generate_log_id(),
  70. result=schema.BuildIndexResult(
  71. indexKey=index_key, imageCount=len(index_data.id_map)
  72. ),
  73. )
  74. @primary_operation(
  75. app,
  76. schema.ADD_IMAGES_TO_INDEX_ENDPOINT,
  77. "addImagesToIndex",
  78. )
  79. async def _add_images_to_index(
  80. request: schema.AddImagesToIndexRequest,
  81. ) -> AIStudioResultResponse[schema.AddImagesToIndexResult]:
  82. pipeline = ctx.pipeline
  83. aiohttp_session = ctx.aiohttp_session
  84. file_bytes_list = await asyncio.gather(
  85. *(
  86. serving_utils.get_raw_bytes_async(img, aiohttp_session)
  87. for img in map(attrgetter("image"), request.imageLabelPairs)
  88. )
  89. )
  90. images = [serving_utils.image_bytes_to_array(item) for item in file_bytes_list]
  91. labels = [pair.label for pair in request.imageLabelPairs]
  92. index_storage = ctx.extra["index_storage"]
  93. index_data_bytes = await serving_utils.call_async(
  94. index_storage.get, request.indexKey
  95. )
  96. index_data = IndexData.from_bytes(index_data_bytes)
  97. index_data = await pipeline.call(
  98. pipeline.pipeline.append_index, images, labels, index_data
  99. )
  100. index_data_bytes = index_data.to_bytes()
  101. await serving_utils.call_async(
  102. index_storage.set, request.indexKey, index_data_bytes
  103. )
  104. return AIStudioResultResponse[schema.AddImagesToIndexResult](
  105. logId=serving_utils.generate_log_id(),
  106. result=schema.AddImagesToIndexResult(imageCount=len(index_data.id_map)),
  107. )
  108. @primary_operation(
  109. app,
  110. schema.REMOVE_IMAGES_FROM_INDEX_ENDPOINT,
  111. "removeImagesFromIndex",
  112. )
  113. async def _remove_images_from_index(
  114. request: schema.RemoveImagesFromIndexRequest,
  115. ) -> AIStudioResultResponse[schema.RemoveImagesFromIndexResult]:
  116. pipeline = ctx.pipeline
  117. index_storage = ctx.extra["index_storage"]
  118. index_data_bytes = await serving_utils.call_async(
  119. index_storage.get, request.indexKey
  120. )
  121. index_data = IndexData.from_bytes(index_data_bytes)
  122. index_data = await pipeline.call(
  123. pipeline.pipeline.remove_index, request.ids, index_data
  124. )
  125. index_data_bytes = index_data.to_bytes()
  126. await serving_utils.call_async(
  127. index_storage.set, request.indexKey, index_data_bytes
  128. )
  129. return AIStudioResultResponse[schema.RemoveImagesFromIndexResult](
  130. logId=serving_utils.generate_log_id(),
  131. result=schema.RemoveImagesFromIndexResult(
  132. imageCount=len(index_data.id_map)
  133. ),
  134. )
  135. @primary_operation(
  136. app,
  137. schema.INFER_ENDPOINT,
  138. "infer",
  139. )
  140. async def _infer(
  141. request: schema.InferRequest,
  142. ) -> AIStudioResultResponse[schema.InferResult]:
  143. pipeline = ctx.pipeline
  144. aiohttp_session = ctx.aiohttp_session
  145. visualize_enabled = (
  146. request.visualize if request.visualize is not None else ctx.config.visualize
  147. )
  148. image_bytes = await serving_utils.get_raw_bytes_async(
  149. request.image, aiohttp_session
  150. )
  151. image = serving_utils.image_bytes_to_array(image_bytes)
  152. if request.indexKey is not None:
  153. index_storage = ctx.extra["index_storage"]
  154. index_data_bytes = await serving_utils.call_async(
  155. index_storage.get, request.indexKey
  156. )
  157. index_data = IndexData.from_bytes(index_data_bytes)
  158. else:
  159. index_data = None
  160. result = list(
  161. await pipeline.call(
  162. pipeline.pipeline.predict,
  163. image,
  164. index=index_data,
  165. det_threshold=request.detThreshold,
  166. rec_threshold=request.recThreshold,
  167. hamming_radius=request.hammingRadius,
  168. topk=request.topk,
  169. )
  170. )[0]
  171. objs: List[Dict[str, Any]] = []
  172. for obj in result["boxes"]:
  173. rec_results: List[Dict[str, Any]] = []
  174. if obj["rec_scores"] is not None:
  175. for label, score in zip(obj["labels"], obj["rec_scores"]):
  176. rec_results.append(
  177. dict(
  178. label=label,
  179. score=score,
  180. )
  181. )
  182. objs.append(
  183. dict(
  184. bbox=obj["coordinate"],
  185. recResults=rec_results,
  186. score=obj["det_score"],
  187. )
  188. )
  189. if visualize_enabled:
  190. output_image_base64 = serving_utils.base64_encode(
  191. serving_utils.image_to_bytes(result.img["res"])
  192. )
  193. else:
  194. output_image_base64 = None
  195. return AIStudioResultResponse[schema.InferResult](
  196. logId=serving_utils.generate_log_id(),
  197. result=schema.InferResult(faces=objs, image=output_image_base64),
  198. )
  199. return app