modeling_unimernet.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. import os
  2. import re
  3. import warnings
  4. from typing import Optional
  5. import torch
  6. from ftfy import fix_text
  7. from loguru import logger
  8. from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel
  9. from transformers import VisionEncoderDecoderConfig, VisionEncoderDecoderModel
  10. from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import logger as base_model_logger
  11. from .unimer_swin import UnimerSwinConfig, UnimerSwinModel, UnimerSwinImageProcessor
  12. from .unimer_mbart import UnimerMBartConfig, UnimerMBartForCausalLM
  13. AutoConfig.register(UnimerSwinConfig.model_type, UnimerSwinConfig)
  14. AutoConfig.register(UnimerMBartConfig.model_type, UnimerMBartConfig)
  15. AutoModel.register(UnimerSwinConfig, UnimerSwinModel)
  16. AutoModelForCausalLM.register(UnimerMBartConfig, UnimerMBartForCausalLM)
  17. # TODO: rewrite tokenizer
  18. class TokenizerWrapper:
  19. def __init__(self, tokenizer):
  20. self.tokenizer = tokenizer
  21. self.pad_token_id = self.tokenizer.pad_token_id
  22. self.bos_token_id = self.tokenizer.bos_token_id
  23. self.eos_token_id = self.tokenizer.eos_token_id
  24. def __len__(self):
  25. return len(self.tokenizer)
  26. def tokenize(self, text, **kwargs):
  27. return self.tokenizer(
  28. text,
  29. return_token_type_ids=False,
  30. return_tensors="pt",
  31. padding="longest",
  32. truncation=True,
  33. **kwargs,
  34. )
  35. def token2str(self, tokens) -> list:
  36. generated_text = self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
  37. generated_text = [fix_text(text) for text in generated_text]
  38. return generated_text
  39. def detokenize(self, tokens):
  40. toks = [self.tokenizer.convert_ids_to_tokens(tok) for tok in tokens]
  41. for b in range(len(toks)):
  42. for i in reversed(range(len(toks[b]))):
  43. if toks[b][i] is None:
  44. toks[b][i] = ''
  45. toks[b][i] = toks[b][i].replace('Ġ', ' ').strip()
  46. if toks[b][i] in ([self.tokenizer.bos_token, self.tokenizer.eos_token, self.tokenizer.pad_token]):
  47. del toks[b][i]
  48. return toks
  49. LEFT_PATTERN = re.compile(r'(\\left)(\S*)')
  50. RIGHT_PATTERN = re.compile(r'(\\right)(\S*)')
  51. LEFT_COUNT_PATTERN = re.compile(r'\\left(?![a-zA-Z])')
  52. RIGHT_COUNT_PATTERN = re.compile(r'\\right(?![a-zA-Z])')
  53. LEFT_RIGHT_REMOVE_PATTERN = re.compile(r'\\left\.?|\\right\.?')
  54. def fix_latex_left_right(s):
  55. """
  56. 修复LaTeX中的\\left和\\right命令
  57. 1. 确保它们后面跟有效分隔符
  58. 2. 平衡\\left和\\right的数量
  59. """
  60. # 白名单分隔符
  61. valid_delims_list = [r'(', r')', r'[', r']', r'{', r'}', r'/', r'|',
  62. r'\{', r'\}', r'\lceil', r'\rceil', r'\lfloor',
  63. r'\rfloor', r'\backslash', r'\uparrow', r'\downarrow',
  64. r'\Uparrow', r'\Downarrow', r'\|', r'\.']
  65. # 为\left后缺失有效分隔符的情况添加点
  66. def fix_delim(match, is_left=True):
  67. cmd = match.group(1) # \left 或 \right
  68. rest = match.group(2) if len(match.groups()) > 1 else ""
  69. if not rest or rest not in valid_delims_list:
  70. return cmd + "."
  71. return match.group(0)
  72. # 使用更精确的模式匹配\left和\right命令
  73. # 确保它们是独立的命令,不是其他命令的一部分
  74. # 使用预编译正则和统一回调函数
  75. s = LEFT_PATTERN.sub(lambda m: fix_delim(m, True), s)
  76. s = RIGHT_PATTERN.sub(lambda m: fix_delim(m, False), s)
  77. # 更精确地计算\left和\right的数量
  78. left_count = len(LEFT_COUNT_PATTERN.findall(s)) # 不匹配\lefteqn等
  79. right_count = len(RIGHT_COUNT_PATTERN.findall(s)) # 不匹配\rightarrow等
  80. if left_count == right_count:
  81. # 如果数量相等,检查是否在同一组
  82. return fix_left_right_pairs(s)
  83. else:
  84. # 如果数量不等,移除所有\left和\right
  85. # logger.debug(f"latex:{s}")
  86. # logger.warning(f"left_count: {left_count}, right_count: {right_count}")
  87. return LEFT_RIGHT_REMOVE_PATTERN.sub('', s)
  88. def fix_left_right_pairs(latex_formula):
  89. """
  90. 检测并修复LaTeX公式中\\left和\\right不在同一组的情况
  91. Args:
  92. latex_formula (str): 输入的LaTeX公式
  93. Returns:
  94. str: 修复后的LaTeX公式
  95. """
  96. # 用于跟踪花括号嵌套层级
  97. brace_stack = []
  98. # 用于存储\left信息: (位置, 深度, 分隔符)
  99. left_stack = []
  100. # 存储需要调整的\right信息: (开始位置, 结束位置, 目标位置)
  101. adjustments = []
  102. i = 0
  103. while i < len(latex_formula):
  104. # 检查是否是转义字符
  105. if i > 0 and latex_formula[i - 1] == '\\':
  106. backslash_count = 0
  107. j = i - 1
  108. while j >= 0 and latex_formula[j] == '\\':
  109. backslash_count += 1
  110. j -= 1
  111. if backslash_count % 2 == 1:
  112. i += 1
  113. continue
  114. # 检测\left命令
  115. if i + 5 < len(latex_formula) and latex_formula[i:i + 5] == "\\left" and i + 5 < len(latex_formula):
  116. delimiter = latex_formula[i + 5]
  117. left_stack.append((i, len(brace_stack), delimiter))
  118. i += 6 # 跳过\left和分隔符
  119. continue
  120. # 检测\right命令
  121. elif i + 6 < len(latex_formula) and latex_formula[i:i + 6] == "\\right" and i + 6 < len(latex_formula):
  122. delimiter = latex_formula[i + 6]
  123. if left_stack:
  124. left_pos, left_depth, left_delim = left_stack.pop()
  125. # 如果\left和\right不在同一花括号深度
  126. if left_depth != len(brace_stack):
  127. # 找到\left所在花括号组的结束位置
  128. target_pos = find_group_end(latex_formula, left_pos, left_depth)
  129. if target_pos != -1:
  130. # 记录需要移动的\right
  131. adjustments.append((i, i + 7, target_pos))
  132. i += 7 # 跳过\right和分隔符
  133. continue
  134. # 处理花括号
  135. if latex_formula[i] == '{':
  136. brace_stack.append(i)
  137. elif latex_formula[i] == '}':
  138. if brace_stack:
  139. brace_stack.pop()
  140. i += 1
  141. # 应用调整,从后向前处理以避免索引变化
  142. if not adjustments:
  143. return latex_formula
  144. result = list(latex_formula)
  145. adjustments.sort(reverse=True, key=lambda x: x[0])
  146. for start, end, target in adjustments:
  147. # 提取\right部分
  148. right_part = result[start:end]
  149. # 从原位置删除
  150. del result[start:end]
  151. # 在目标位置插入
  152. result.insert(target, ''.join(right_part))
  153. return ''.join(result)
  154. def find_group_end(text, pos, depth):
  155. """查找特定深度的花括号组的结束位置"""
  156. current_depth = depth
  157. i = pos
  158. while i < len(text):
  159. if text[i] == '{' and (i == 0 or not is_escaped(text, i)):
  160. current_depth += 1
  161. elif text[i] == '}' and (i == 0 or not is_escaped(text, i)):
  162. current_depth -= 1
  163. if current_depth < depth:
  164. return i
  165. i += 1
  166. return -1 # 未找到对应结束位置
  167. def is_escaped(text, pos):
  168. """检查字符是否被转义"""
  169. backslash_count = 0
  170. j = pos - 1
  171. while j >= 0 and text[j] == '\\':
  172. backslash_count += 1
  173. j -= 1
  174. return backslash_count % 2 == 1
  175. def fix_unbalanced_braces(latex_formula):
  176. """
  177. 检测LaTeX公式中的花括号是否闭合,并删除无法配对的花括号
  178. Args:
  179. latex_formula (str): 输入的LaTeX公式
  180. Returns:
  181. str: 删除无法配对的花括号后的LaTeX公式
  182. """
  183. stack = [] # 存储左括号的索引
  184. unmatched = set() # 存储不匹配括号的索引
  185. i = 0
  186. while i < len(latex_formula):
  187. # 检查是否是转义的花括号
  188. if latex_formula[i] in ['{', '}']:
  189. # 计算前面连续的反斜杠数量
  190. backslash_count = 0
  191. j = i - 1
  192. while j >= 0 and latex_formula[j] == '\\':
  193. backslash_count += 1
  194. j -= 1
  195. # 如果前面有奇数个反斜杠,则该花括号是转义的,不参与匹配
  196. if backslash_count % 2 == 1:
  197. i += 1
  198. continue
  199. # 否则,该花括号参与匹配
  200. if latex_formula[i] == '{':
  201. stack.append(i)
  202. else: # latex_formula[i] == '}'
  203. if stack: # 有对应的左括号
  204. stack.pop()
  205. else: # 没有对应的左括号
  206. unmatched.add(i)
  207. i += 1
  208. # 所有未匹配的左括号
  209. unmatched.update(stack)
  210. # 构建新字符串,删除不匹配的括号
  211. return ''.join(char for i, char in enumerate(latex_formula) if i not in unmatched)
  212. def process_latex(input_string):
  213. """
  214. 处理LaTeX公式中的反斜杠:
  215. 1. 如果\后跟特殊字符(#$%&~_^\\{})或空格,保持不变
  216. 2. 如果\后跟两个小写字母,保持不变
  217. 3. 其他情况,在\后添加空格
  218. Args:
  219. input_string (str): 输入的LaTeX公式
  220. Returns:
  221. str: 处理后的LaTeX公式
  222. """
  223. def replace_func(match):
  224. # 获取\后面的字符
  225. next_char = match.group(1)
  226. # 如果是特殊字符或空格,保持不变
  227. if next_char in "#$%&~_^|\\{} \t\n\r\v\f":
  228. return match.group(0)
  229. # 如果是字母,检查下一个字符
  230. if 'a' <= next_char <= 'z' or 'A' <= next_char <= 'Z':
  231. pos = match.start() + 2 # \x后的位置
  232. if pos < len(input_string) and ('a' <= input_string[pos] <= 'z' or 'A' <= input_string[pos] <= 'Z'):
  233. # 下一个字符也是字母,保持不变
  234. return match.group(0)
  235. # 其他情况,在\后添加空格
  236. return '\\' + ' ' + next_char
  237. # 匹配\后面跟一个字符的情况
  238. pattern = r'\\(.)'
  239. return re.sub(pattern, replace_func, input_string)
  240. # 常见的在KaTeX/MathJax中可用的数学环境
  241. ENV_TYPES = ['array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix',
  242. 'Bmatrix', 'Vmatrix', 'cases', 'aligned', 'gathered']
  243. ENV_BEGIN_PATTERNS = {env: re.compile(r'\\begin\{' + env + r'\}') for env in ENV_TYPES}
  244. ENV_END_PATTERNS = {env: re.compile(r'\\end\{' + env + r'\}') for env in ENV_TYPES}
  245. ENV_FORMAT_PATTERNS = {env: re.compile(r'\\begin\{' + env + r'\}\{([^}]*)\}') for env in ENV_TYPES}
  246. def fix_latex_environments(s):
  247. """
  248. 检测LaTeX中环境(如array)的\\begin和\\end是否匹配
  249. 1. 如果缺少\\begin标签则在开头添加
  250. 2. 如果缺少\\end标签则在末尾添加
  251. """
  252. for env in ENV_TYPES:
  253. begin_count = len(ENV_BEGIN_PATTERNS[env].findall(s))
  254. end_count = len(ENV_END_PATTERNS[env].findall(s))
  255. if begin_count != end_count:
  256. if end_count > begin_count:
  257. format_match = ENV_FORMAT_PATTERNS[env].search(s)
  258. default_format = '{c}' if env == 'array' else ''
  259. format_str = '{' + format_match.group(1) + '}' if format_match else default_format
  260. missing_count = end_count - begin_count
  261. begin_command = '\\begin{' + env + '}' + format_str + ' '
  262. s = begin_command * missing_count + s
  263. else:
  264. missing_count = begin_count - end_count
  265. s = s + (' \\end{' + env + '}') * missing_count
  266. return s
  267. UP_PATTERN = re.compile(r'\\up([a-zA-Z]+)')
  268. COMMANDS_TO_REMOVE_PATTERN = re.compile(
  269. r'\\(?:lefteqn|boldmath|ensuremath|centering|textsubscript|sides|textsl|textcent|emph|protect|null)')
  270. REPLACEMENTS_PATTERNS = {
  271. re.compile(r'\\underbar'): r'\\underline',
  272. re.compile(r'\\Bar'): r'\\hat',
  273. re.compile(r'\\Hat'): r'\\hat',
  274. re.compile(r'\\Tilde'): r'\\tilde',
  275. re.compile(r'\\slash'): r'/',
  276. re.compile(r'\\textperthousand'): r'‰',
  277. re.compile(r'\\sun'): r'☉',
  278. re.compile(r'\\textunderscore'): r'\\_',
  279. re.compile(r'\\fint'): r'⨏',
  280. re.compile(r'\\up '): r'\\ ',
  281. re.compile(r'\\vline = '): r'\\models ',
  282. re.compile(r'\\vDash '): r'\\models ',
  283. re.compile(r'\\sq \\sqcup '): r'\\square ',
  284. re.compile(r'\\copyright'): r'©',
  285. }
  286. QQUAD_PATTERN = re.compile(r'\\qquad(?!\s)')
  287. def latex_rm_whitespace(s: str):
  288. """Remove unnecessary whitespace from LaTeX code."""
  289. s = fix_unbalanced_braces(s)
  290. s = fix_latex_left_right(s)
  291. s = fix_latex_environments(s)
  292. # 使用预编译的正则表达式
  293. s = UP_PATTERN.sub(
  294. lambda m: m.group(0) if m.group(1) in ["arrow", "downarrow", "lus", "silon"] else f"\\{m.group(1)}", s
  295. )
  296. s = COMMANDS_TO_REMOVE_PATTERN.sub('', s)
  297. # 应用所有替换
  298. for pattern, replacement in REPLACEMENTS_PATTERNS.items():
  299. s = pattern.sub(replacement, s)
  300. # 处理LaTeX中的反斜杠和空格
  301. s = process_latex(s)
  302. # \qquad后补空格
  303. s = QQUAD_PATTERN.sub(r'\\qquad ', s)
  304. # 如果字符串以反斜杠结尾,去掉最后的反斜杠
  305. while s.endswith('\\'):
  306. s = s[:-1]
  307. return s
  308. class UnimernetModel(VisionEncoderDecoderModel):
  309. def __init__(
  310. self,
  311. config: Optional[PretrainedConfig] = None,
  312. encoder: Optional[PreTrainedModel] = None,
  313. decoder: Optional[PreTrainedModel] = None,
  314. ):
  315. # VisionEncoderDecoderModel's checking log has bug, disable for temp.
  316. base_model_logger.disabled = True
  317. try:
  318. super().__init__(config, encoder, decoder)
  319. finally:
  320. base_model_logger.disabled = False
  321. if not config or not hasattr(config, "_name_or_path"):
  322. raise RuntimeError("config._name_or_path is required by UnimernetModel.")
  323. model_path = config._name_or_path
  324. self.transform = UnimerSwinImageProcessor()
  325. self.tokenizer = TokenizerWrapper(AutoTokenizer.from_pretrained(model_path))
  326. self._post_check()
  327. def _post_check(self):
  328. tokenizer = self.tokenizer
  329. if tokenizer.tokenizer.model_max_length != self.config.decoder.max_position_embeddings:
  330. warnings.warn(
  331. f"decoder.max_position_embeddings={self.config.decoder.max_position_embeddings}," +
  332. f" but tokenizer.model_max_length={tokenizer.tokenizer.model_max_length}, will set" +
  333. f" tokenizer.model_max_length to {self.config.decoder.max_position_embeddings}.")
  334. tokenizer.tokenizer.model_max_length = self.config.decoder.max_position_embeddings
  335. assert self.config.decoder.vocab_size == len(tokenizer)
  336. assert self.config.decoder_start_token_id == tokenizer.bos_token_id
  337. assert self.config.pad_token_id == tokenizer.pad_token_id
  338. @classmethod
  339. def from_checkpoint(cls, model_path: str, model_filename: str = "pytorch_model.pth", state_dict_strip_prefix="model.model."):
  340. config = VisionEncoderDecoderConfig.from_pretrained(model_path)
  341. config._name_or_path = model_path
  342. config.encoder = UnimerSwinConfig(**vars(config.encoder))
  343. config.decoder = UnimerMBartConfig(**vars(config.decoder))
  344. encoder = UnimerSwinModel(config.encoder)
  345. decoder = UnimerMBartForCausalLM(config.decoder)
  346. model = cls(config, encoder, decoder)
  347. # load model weights
  348. model_file_path = os.path.join(model_path, model_filename)
  349. checkpoint = torch.load(model_file_path, map_location="cpu", weights_only=True)
  350. state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint
  351. if not state_dict:
  352. raise RuntimeError("state_dict is empty.")
  353. if state_dict_strip_prefix:
  354. state_dict = {
  355. k[len(state_dict_strip_prefix):] if k.startswith(state_dict_strip_prefix) else k: v
  356. for k, v in state_dict.items()
  357. }
  358. missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
  359. if len(unexpected_keys) > 0:
  360. warnings.warn("Unexpected key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in unexpected_keys)))
  361. if len(missing_keys) > 0:
  362. raise RuntimeError("Missing key(s) in state_dict: {}.".format(", ".join(f'"{k}"' for k in missing_keys)))
  363. return model
  364. def forward_bak(self, samples):
  365. pixel_values, text = samples["image"], samples["text_input"]
  366. text_inputs = self.tokenizer.tokenize(text).to(pixel_values.device)
  367. decoder_input_ids, decoder_attention_mask = text_inputs["input_ids"], text_inputs["attention_mask"]
  368. num_channels = pixel_values.shape[1]
  369. if num_channels == 1:
  370. pixel_values = pixel_values.repeat(1, 3, 1, 1)
  371. labels = decoder_input_ids * 1
  372. labels = labels.masked_fill(labels == self.tokenizer.pad_token_id, -100)
  373. loss = self.model(
  374. pixel_values=pixel_values,
  375. decoder_input_ids=decoder_input_ids[:, :-1],
  376. decoder_attention_mask=decoder_attention_mask[:, :-1],
  377. labels=labels[:, 1:],
  378. ).loss
  379. return {"loss": loss}
  380. def generate(self, samples, do_sample: bool = False, temperature: float = 0.2, top_p: float = 0.95, batch_size=64):
  381. pixel_values = samples["image"]
  382. num_channels = pixel_values.shape[1]
  383. if num_channels == 1:
  384. pixel_values = pixel_values.repeat(1, 3, 1, 1)
  385. kwargs = {}
  386. if do_sample:
  387. kwargs["temperature"] = temperature
  388. kwargs["top_p"] = top_p
  389. if self.tokenizer.tokenizer.model_max_length > 1152:
  390. if batch_size <= 32:
  391. self.tokenizer.tokenizer.model_max_length = 1152 # 6g
  392. else:
  393. self.tokenizer.tokenizer.model_max_length = 1344 # 8g
  394. outputs = super().generate(
  395. pixel_values=pixel_values,
  396. max_new_tokens=self.tokenizer.tokenizer.model_max_length, # required
  397. decoder_start_token_id=self.tokenizer.tokenizer.bos_token_id,
  398. do_sample=do_sample,
  399. **kwargs,
  400. )
  401. outputs = outputs[:, 1:].cpu().numpy()
  402. pred_tokens = self.tokenizer.detokenize(outputs)
  403. pred_str = self.tokenizer.token2str(outputs)
  404. fixed_str = [latex_rm_whitespace(s) for s in pred_str]
  405. return {"pred_ids": outputs, "pred_tokens": pred_tokens, "pred_str": pred_str, "fixed_str": fixed_str}