tablemaster_paddle.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import os
  2. import cv2
  3. import numpy as np
  4. from paddleocr import PaddleOCR
  5. from ppstructure.table.predict_table import TableSystem
  6. from ppstructure.utility import init_args
  7. from PIL import Image
  8. from magic_pdf.config.constants import * # noqa: F403
  9. class TableMasterPaddleModel(object):
  10. """This class is responsible for converting image of table into HTML format
  11. using a pre-trained model.
  12. Attributes:
  13. - table_sys: An instance of TableSystem initialized with parsed arguments.
  14. Methods:
  15. - __init__(config): Initializes the model with configuration parameters.
  16. - img2html(image): Converts a PIL Image or NumPy array to HTML string.
  17. - parse_args(**kwargs): Parses configuration arguments.
  18. """
  19. def __init__(self, config):
  20. """
  21. Parameters:
  22. - config (dict): Configuration dictionary containing model_dir and device.
  23. """
  24. args = self.parse_args(**config)
  25. self.table_sys = TableSystem(args)
  26. def img2html(self, image):
  27. """
  28. Parameters:
  29. - image (PIL.Image or np.ndarray): The image of the table to be converted.
  30. Return:
  31. - HTML (str): A string representing the HTML structure with content of the table.
  32. """
  33. if isinstance(image, Image.Image):
  34. image = np.asarray(image)
  35. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  36. pred_res, _ = self.table_sys(image)
  37. pred_html = pred_res['html']
  38. # res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace(
  39. # "</table></body></html>","") + "</table></td>\n"
  40. return pred_html
  41. def parse_args(self, **kwargs):
  42. parser = init_args()
  43. model_dir = kwargs.get('model_dir')
  44. table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR) # noqa: F405
  45. table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT) # noqa: F405
  46. det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR) # noqa: F405
  47. rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR) # noqa: F405
  48. rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT) # noqa: F405
  49. device = kwargs.get('device', 'cpu')
  50. use_gpu = True if device.startswith('cuda') else False
  51. config = {
  52. 'use_gpu': use_gpu,
  53. 'table_max_len': kwargs.get('table_max_len', TABLE_MAX_LEN), # noqa: F405
  54. 'table_algorithm': 'TableMaster',
  55. 'table_model_dir': table_model_dir,
  56. 'table_char_dict_path': table_char_dict_path,
  57. 'det_model_dir': det_model_dir,
  58. 'rec_model_dir': rec_model_dir,
  59. 'rec_char_dict_path': rec_char_dict_path,
  60. }
  61. parser.set_defaults(**config)
  62. return parser.parse_args([])