StructTableModel.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. import re
  2. import torch
  3. from struct_eqtable import build_model
  4. class StructTableModel:
  5. def __init__(self, model_path, max_new_tokens=1024, max_time=60):
  6. # init
  7. assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model."
  8. self.model = build_model(
  9. model_ckpt=model_path,
  10. max_new_tokens=max_new_tokens,
  11. max_time=max_time,
  12. lmdeploy=False,
  13. flash_attn=False,
  14. batch_size=1,
  15. ).cuda()
  16. self.default_format = "html"
  17. def predict(self, images, output_format=None, **kwargs):
  18. if output_format is None:
  19. output_format = self.default_format
  20. else:
  21. if output_format not in ['latex', 'markdown', 'html']:
  22. raise ValueError(f"Output format {output_format} is not supported.")
  23. results = self.model(
  24. images, output_format=output_format
  25. )
  26. if output_format == "html":
  27. results = [self.minify_html(html) for html in results]
  28. return results
  29. def minify_html(self, html):
  30. # 移除多余的空白字符
  31. html = re.sub(r'\s+', ' ', html)
  32. # 移除行尾的空白字符
  33. html = re.sub(r'\s*>\s*', '>', html)
  34. # 移除标签前的空白字符
  35. html = re.sub(r'\s*<\s*', '<', html)
  36. return html.strip()