table_structure.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import time
  15. from typing import Any, Dict, List, Tuple
  16. import numpy as np
  17. from .table_stucture_utils import (
  18. OrtInferSession,
  19. TableLabelDecode,
  20. TablePreprocess,
  21. BatchTablePreprocess,
  22. )
  23. class TableStructurer:
  24. def __init__(self, config: Dict[str, Any]):
  25. self.preprocess_op = TablePreprocess()
  26. self.batch_preprocess_op = BatchTablePreprocess()
  27. self.session = OrtInferSession(config)
  28. self.character = self.session.get_metadata()
  29. self.postprocess_op = TableLabelDecode(self.character)
  30. def process(self, img):
  31. starttime = time.time()
  32. data = {"image": img}
  33. data = self.preprocess_op(data)
  34. img = data[0]
  35. if img is None:
  36. return None, 0
  37. img = np.expand_dims(img, axis=0)
  38. img = img.copy()
  39. outputs = self.session([img])
  40. preds = {"loc_preds": outputs[0], "structure_probs": outputs[1]}
  41. shape_list = np.expand_dims(data[-1], axis=0)
  42. post_result = self.postprocess_op(preds, [shape_list])
  43. bbox_list = post_result["bbox_batch_list"][0]
  44. structure_str_list = post_result["structure_batch_list"][0]
  45. structure_str_list = structure_str_list[0]
  46. structure_str_list = (
  47. ["<html>", "<body>", "<table>"]
  48. + structure_str_list
  49. + ["</table>", "</body>", "</html>"]
  50. )
  51. elapse = time.time() - starttime
  52. return structure_str_list, bbox_list, elapse
  53. def batch_process(
  54. self, img_list: List[np.ndarray]
  55. ) -> List[Tuple[List[str], np.ndarray, float]]:
  56. """批量处理图像列表
  57. Args:
  58. img_list: 图像列表
  59. Returns:
  60. 结果列表,每个元素包含 (table_struct_str, cell_bboxes, elapse)
  61. """
  62. starttime = time.perf_counter()
  63. batch_data = self.batch_preprocess_op(img_list)
  64. preprocessed_images = batch_data[0]
  65. shape_lists = batch_data[1]
  66. preprocessed_images = np.array(preprocessed_images)
  67. bbox_preds, struct_probs = self.session([preprocessed_images])
  68. batch_size = preprocessed_images.shape[0]
  69. results = []
  70. for bbox_pred, struct_prob, shape_list in zip(
  71. bbox_preds, struct_probs, shape_lists
  72. ):
  73. preds = {
  74. "loc_preds": np.expand_dims(bbox_pred, axis=0),
  75. "structure_probs": np.expand_dims(struct_prob, axis=0),
  76. }
  77. shape_list = np.expand_dims(shape_list, axis=0)
  78. post_result = self.postprocess_op(preds, [shape_list])
  79. bbox_list = post_result["bbox_batch_list"][0]
  80. structure_str_list = post_result["structure_batch_list"][0]
  81. structure_str_list = structure_str_list[0]
  82. structure_str_list = (
  83. ["<html>", "<body>", "<table>"]
  84. + structure_str_list
  85. + ["</table>", "</body>", "</html>"]
  86. )
  87. results.append((structure_str_list, bbox_list, 0))
  88. total_elapse = time.perf_counter() - starttime
  89. for i in range(len(results)):
  90. results[i] = (results[i][0], results[i][1], total_elapse / batch_size)
  91. return results