咨询热线
0898-08980898传真:0000-0000-000
从零写CRNN文字识别 —— (5)优化器和Loss
从零写CRNN文字识别 —— (1)准备工作
从零写CRNN文字识别 —— (2)准备配置文件
从零写CRNN文字识别 —— (3)数据加载器
从零写CRNN文字识别 —— (4)搭建模型
从零写CRNN文字识别 —— (5)优化器和Loss
从零写CRNN文字识别 —— (6)训练
代码地址:github
上一节完成了模型的前向传播,最后从LSTM层拿到了[41,32,106]的输出矩阵,这里32好理解是batch_size,106是字典的数目,41有点理解不了了…先不管了,看看代码对应的loss和优化器怎么设计的。
这里先试试adam优化器,参考的代码还封装了两个其他的优化器,首先在配置文件中设置优化器的选择以及超参:
封装的优化器函数,这段代码加到utils.py中:
这里传入了两个参数,第一个参数是配置文件,用来设置选择的优化器和超参,第二个参数是告诉优化器需要优化那个模型的参数,使用pytorch中实现的optim库就可以很快的搭建优化器,这都是固定写法了:
同时train.py中对应修改:
同时定义学习率变化的策略:
该策略的原理:
这里Loss自然的选择了CTC loss,对于该loss的详解可以参考这里:CTC Loss原理
pytorch中定义loss也比较简单如下一句话即可:
等号右边实例化CTCLoss类的对象,criterion(*params)就可以调用该类的__call__方法了。
第一步,获取CTCLoss的对象:
类初始化参数说明:
- blank:空白标签所在的label值,默认为0,需要根据实际的标签定义进行设定;
- reduction:处理output losses的方式,string类型,可选’none’ 、 ‘mean’ 及 ‘sum’,'none’表示对output losses不做任何处理,‘mean’ 则对output losses取平均值处理,‘sum’则是对output losses求和处理,默认为’mean’ 。
第二步,在迭代中调用CTCLoss()对象计算损失值
CTCLoss()对象调用形参说明:
-
log_probs:shape为(T, N, C)的模型输出张量,其中,T表示CTCLoss的输入长度也即输出序列长度,N表示训练的batch size长度,C则表示包含有空白标签的所有要预测的字符集总长度,log_probs一般需要经过torch.nn.functional.log_softmax处理后再送入到CTCLoss中;
-
targets:shape为(N, S) 或(sum(target_lengths))的张量,其中第一种类型,N表示训练的batch size长度,S则为标签长度,第二种类型,则为所有标签长度之和,但是需要注意的是targets不能包含有空白标签;
-
input_lengths:shape为(N)的张量或元组,但每一个元素的长度必须等于T即输出序列长度,一般来说模型输出序列固定后则该张量或元组的元素值均相同;
-
target_lengths:shape为(N)的张量或元组,其每一个元素指示每个训练输入序列的标签长度,但标签长度是可以变化的;
在我们的案例中输入CTCLoss的形状是[41,32,107],其中41是序列长度,32是batch_size大小,107是预测的类别数。
注意这里我在french_dict.txt的第一行加上了一个"blank",所以后续打印的类别数不再是106而是107了,截图太麻烦了前面的就不修改了
先打印下代码对应的size:
输出结果:
稍微介绍下输出:
- [41,32,107]preds的size
- BEL是我打印的第一个batch中的第一数据的标签值
- [26,29,36]是BEL三个字母在字典中对应的编码
- (177,)是一整个batch_size对应的字符的长度
- length_shape就是batch_size大小
- pred_size的大小是(32,1),每个项是序列的长度
- 43.32037是第一个batch算的loss
稍微写的有点乱,后续润色一下~