最近ChatDOC团队发布了一款非常实用的多模态 OCR 大模型:OCRFlux-3B,这是一个基于 Qwen2.5-VL-3B-Instruct 微调得到的模型,专为文档解析任务优化,在解析 PDF、图片内容为 Markdown文本的效果上非常亮眼,尤其值得一提的是,它原生支持跨页表格与段落合并,这是目前开源 OCR 项目中首次实现该能力的模型。
参考
https://blog.csdn.net/qq_39078903/article/details/149326447
https://zhuanlan.zhihu.com/p/25543305218
判断文本是否为 HTML 表格
def is_html_table(text):
soup = BeautifulSoup(text, "html.parser")
return soup.find('table') is not None
9-62 行:table_matrix2html函数(将矩阵形式表格转换为 HTML 表格)
def table_matrix2html(matrix_table): # 定义函数,参数为"矩阵形式"的表格(含合并标记的HTML)
soup = BeautifulSoup(matrix_table, 'html.parser') # 解析矩阵形式的表格为HTML对象
table = soup.find('table') # 获取<table>标签内容
rownum = 0 # 初始化行数(后续动态计算)
colnum = 0 # 初始化列数(后续动态计算)
cell_dict = {} # 用字典存储单元格内容,键为(行索引, 列索引),值为单元格文本
rid = 0 # 当前行索引(从0开始)
for tr in table.find_all('tr'): # 遍历表格中的每一行<tr>
cid = 0 # 当前列索引(从0开始)
for td in tr.find_all('td'): # 遍历当前行中的每一个单元格<td>
# 检查单元格是否包含合并标记(<l>列合并占位符、<t>行合并占位符、<lt>行列合并占位符)
if td.find('l'):
cell_dict[(rid, cid)] = '<l>' # 记录列合并占位符
elif td.find('t'):
cell_dict[(rid, cid)] = '<t>' # 记录行合并占位符
elif td.find('lt'):
cell_dict[(rid, cid)] = '<lt>' # 记录行列合并占位符
else:
text = td.get_text(strip=True) # 获取单元格实际文本(去除首尾空格)
cell_dict[(rid, cid)] = text # 记录实际文本
cid += 1 # 列索引自增
if colnum == 0: # 首次遍历行时,记录总列数
colnum = cid
elif cid != colnum: # 后续行若列数与首行不一致,抛出异常(确保表格结构合法)
raise Exception('colnum not match')
rid += 1 # 行索引自增
rownum = rid # 记录总行数
html_table = ['<table>'] # 初始化HTML表格字符串列表(用列表拼接效率更高)
for rid in range(rownum): # 遍历每一行(按行索引)
html_table.append('<tr>') # 添加行标签<tr>
for cid in range(colnum): # 遍历每一列(按列索引)
if (rid, cid) not in cell_dict.keys(): # 跳过已被合并的单元格(已被删除)
continue
text = cell_dict[(rid, cid)] # 获取当前单元格内容
if text == '<l>' or text == '<t>' or text == '<lt>': # 若仍存在占位符,说明结构错误
raise Exception('cell not match')
rowspan = 1 # 初始化行合并数(默认为1,即不合并)
colspan = 1 # 初始化列合并数(默认为1,即不合并)
# 计算行合并数:检查当前单元格下方行是否有<t>占位符
for r in range(rid+1, rownum):
if (r, cid) in cell_dict.keys() and cell_dict[(r, cid)] == '<t>':
rowspan += 1 # 行合并数+1
del cell_dict[(r, cid)] # 删除占位符(避免重复处理)
else:
break # 无更多行合并,退出循环
# 计算列合并数:检查当前单元格右侧列是否有<l>占位符
for c in range(cid+1, colnum):
if (rid, c) in cell_dict.keys() and cell_dict[(rid, c)] == '<l>':
colspan += 1 # 列合并数+1
del cell_dict[(rid, c)] # 删除占位符(避免重复处理)
else:
break # 无更多列合并,退出循环
# 检查行列交叉合并的占位符<lt>是否存在(确保合并区域正确)
for r in range(rid+1, rid+rowspan):
for c in range(cid+1, cid+colspan):
if cell_dict[(r, c)] != '<lt>': # 若不是<lt>,说明结构错误
raise Exception('cell not match')
del cell_dict[(r, c)] # 删除占位符(避免重复处理)
attr = '' # 初始化<td>标签的属性字符串(用于存储rowspan和colspan)
if rowspan > 1: # 若行合并数>1,添加rowspan属性
attr += ' rowspan="{}"'.format(rowspan)
if colspan > 1: # 若列合并数>1,添加colspan属性
attr += ' colspan="{}"'.format(colspan)
# 将带合并属性的<td>标签添加到HTML列表
html_table.append("<td{}>{}</td>".format(attr, text))
html_table.append('</tr>') # 结束当前行标签
html_table.append('</table>') # 结束表格标签
return "".join(html_table) # 将列表拼接为字符串,返回完整HTML表格
table_html2matrix函数(将 HTML 表格转换为矩阵形式表格)
def table_html2matrix(html_table): # 定义函数,参数为标准HTML表格
soup = BeautifulSoup(html_table, 'html.parser') # 解析HTML表格为对象
table = soup.find('table') # 获取<table>标签内容
rownum = len(table.find_all('tr')) # 总行数=<tr>标签的数量
colnum = 0 # 初始化总列数(后续计算)
tr = table.find_all('tr')[0] # 获取第一行<tr>(用于计算总列数)
for td in tr.find_all('td'): # 遍历第一行的每个单元格
colnum += td.get('colspan', 1) # 总列数=各单元格colspan之和(默认colspan=1)
# 初始化矩阵:rownum行×colnum列,初始值为None(用于存储单元格内容或合并标记)
matrix = [[None for _ in range(colnum)] for _ in range(rownum)]
rid = 0 # 当前行索引(从0开始)
for tr in table.find_all('tr'): # 遍历每一行<tr>
cid = 0 # 当前列索引(从0开始)
for td in tr.find_all('td'): # 遍历当前行的每个单元格<td>
# 找到当前行中第一个未填充的列(处理列合并导致的索引偏移)
for c in range(cid, colnum):
if matrix[rid][c] is None:
break
cid = c # 更新当前列索引为找到的空列
rowspan = td.get('rowspan', 1) # 获取行合并数(默认1)
colspan = td.get('colspan', 1) # 获取列合并数(默认1)
cell_text = td.get_text(strip=True) # 获取单元格实际文本
# 填充矩阵:覆盖合并区域的所有单元格
for r in range(rid, rid+rowspan): # 遍历合并的行范围
if r >= rownum: # 若超出总行数,抛出异常(结构错误)
raise Exception('rownum not match')
for c in range(cid, cid+colspan): # 遍历合并的列范围
if c >= colnum: # 若超出总列数,抛出异常(结构错误)
raise Exception('colnum not match')
if matrix[r][c] is not None: # 若单元格已填充,抛出异常(重复填充)
raise Exception('cell not match')
# 根据位置设置矩阵值(原始单元格/合并占位符)
if r == rid and c == cid:
matrix[r][c] = cell_text # 原始单元格:存实际文本
elif r == rid:
matrix[r][c] = '<l>' # 同一行右侧合并:列合并占位符
elif c == cid:
matrix[r][c] = '<t>' # 同一列下方合并:行合并占位符
else:
matrix[r][c] = '<lt>' # 行列交叉合并:行列合并占位符
cid += colspan # 列索引偏移(跳过合并的列)
rid += 1 # 行索引自增
# 生成矩阵形式的HTML表格字符串
matrix_table = ['<table>'] # 初始化列表
for rid in range(rownum): # 遍历每一行
matrix_table.append('<tr>') # 添加行标签
for cid in range(colnum): # 遍历每一列
matrix_table.append('<td>') # 添加单元格标签
cell_text = matrix[rid][cid] # 获取矩阵中的值(文本或占位符)
matrix_table.append(cell_text) # 添加单元格内容
matrix_table.append('</td>') # 结束单元格标签
matrix_table.append('</tr>') # 结束行标签
matrix_table.append('</table>') # 结束表格标签
return "".join(matrix_table) # 拼接为字符串,返回矩阵形式表格
113-116 行:转换函数映射字典
trans_func = { # 定义转换类型与对应函数的映射
"html2matrix": table_html2matrix, # "html2matrix"对应HTML转矩阵函数
"matrix2html": table_matrix2html, # "matrix2html"对应矩阵转HTML函数
}
118-128 行:trans_markdown_text函数(批量转换 Markdown 文本中的表格)
def trans_markdown_text(markdown_text, trans_type): # 定义函数,参数为Markdown文本和转换类型
if markdown_text == None: # 若输入文本为None,直接返回None
return None
text_list = markdown_text.split('\n\n') # 按空行分割Markdown文本(Markdown中段落/表格以空行分隔)
for i, text in enumerate(text_list): # 遍历每个分割后的元素
if is_html_table(text): # 若元素是HTML表格
# 用trans_func中对应的转换函数处理表格
text_list[i] = trans_func[trans_type](text)
return "\n\n".join(text_list) # 将处理后的元素拼接回Markdown文本并返回
该文件核心功能是实现两种表格格式的相互转换:
标准 HTML 表格(带 rowspan/colspan 属性)
矩阵形式表格(用<l>、<t>、<lt>标记合并区域)
通过trans_markdown_text函数,可批量处理 Markdown 文本中的表格,为 OCRFlux 中表格的解析、合并等后续操作提供格式支持。
turn_header_to_h1:统一 Markdown 标题格式
def turn_header_to_h1(line):
# 检查是否是以一个或多个 '#' 开头的标题行
if line.lstrip().startswith('#'):
# 去掉开头的 '#' 和其后的空格,统一转为H1格式(仅保留一个#)
new_line = "# " + line.lstrip().lstrip('#').lstrip()
return new_line
else:
return line
作用:将 Markdown 中以多个#开头的标题(如## 标题、### 标题)统一转换为 H1 格式(# 标题),避免标题层级差异对评估的干扰。
replace_single_dollar:替换单美元符号公式
def replace_single_dollar(markdown_text):
pattern = r'\$(.*?)\$' # 匹配单美元符号包裹的公式(如$a+b$)
def replace_with_brackets(match):
formula_content = match.group(1) # 获取匹配到的公式内容
return f'\\({formula_content}\\)' # 替换为LaTeX格式(\(a+b\))
replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
return replaced_text
作用:将 Markdown 中用$...$表示的行内公式转换为标准 LaTeX 格式\(...\),统一公式表示方式。
replace_double_dollar:替换双美元符号公式
def replace_double_dollar(markdown_text):
pattern = r'\$\$(.*?)\$\$' # 匹配双美元符号包裹的公式(如$$a+b$$)
def replace_with_brackets(match):
formula_content = match.group(1)
return f'\\[{formula_content}\\]' # 替换为LaTeX格式(\[a+b\])
replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
return replaced_text
作用:将 Markdown 中用$$...$$表示的块级公式转换为标准 LaTeX 格式\[...\],与行内公式处理逻辑一致。
TableTree:表格树结构类
class TableTree(Tree):
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
self.tag = tag # 节点标签(如'table'、'tr'、'td')
self.colspan = colspan # 单元格跨列数(仅td节点)
self.rowspan = rowspan # 单元格跨行数(仅td节点)
self.content = content # 单元格内容(仅td节点,token化后的文本)
self.children = list(children) # 子节点列表
def bracket(self):
"""用括号表示法展示树结构(用于调试和可视化)"""
if self.tag == 'td':
# 对td节点,包含标签、跨列、跨行和内容信息
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
(self.tag, self.colspan, self.rowspan, self.content)
else:
# 非td节点(如table、tr)仅包含标签信息
result = '"tag": %s' % self.tag
# 递归拼接子节点
for child in self.children:
result += child.bracket()
return "{{{}}}".format(result)
作用:继承Tree类,定义表格特有的树结构,包含 HTML 标签、单元格跨行列属性、内容等信息,为后续树编辑距离计算提供数据结构。
CustomConfig:树编辑距离计算配置
class CustomConfig(Config):
@staticmethod
def maximum(*sequences):
"""获取序列的最大长度(用于归一化编辑距离)"""
return max(map(len, sequences))
def normalized_distance(self, *sequences):
"""计算归一化的编辑距离(0~1之间)"""
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
def rename(self, node1, node2):
"""计算两个节点的差异度(0表示相同,1表示完全不同)"""
# 标签、跨列、跨行不同,则差异度为1
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
return 1.
# 若为td节点,还需比较内容差异
if node1.tag == 'td':
if node1.content or node2.content:
return self.normalized_distance(node1.content, node2.content)
return 0.
作用:自定义APTED库的配置,定义节点差异度计算规则(用于树编辑距离的代价函数),是计算 TEDS(Tree Edit Distance-based Similarity)的核心配置。
指标:TEDS(Tree Edit Distance-based Similarity)
不用th 5. 表格简化函数
def simplify_html_table(html_table):
# 使用BeautifulSoup解析HTML表格
soup = BeautifulSoup(html_table, 'html.parser')
# 查找table标签
table = soup.find('table')
if not table:
raise ValueError("输入的HTML不包含有效的<table>标签")
# 创建新的table标签
new_table = BeautifulSoup('<table></table>', 'html.parser').table
# 提取所有行(包括thead和tbody中的tr)
rows = table.find_all(['tr'], recursive=True)
for row in rows:
new_row = soup.new_tag('tr') # 新行
# 处理行中的单元格(th或td)
cells = row.find_all(['th', 'td'])
for cell in cells:
new_cell = soup.new_tag('td') # 将th转换为td
# 保留跨列和跨行属性
if cell.has_attr('rowspan'):
new_cell['rowspan'] = cell['rowspan']
if cell.has_attr('colspan'):
new_cell['colspan'] = cell['colspan']
new_cell.string = cell.get_text(strip=True) # 提取单元格文本
new_row.append(new_cell)
new_table.append(new_row) # 将新行添加到新表格
return str(new_table) # 返回简化后的HTML表格字符串
作用:简化 HTML 表格结构,统一将转换为,保留跨行列属性,仅保留文本内容,减少无关格式(如样式、嵌套标签)对评估的干扰。
主函数(入口逻辑)
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description="评估table_to_html任务")
parser.add_argument(
"workspace",
help="存储预测结果的目录路径"
)
parser.add_argument(
"--gt_file",
help="Ground Truth文件路径"
)
parser.add_argument("--n_jobs", type=int, default=40, help="并行任务数")
args = parser.parse_args()
# 加载预测数据
pred_data = {}
for file in os.listdir(args.workspace):
file_path = os.path.join(args.workspace, file)
# 生成对应的图片文件名(假设预测文件名为xxx.txt,对应图片为xxx.png)
pdf_name = file.split('.')[0] + ".png"
with open(file_path, "r") as f:
document_text = f.read()
# 处理公式格式
document_text = replace_single_dollar(replace_double_dollar(document_text))
# 按空行分割Markdown内容
markdown_text_list = document_text.split("\n\n")
new_markdown_text_list = []
for text in markdown_text_list:
text = text.strip()
# 过滤无关标签(水印、图片、页码、签名)
if (text.startswith("<watermark>") and text.endswith("</watermark>")) or \
(text.startswith("<img>") and text.endswith("</img>")) or \
(text.startswith("<page_number>") and text.endswith("</page_number>")) or \
(text.startswith("<signature>") and text.endswith("</signature>")):
continue
else:
# 将Markdown转换为HTML
html_text = str(markdown2.markdown(text, extras=["tables"]))
html_text = html_text.strip()
# 若为表格,简化结构
if html_text.startswith("<table>") and html_text.endswith("</table>"):
html_table = simplify_html_table(html_text)
new_markdown_text_list.append(html_table)
else:
# 统一标题格式为H1
text = turn_header_to_h1(text)
new_markdown_text_list.append(text)
# 存储处理后的预测结果(键为图片名,值为HTML内容)
pred_data[os.path.basename(pdf_name)] = "\n\n".join(new_markdown_text_list)
# 加载Ground Truth数据
gt_data = {}
with open(args.gt_file, "r") as f:
for line in f:
data = json.loads(line)
# 存储真实数据(键为图片名,值为{html: 真实表格HTML, type: 表格类型})
gt_data[data['image_name']] = {'html': data['gt_table'], 'type': data['type']}
# 初始化TEDS并批量评估
teds = TEDS(n_jobs=args.n_jobs, ignore_nodes=['b', 'thead', 'tbody'])
teds.batch_evaluate(pred_data, gt_data)
if __name__ == "__main__":
main()
作用:作为脚本入口,负责解析命令行参数、加载并预处理预测结果和 Ground Truth 数据,最终调用TEDS类进行批量评估并输出结果。