tablemaster_paddle.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import os
  2. import cv2
  3. import numpy as np
  4. from paddleocr.ppstructure.table.predict_table import TableSystem
  5. from paddleocr.ppstructure.utility import init_args
  6. from PIL import Image
  7. from magic_pdf.config.constants import * # noqa: F403
  8. class TableMasterPaddleModel(object):
  9. """This class is responsible for converting image of table into HTML format
  10. using a pre-trained model.
  11. Attributes:
  12. - table_sys: An instance of TableSystem initialized with parsed arguments.
  13. Methods:
  14. - __init__(config): Initializes the model with configuration parameters.
  15. - img2html(image): Converts a PIL Image or NumPy array to HTML string.
  16. - parse_args(**kwargs): Parses configuration arguments.
  17. """
  18. def __init__(self, config):
  19. """
  20. Parameters:
  21. - config (dict): Configuration dictionary containing model_dir and device.
  22. """
  23. args = self.parse_args(**config)
  24. self.table_sys = TableSystem(args)
  25. def img2html(self, image):
  26. """
  27. Parameters:
  28. - image (PIL.Image or np.ndarray): The image of the table to be converted.
  29. Return:
  30. - HTML (str): A string representing the HTML structure with content of the table.
  31. """
  32. if isinstance(image, Image.Image):
  33. image = np.asarray(image)
  34. image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
  35. pred_res, _ = self.table_sys(image)
  36. pred_html = pred_res['html']
  37. # res = '<td><table border="1">' + pred_html.replace("<html><body><table>", "").replace(
  38. # "</table></body></html>","") + "</table></td>\n"
  39. return pred_html
  40. def parse_args(self, **kwargs):
  41. parser = init_args()
  42. model_dir = kwargs.get('model_dir')
  43. table_model_dir = os.path.join(model_dir, TABLE_MASTER_DIR) # noqa: F405
  44. table_char_dict_path = os.path.join(model_dir, TABLE_MASTER_DICT) # noqa: F405
  45. det_model_dir = os.path.join(model_dir, DETECT_MODEL_DIR) # noqa: F405
  46. rec_model_dir = os.path.join(model_dir, REC_MODEL_DIR) # noqa: F405
  47. rec_char_dict_path = os.path.join(model_dir, REC_CHAR_DICT) # noqa: F405
  48. device = kwargs.get('device', 'cpu')
  49. use_gpu = True if device.startswith('cuda') else False
  50. config = {
  51. 'use_gpu': use_gpu,
  52. 'table_max_len': kwargs.get('table_max_len', TABLE_MAX_LEN), # noqa: F405
  53. 'table_algorithm': 'TableMaster',
  54. 'table_model_dir': table_model_dir,
  55. 'table_char_dict_path': table_char_dict_path,
  56. 'det_model_dir': det_model_dir,
  57. 'rec_model_dir': rec_model_dir,
  58. 'rec_char_dict_path': rec_char_dict_path,
  59. }
  60. parser.set_defaults(**config)
  61. return parser.parse_args([])