微信扫码
与创始人交个朋友
我要投稿
前面文档介绍了文档智能上多种思路及核心技术实现《【文档智能 & RAG】RAG增强之路:增强PDF解析并结构化技术路线方案及思路》,
表格识别作为文档智能的重要组成部分,面临着复杂结构和多样化格式的挑战。本文介绍的轻量级的表格识别算法模型——SLANet,旨在在保证准确率的同时提升推理速度,方便生产落地。SLANet综合了PP-LCNet作为基础网络,采用CSP-PAN进行特征融合,并引入Attention机制以实现结构与位置信息的精确解码。通过这一框架,SLANet不仅有效减少了计算资源的消耗,还增强了模型在实际应用场景中的适用性与灵活性。
PP-LCNet是一种一种轻量级的CPU卷积神经网络,在图像分类的任务上表现良好,具有很高的落地意义。PP-LCNet的准确度显著优于具有相同推理时间的先前网络结构。
DepthSepConv块: 使用MobileNetV1中的DepthSepConv作为基本块,该块没有快捷操作,减少了额外的拼接或逐元素相加操作,从而提高了推理速度。
更好的激活函数: 将BaseNet中的ReLU激活函数替换为H-Swish,提升了网络性能,同时推理时间几乎没有变化。
SE模块的适当位置: 在网络的尾部添加SE模块,以提高特征权重,从而实现更好的准确性和速度平衡。SE 模块是 SENet 提出的一种通道注意力机制,可以有效提升模型的精度。但是在 Intel CPU 端,该模块同样会带来较大的延时,如何平衡精度和速度是我们要解决的一个问题。虽然在 MobileNetV3 等基于 NAS 搜索的网络中对 SE 模块的位置进行了搜索,但是并没有得出一般的结论,我们通过实验发现,SE 模块越靠近网络的尾部对模型精度的提升越大。
更大的卷积核: 在网络的尾部使用5x5卷积核替代3x3卷积核,以在低延迟和高准确性之间取得平衡。
实验表明,更大的卷积核放在网络的中后部即可达到放在所有位置的精度,与此同时,获得更快的推理速度。PP-LCNet 最终选用了表格中第三行的方案。
更大的1x1卷积层: 在全局平均池化(GAP)层后添加一个1280维的1x1卷积层,以增强模型的拟合能力,同时推理时间增加不多。在 GoogLeNet 之后,GAP(Global-Average-Pooling)后往往直接接分类层,但是在轻量级网络中,这样会导致 GAP 后提取的特征没有得到进一步的融合和加工。如果在此后使用一个更大的 1x1 卷积层(等同于 FC 层),GAP 后的特征便不会直接经过分类层,而是先进行了融合,并将融合的特征进行分类。这样可以在不影响模型推理速度的同时大大提升准确率。
PAN结构图:相比于原始的FPN多了自下而上的特征金字塔。
CSPNet是一种处理的思想,可以和ResNet、ResNeXt和DenseNet结合。用 CSP 网络进行相邻 feature maps 之间的特征连接和融合。
CSP-PAN的引入主要有下面三个目的:
原理:
从上图看,SLANet主要由PP-LCNet + CSP-PAN + Attention组合得到。
import torch
from torch import nn
from torch.nn import functional as F
class SLAHead(nn.Module):
def __init__(self, in_channels=96, is_train=False) -> None:
super().__init__()
self.max_text_length = 500
self.hidden_size = 256
self.loc_reg_num = 4
self.out_channels = 30
self.num_embeddings = self.out_channels
self.is_train = is_train
self.structure_attention_cell = AttentionGRUCell(in_channels,
self.hidden_size,
self.num_embeddings)
self.structure_generator = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.Linear(self.hidden_size, self.out_channels)
)
self.loc_generator = nn.Sequential(
nn.Linear(self.hidden_size, self.hidden_size),
nn.Linear(self.hidden_size, self.loc_reg_num)
)
def forward(self, fea):
batch_size = fea.shape[0]
# 1 x 96 x 16 x 16 → 1 x 96 x 256
fea = torch.reshape(fea, [fea.shape[0], fea.shape[1], -1])
# 1 x 256 x 96
fea = fea.permute(0, 2, 1)
# infer 1 x 501 x 30
structure_preds = torch.zeros(batch_size, self.max_text_length + 1,
self.num_embeddings)
# 1 x 501 x 4
loc_preds = torch.zeros(batch_size, self.max_text_length + 1,
self.loc_reg_num)
hidden = torch.zeros(batch_size, self.hidden_size)
pre_chars = torch.zeros(batch_size, dtype=torch.int64)
loc_step, structure_step = None, None
for i in range(self.max_text_length + 1):
hidden, structure_step, loc_step = self._decode(pre_chars,
fea, hidden)
pre_chars = structure_step.argmax(dim=1)
structure_preds[:, i, :] = structure_step
loc_preds[:, i, :] = loc_step
if not self.is_train:
structure_preds = F.softmax(structure_preds, dim=-1)
# structure_preds: 1 x 501 x 30
# loc_preds: 1 x 501 x 4
return structure_preds, loc_preds
def _decode(self, pre_chars, features, hidden):
emb_features = F.one_hot(pre_chars, num_classes=self.num_embeddings)
(output, hidden), alpha = self.structure_attention_cell(hidden,
features,
emb_features)
structure_step = self.structure_generator(output)
loc_step = self.loc_generator(output)
return hidden, structure_step, loc_step
class AttentionGRUCell(nn.Module):
def __init__(self, input_size, hidden_size, num_embedding) -> None:
super().__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias=False)
self.gru = nn.GRU(input_size=input_size + num_embedding,
hidden_size=hidden_size,)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
# 这里实现参考论文https://arxiv.org/pdf/1704.03549.pdf
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = torch.unsqueeze(self.h2h(prev_hidden), dim=1)
res = torch.add(batch_H_proj, prev_hidden_proj)
res = F.tanh(res)
e = self.score(res)
alpha = F.softmax(e, dim=1)
alpha = alpha.permute(0, 2, 1)
context = torch.squeeze(torch.matmul(alpha, batch_H), dim=1)
concat_context = torch.concat([context, char_onehots], 1)
cur_hidden = self.gru(concat_context, prev_hidden)
return cur_hidden, alpha
class SLALoss(nn.Module):
def __init__(self) -> None:
super().__init__()
self.loss_func = nn.CrossEntropyLoss()
self.structure_weight = 1.0
self.loc_weight = 2.0
self.eps = 1e-12
def forward(self, pred):
structure_probs = pred[0]
structure_probs = structure_probs.permute(0, 2, 1)
# 1 x 30 x 501
# 1 x 501
structure_target = torch.empty(1, 501, dtype=torch.long).random_(30)
structure_loss = self.loss_func(structure_probs, structure_target)
structure_loss = structure_loss * self.structure_weight
loc_preds = pred[1] # 1 x 501 x 4
loc_targets = torch.randn(1, 501, 4)
loc_target_mask = torch.randn(1, 501, 1)
loc_loss = F.smooth_l1_loss(loc_preds * loc_target_mask,
loc_targets * loc_target_mask,
reduction='mean')
loc_loss *= self.loc_weight
loc_loss = loc_loss / (loc_target_mask.sum() + self.eps)
total_loss = structure_loss + loc_loss
return total_loss
53AI,企业落地应用大模型首选服务商
产品:大模型应用平台+智能体定制开发+落地咨询服务
承诺:先做场景POC验证,看到效果再签署服务协议。零风险落地应用大模型,已交付160+中大型企业
2024-11-08
NebulaGraph 在中医药领域的应用:构建鼻炎知识图谱
2024-11-07
轻松搭建AI版“谁是卧底”游戏,muAgent框架让知识图谱秒变编排引擎,支持复杂推理+在线协同
2024-11-06
GraphRAG 0.4来袭:增量更新+DRIFT,起飞~
2024-11-05
Obsidian AI 自动生成知识图谱辅助学习
2024-11-05
专题解读 | 图检索增强生成研究进展
2024-11-04
从CSV到Neo4j:如何用LLM实现自动化数据建模?
2024-11-03
一文读懂GraphRAG大模型知识图谱
2024-11-01
KGLA:基于知识图谱的推荐系统
2024-07-17
2024-07-11
2024-07-13
2024-08-13
2024-07-08
2024-07-12
2024-07-26
2024-07-04
2024-06-10
2024-04-10
2024-11-04
2024-10-10
2024-10-03
2024-09-27
2024-09-08
2024-09-05
2024-08-27
2024-08-24