OCRFlux 学习

最近ChatDOC团队发布了一款非常实用的多模态 OCR 大模型:OCRFlux-3B,这是一个基于 Qwen2.5-VL-3B-Instruct 微调得到的模型,专为文档解析任务优化,在解析 PDF、图片内容为 Markdown文本的效果上非常亮眼,尤其值得一提的是,它原生支持跨页表格与段落合并,这是目前开源 OCR 项目中首次实现该能力的模型。

参考

https://blog.csdn.net/qq_39078903/article/details/149326447
https://zhuanlan.zhihu.com/p/25543305218

判断文本是否为 HTML 表格

def is_html_table(text):
    soup = BeautifulSoup(text, "html.parser")
    return soup.find('table') is not None

9-62 行:table_matrix2html函数(将矩阵形式表格转换为 HTML 表格)

def table_matrix2html(matrix_table):  # 定义函数,参数为"矩阵形式"的表格(含合并标记的HTML)
    soup = BeautifulSoup(matrix_table, 'html.parser')  # 解析矩阵形式的表格为HTML对象
    table = soup.find('table')  # 获取<table>标签内容
    rownum = 0  # 初始化行数(后续动态计算)
    colnum = 0  # 初始化列数(后续动态计算)
    cell_dict = {}  # 用字典存储单元格内容,键为(行索引, 列索引),值为单元格文本
    rid = 0  # 当前行索引(从0开始)
    for tr in table.find_all('tr'):  # 遍历表格中的每一行<tr>
        cid = 0  # 当前列索引(从0开始)
        for td in tr.find_all('td'):  # 遍历当前行中的每一个单元格<td>
            # 检查单元格是否包含合并标记(<l>列合并占位符、<t>行合并占位符、<lt>行列合并占位符)
            if td.find('l'):
                cell_dict[(rid, cid)] = '<l>'  # 记录列合并占位符
            elif td.find('t'):
                cell_dict[(rid, cid)] = '<t>'  # 记录行合并占位符
            elif td.find('lt'):
                cell_dict[(rid, cid)] = '<lt>'  # 记录行列合并占位符
            else:
                text = td.get_text(strip=True)  # 获取单元格实际文本(去除首尾空格)
                cell_dict[(rid, cid)] = text  # 记录实际文本
            cid += 1  # 列索引自增
        if colnum == 0:  # 首次遍历行时,记录总列数
            colnum = cid
        elif cid != colnum:  # 后续行若列数与首行不一致,抛出异常(确保表格结构合法)
            raise Exception('colnum not match')
        rid += 1  # 行索引自增
    rownum = rid  # 记录总行数
    html_table = ['<table>']  # 初始化HTML表格字符串列表(用列表拼接效率更高)
    for rid in range(rownum):  # 遍历每一行(按行索引)
        html_table.append('<tr>')  # 添加行标签<tr>
        for cid in range(colnum):  # 遍历每一列(按列索引)
            if (rid, cid) not in cell_dict.keys():  # 跳过已被合并的单元格(已被删除)
                continue
            text = cell_dict[(rid, cid)]  # 获取当前单元格内容
            if text == '<l>' or text == '<t>' or text == '<lt>':  # 若仍存在占位符,说明结构错误
                raise Exception('cell not match')
            rowspan = 1  # 初始化行合并数(默认为1,即不合并)
            colspan = 1  # 初始化列合并数(默认为1,即不合并)
            # 计算行合并数:检查当前单元格下方行是否有<t>占位符
            for r in range(rid+1, rownum):
                if (r, cid) in cell_dict.keys() and cell_dict[(r, cid)] == '<t>':
                    rowspan += 1  # 行合并数+1
                    del cell_dict[(r, cid)]  # 删除占位符(避免重复处理)
                else:
                    break  # 无更多行合并,退出循环
            # 计算列合并数:检查当前单元格右侧列是否有<l>占位符
            for c in range(cid+1, colnum):
                if (rid, c) in cell_dict.keys() and cell_dict[(rid, c)] == '<l>':
                    colspan += 1  # 列合并数+1
                    del cell_dict[(rid, c)]  # 删除占位符(避免重复处理)
                else:
                    break  # 无更多列合并,退出循环
            # 检查行列交叉合并的占位符<lt>是否存在(确保合并区域正确)
            for r in range(rid+1, rid+rowspan):
                for c in range(cid+1, cid+colspan):
                    if cell_dict[(r, c)] != '<lt>':  # 若不是<lt>,说明结构错误
                        raise Exception('cell not match')
                    del cell_dict[(r, c)]  # 删除占位符(避免重复处理)
            attr = ''  # 初始化<td>标签的属性字符串(用于存储rowspan和colspan)
            if rowspan > 1:  # 若行合并数>1,添加rowspan属性
                attr += ' rowspan="{}"'.format(rowspan)
            if colspan > 1:  # 若列合并数>1,添加colspan属性
                attr += ' colspan="{}"'.format(colspan)
            # 将带合并属性的<td>标签添加到HTML列表
            html_table.append("<td{}>{}</td>".format(attr, text))
        html_table.append('</tr>')  # 结束当前行标签
    html_table.append('</table>')  # 结束表格标签
    return "".join(html_table)  # 将列表拼接为字符串,返回完整HTML表格

table_html2matrix函数(将 HTML 表格转换为矩阵形式表格)

def table_html2matrix(html_table):  # 定义函数,参数为标准HTML表格
    soup = BeautifulSoup(html_table, 'html.parser')  # 解析HTML表格为对象
    table = soup.find('table')  # 获取<table>标签内容
    rownum = len(table.find_all('tr'))  # 总行数=<tr>标签的数量
    colnum = 0  # 初始化总列数(后续计算)
    tr = table.find_all('tr')[0]  # 获取第一行<tr>(用于计算总列数)
    for td in tr.find_all('td'):  # 遍历第一行的每个单元格
        colnum += td.get('colspan', 1)  # 总列数=各单元格colspan之和(默认colspan=1)
    # 初始化矩阵:rownum行×colnum列,初始值为None(用于存储单元格内容或合并标记)
    matrix = [[None for _ in range(colnum)] for _ in range(rownum)]

    rid = 0  # 当前行索引(从0开始)
    for tr in table.find_all('tr'):  # 遍历每一行<tr>
        cid = 0  # 当前列索引(从0开始)
        for td in tr.find_all('td'):  # 遍历当前行的每个单元格<td>
            # 找到当前行中第一个未填充的列(处理列合并导致的索引偏移)
            for c in range(cid, colnum):
                if matrix[rid][c] is None:
                    break
            cid = c  # 更新当前列索引为找到的空列
            rowspan = td.get('rowspan', 1)  # 获取行合并数(默认1)
            colspan = td.get('colspan', 1)  # 获取列合并数(默认1)
            cell_text = td.get_text(strip=True)  # 获取单元格实际文本
            # 填充矩阵:覆盖合并区域的所有单元格
            for r in range(rid, rid+rowspan):  # 遍历合并的行范围
                if r >= rownum:  # 若超出总行数,抛出异常(结构错误)
                    raise Exception('rownum not match')
                for c in range(cid, cid+colspan):  # 遍历合并的列范围
                    if c >= colnum:  # 若超出总列数,抛出异常(结构错误)
                        raise Exception('colnum not match')
                    if matrix[r][c] is not None:  # 若单元格已填充,抛出异常(重复填充)
                        raise Exception('cell not match')
                    # 根据位置设置矩阵值(原始单元格/合并占位符)
                    if r == rid and c == cid:
                        matrix[r][c] = cell_text  # 原始单元格:存实际文本
                    elif r == rid:
                        matrix[r][c] = '<l>'  # 同一行右侧合并:列合并占位符
                    elif c == cid:
                        matrix[r][c] = '<t>'  # 同一列下方合并:行合并占位符
                    else:
                        matrix[r][c] = '<lt>'  # 行列交叉合并:行列合并占位符
            cid += colspan  # 列索引偏移(跳过合并的列)
        rid += 1  # 行索引自增

    # 生成矩阵形式的HTML表格字符串
    matrix_table = ['<table>']  # 初始化列表
    for rid in range(rownum):  # 遍历每一行
        matrix_table.append('<tr>')  # 添加行标签
        for cid in range(colnum):  # 遍历每一列
            matrix_table.append('<td>')  # 添加单元格标签
            cell_text = matrix[rid][cid]  # 获取矩阵中的值(文本或占位符)
            matrix_table.append(cell_text)  # 添加单元格内容
            matrix_table.append('</td>')  # 结束单元格标签
        matrix_table.append('</tr>')  # 结束行标签
    matrix_table.append('</table>')  # 结束表格标签
    return "".join(matrix_table)  # 拼接为字符串,返回矩阵形式表格

113-116 行:转换函数映射字典

trans_func = {  # 定义转换类型与对应函数的映射
    "html2matrix": table_html2matrix,  # "html2matrix"对应HTML转矩阵函数
    "matrix2html": table_matrix2html,  # "matrix2html"对应矩阵转HTML函数
}

118-128 行:trans_markdown_text函数(批量转换 Markdown 文本中的表格)

def trans_markdown_text(markdown_text, trans_type):  # 定义函数,参数为Markdown文本和转换类型
    if markdown_text == None:  # 若输入文本为None,直接返回None
        return None
    text_list = markdown_text.split('\n\n')  # 按空行分割Markdown文本(Markdown中段落/表格以空行分隔)
    for i, text in enumerate(text_list):  # 遍历每个分割后的元素
        if is_html_table(text):  # 若元素是HTML表格
            # 用trans_func中对应的转换函数处理表格
            text_list[i] = trans_func[trans_type](text)
    return "\n\n".join(text_list)  # 将处理后的元素拼接回Markdown文本并返回
该文件核心功能是实现两种表格格式的相互转换:

标准 HTML 表格(带 rowspan/colspan 属性)
矩阵形式表格(用<l>、<t>、<lt>标记合并区域)

通过trans_markdown_text函数,可批量处理 Markdown 文本中的表格,为 OCRFlux 中表格的解析、合并等后续操作提供格式支持。

turn_header_to_h1:统一 Markdown 标题格式

def turn_header_to_h1(line):
    # 检查是否是以一个或多个 '#' 开头的标题行
    if line.lstrip().startswith('#'):
        # 去掉开头的 '#' 和其后的空格,统一转为H1格式(仅保留一个#)
        new_line = "# " + line.lstrip().lstrip('#').lstrip()
        return new_line
    else:
        return line

作用:将 Markdown 中以多个#开头的标题(如## 标题、### 标题)统一转换为 H1 格式(# 标题),避免标题层级差异对评估的干扰。

replace_single_dollar:替换单美元符号公式

def replace_single_dollar(markdown_text):
    pattern = r'\$(.*?)\$'  # 匹配单美元符号包裹的公式(如$a+b$)
    def replace_with_brackets(match):
        formula_content = match.group(1)  # 获取匹配到的公式内容
        return f'\\({formula_content}\\)'  # 替换为LaTeX格式(\(a+b\))

    replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
    return replaced_text
作用:将 Markdown 中用$...$表示的行内公式转换为标准 LaTeX 格式\(...\),统一公式表示方式。

replace_double_dollar:替换双美元符号公式

def replace_double_dollar(markdown_text):
    pattern = r'\$\$(.*?)\$\$'  # 匹配双美元符号包裹的公式(如$$a+b$$)
    def replace_with_brackets(match):
        formula_content = match.group(1) 
        return f'\\[{formula_content}\\]'  # 替换为LaTeX格式(\[a+b\])
    replaced_text = re.sub(pattern, replace_with_brackets, markdown_text, flags=re.DOTALL)
    return replaced_text
作用:将 Markdown 中用$$...$$表示的块级公式转换为标准 LaTeX 格式\[...\],与行内公式处理逻辑一致。

TableTree:表格树结构类

class TableTree(Tree):
    def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
        self.tag = tag  # 节点标签(如'table'、'tr'、'td')
        self.colspan = colspan  # 单元格跨列数(仅td节点)
        self.rowspan = rowspan  # 单元格跨行数(仅td节点)
        self.content = content  # 单元格内容(仅td节点,token化后的文本)
        self.children = list(children)  # 子节点列表

    def bracket(self):
        """用括号表示法展示树结构(用于调试和可视化)"""
        if self.tag == 'td':
            # 对td节点,包含标签、跨列、跨行和内容信息
            result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
                     (self.tag, self.colspan, self.rowspan, self.content)
        else:
            # 非td节点(如table、tr)仅包含标签信息
            result = '"tag": %s' % self.tag
        # 递归拼接子节点
        for child in self.children:
            result += child.bracket()
        return "{{{}}}".format(result)

作用:继承Tree类,定义表格特有的树结构,包含 HTML 标签、单元格跨行列属性、内容等信息,为后续树编辑距离计算提供数据结构。

CustomConfig:树编辑距离计算配置

class CustomConfig(Config):
    @staticmethod
    def maximum(*sequences):
        """获取序列的最大长度(用于归一化编辑距离)"""
        return max(map(len, sequences))

    def normalized_distance(self, *sequences):
        """计算归一化的编辑距离(0~1之间)"""
        return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)

    def rename(self, node1, node2):
        """计算两个节点的差异度(0表示相同,1表示完全不同)"""
        # 标签、跨列、跨行不同,则差异度为1
        if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
            return 1.
        # 若为td节点,还需比较内容差异
        if node1.tag == 'td':
            if node1.content or node2.content:
                return self.normalized_distance(node1.content, node2.content)
        return 0.

作用:自定义APTED库的配置,定义节点差异度计算规则(用于树编辑距离的代价函数),是计算 TEDS(Tree Edit Distance-based Similarity)的核心配置。

指标:TEDS(Tree Edit Distance-based Similarity)

不用th 5. 表格简化函数

def simplify_html_table(html_table):
    # 使用BeautifulSoup解析HTML表格
    soup = BeautifulSoup(html_table, 'html.parser')
    # 查找table标签
    table = soup.find('table')
    if not table:
        raise ValueError("输入的HTML不包含有效的<table>标签")
    # 创建新的table标签
    new_table = BeautifulSoup('<table></table>', 'html.parser').table
    # 提取所有行(包括thead和tbody中的tr)
    rows = table.find_all(['tr'], recursive=True)
    for row in rows:
        new_row = soup.new_tag('tr')  # 新行
        # 处理行中的单元格(th或td)
        cells = row.find_all(['th', 'td'])
        for cell in cells:
            new_cell = soup.new_tag('td')  # 将th转换为td
            # 保留跨列和跨行属性
            if cell.has_attr('rowspan'):
                new_cell['rowspan'] = cell['rowspan']
            if cell.has_attr('colspan'):
                new_cell['colspan'] = cell['colspan']
            new_cell.string = cell.get_text(strip=True)  # 提取单元格文本
            new_row.append(new_cell)
        new_table.append(new_row)  # 将新行添加到新表格
    return str(new_table)  # 返回简化后的HTML表格字符串

作用:简化 HTML 表格结构,统一将转换为,保留跨行列属性,仅保留文本内容,减少无关格式(如样式、嵌套标签)对评估的干扰。

主函数(入口逻辑)

def main():
    # 解析命令行参数
    parser = argparse.ArgumentParser(description="评估table_to_html任务")
    parser.add_argument(
        "workspace",
        help="存储预测结果的目录路径"
    )
    parser.add_argument(
        "--gt_file",
        help="Ground Truth文件路径"
    )
    parser.add_argument("--n_jobs", type=int, default=40, help="并行任务数")
    args = parser.parse_args()
    
    # 加载预测数据
    pred_data = {}
    for file in os.listdir(args.workspace):
        file_path = os.path.join(args.workspace, file)
        # 生成对应的图片文件名(假设预测文件名为xxx.txt,对应图片为xxx.png)
        pdf_name = file.split('.')[0] + ".png"
        with open(file_path, "r") as f:
            document_text = f.read()
            # 处理公式格式
            document_text = replace_single_dollar(replace_double_dollar(document_text))
            # 按空行分割Markdown内容
            markdown_text_list = document_text.split("\n\n")
            new_markdown_text_list = []
            for text in markdown_text_list:
                text = text.strip()
                # 过滤无关标签(水印、图片、页码、签名)
                if (text.startswith("<watermark>") and text.endswith("</watermark>")) or \
                   (text.startswith("<img>") and text.endswith("</img>")) or \
                   (text.startswith("<page_number>") and text.endswith("</page_number>")) or \
                   (text.startswith("<signature>") and text.endswith("</signature>")):
                    continue
                else:
                    # 将Markdown转换为HTML
                    html_text = str(markdown2.markdown(text, extras=["tables"]))
                    html_text = html_text.strip()
                    # 若为表格,简化结构
                    if html_text.startswith("<table>") and html_text.endswith("</table>"):
                        html_table = simplify_html_table(html_text)
                        new_markdown_text_list.append(html_table)
                    else:
                        # 统一标题格式为H1
                        text = turn_header_to_h1(text)
                        new_markdown_text_list.append(text)
            # 存储处理后的预测结果(键为图片名,值为HTML内容)
            pred_data[os.path.basename(pdf_name)] = "\n\n".join(new_markdown_text_list)
    
    # 加载Ground Truth数据
    gt_data = {}
    with open(args.gt_file, "r") as f:
        for line in f:
            data = json.loads(line)
            # 存储真实数据(键为图片名,值为{html: 真实表格HTML, type: 表格类型})
            gt_data[data['image_name']] = {'html': data['gt_table'], 'type': data['type']}
    
    # 初始化TEDS并批量评估
    teds = TEDS(n_jobs=args.n_jobs, ignore_nodes=['b', 'thead', 'tbody'])
    teds.batch_evaluate(pred_data, gt_data)

if __name__ == "__main__":
    main()

作用:作为脚本入口,负责解析命令行参数、加载并预处理预测结果和 Ground Truth 数据,最终调用TEDS类进行批量评估并输出结果。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

向上Claire

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值