江东的笔记

Be overcome difficulties is victory

0%

基于BERT的文本分类

通过智能化手段识别其中是否存在“虚报、假报”的情况

背景

企业自主填报安全生产隐患,对于将风险消除在事故萌芽阶段具有重要意义。企业在填报隐患时,往往存在不认真填报的情况,“虚报、假报”隐患内容,增大了企业监管的难度。采用大数据手段分析隐患内容,找出不切实履行主体责任的企业,向监管部门进行推送,实现精准执法,能够提高监管手段的有效性,增强企业安全责任意识。

任务

本赛题提供企业填报隐患数据,参赛选手需通过智能化手段识别其中是否存在“虚报、假报”的情况。

数据简介

本赛题数据集为脱敏后的企业填报自查隐患记录,数据说明如下:

  • 训练集数据包含“【id、level_1(一级标准)、level_2(二级标准)、level_3(三级标准)、level_4(四级标准)、content(隐患内容)和label(标签)】”共7个字段。
    其中“id”为主键,无业务意义;“一级标准、二级标准、三级标准、四级标准”为《深圳市安全隐患自查和巡查基本指引(2016年修订版)》规定的排查指引,一级标准对应不同隐患类型,二至四级标准是对一级标准的细化,企业自主上报隐患时,根据不同类型隐患的四级标准开展隐患自查工作;“隐患内容”为企业上报的具体隐患;“标签”标识的是该条隐患的合格性,“1”表示隐患填报不合格,“0”表示隐患填报合格。

  • 预测结果文件results.csv

1669078534834

评测标准

本赛题采用F1 -score作为模型评判标准。

具体实现代码如下:

导入所以需要的包

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 导入transformers
import transformers
# from transformers import BertModel, BertTokenizer,BertConfig, AdamW, get_linear_schedule_with_warmup
from transformers import AutoModel, AutoTokenizer,AutoConfig, AdamW, get_linear_schedule_with_warmup

import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

# 常用包
import re
import numpy as np
import pandas as pd
import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import defaultdict
from textwrap import wrap

import warnings
warnings.filterwarnings("ignore")
%matplotlib inline
%config InlineBackend.figure_format='retina' # 主题

初始化设置

1
2
3
4
5
6
7
8
9
10
11
12
sns.set(style='whitegrid', palette='muted', font_scale=1.2)
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 12, 8

# 固定随机种子
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
device(type='cuda', index=0)

读取数据

1
2
3
sub=pd.read_csv('./data/02企业隐患排查/sub.csv')
test=pd.read_csv('./data/02企业隐患排查/test.csv')
train=pd.read_csv('./data/02企业隐患排查/train.csv')
1
train.head()

id level_1 level_2 level_3 level_4 content label
0 0 工业/危化品类(现场)—2016版 (二)电气安全 6、移动用电产品、电动工具及照明 1、移动使用的用电产品和I类电动工具的绝缘线,必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。 使用移动手动电动工具,外接线绝缘皮破损,应停止使用. 0
1 1 工业/危化品类(现场)—2016版 (一)消防检查 1、防火巡查 3、消防设施、器材和消防安全标志是否在位、完整; 一般 1
2 2 工业/危化品类(现场)—2016版 (一)消防检查 2、防火检查 6、重点工种人员以及其他员工消防知识的掌握情况; 消防知识要加强 0
3 3 工业/危化品类(现场)—2016版 (一)消防检查 1、防火巡查 3、消防设施、器材和消防安全标志是否在位、完整; 消防通道有货物摆放 清理不及时 0
4 4 工业/危化品类(现场)—2016版 (一)消防检查 1、防火巡查 4、常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用; 防火门打开状态 0
1
test.head()

id level_1 level_2 level_3 level_4 content
0 0 交通运输类(现场)—2016版 (一)消防安全 2、防火检查 2、安全疏散通道、疏散指示标志、应急照明和安全出口情况。 RB1洗地机占用堵塞安全通道
1 1 工业/危化品类(选项)—2016版 (二)仓库 1、一般要求 1、库房内储存物品应分类、分堆、限额存放。 未分类堆放
2 2 工业/危化品类(现场)—2016版 (一)消防检查 1、防火巡查 3、消防设施、器材和消防安全标志是否在位、完整; 消防设施、器材和消防安全标志是否在位、完整
3 3 商贸服务教文卫类(现场)—2016版 (二)电气安全 3、电气线路及电源插头插座 3、电源插座、电源插头应按规定正确接线。 插座随意放在电器旁边
4 4 商贸服务教文卫类(现场)—2016版 (一)消防检查 1、防火巡查 6、其他消防安全情况。 检查中发现一瓶灭火器过期

查看数据的形状

1
print("train.shape,test.shape,sub.shape",train.shape,test.shape,sub.shape)
train.shape,test.shape,sub.shape (12000, 7) (18000, 6) (18000, 2)

查看是否存在空值

1
train[train['content'].isna()]

id level_1 level_2 level_3 level_4 content label
6193 6193 工业/危化品类(现场)—2016版 (一)消防检查 1、防火巡查 3、消防设施、器材和消防安全标志是否在位、完整; NaN 1
9248 9248 工业/危化品类(现场)—2016版 (一)消防检查 1、防火巡查 4、常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用; NaN 1
1
2
print('train null nums')
train.shape[0]-train.count()
train null nums

id         0
level_1    0
level_2    0
level_3    0
level_4    0
content    2
label      0
dtype: int64
1
2
print('test null nums')
test.shape[0]-test.count()
test null nums

id         0
level_1    0
level_2    0
level_3    0
level_4    0
content    4
dtype: int64

查看标签的分布

1
train['label'].value_counts()
0    10712
1     1288
Name: label, dtype: int64
1
2
# sns.countplot(train.label)
# plt.xlabel('label count')

数据预处理

1
2
train.fillna("空值",inplace=True)
test.fillna("空值",inplace=True)
1
train.shape[0]-train.count()
id         0
level_1    0
level_2    0
level_3    0
level_4    0
content    0
label      0
dtype: int64
1
train['level_3'].value_counts()
1、防火巡查             4225
2、防火检查             2911
2、配电箱(柜、板)          710
1、作业通道              664
3、电气线路及电源插头插座       497
                   ... 
3、安全带                 1
4、特种设备及操作人员管理记录       1
4、安全技术交底              1
3、停车场                 1
1、水库安全                1
Name: level_3, Length: 153, dtype: int64

对训练集处理

1
2
3
4
train['level_1'] = train['level_1'].apply(lambda x:x.split('(')[0])
train['level_2'] = train['level_2'].apply(lambda x:x.split(')')[1])
train['level_3'] = train['level_3'].apply(lambda x:re.split('[0-9]、',x)[-1])
train['level_4'] = train['level_4'].apply(lambda x:re.split('[0-9]、',x)[-1])

对测试集处理

1
2
3
4
test['level_1'] = test['level_1'].apply(lambda x:x.split('(')[0])
test['level_2'] = test['level_2'].apply(lambda x:x.split(')')[-1])
test['level_3'] = test['level_3'].apply(lambda x:re.split('[0-9]、',x)[-1])
test['level_4'] = test['level_4'].apply(lambda x:re.split('[0-9]、',x)[-1])
1
train.head()

id level_1 level_2 level_3 level_4 content label
0 0 工业/危化品类 电气安全 移动用电产品、电动工具及照明 移动使用的用电产品和I类电动工具的绝缘线,必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。 使用移动手动电动工具,外接线绝缘皮破损,应停止使用. 0
1 1 工业/危化品类 消防检查 防火巡查 消防设施、器材和消防安全标志是否在位、完整; 一般 1
2 2 工业/危化品类 消防检查 防火检查 重点工种人员以及其他员工消防知识的掌握情况; 消防知识要加强 0
3 3 工业/危化品类 消防检查 防火巡查 消防设施、器材和消防安全标志是否在位、完整; 消防通道有货物摆放 清理不及时 0
4 4 工业/危化品类 消防检查 防火巡查 常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用; 防火门打开状态 0

文本拼接

1
2
3
train['text']=train['content']+'[SEP]'+train['level_1']+'[SEP]'+train['level_2']+'[SEP]'+train['level_3']+'[SEP]'+train['level_4']
test['text']=test['content']+'[SEP]'+test['level_1']+'[SEP]'+test['level_2']+'[SEP]'+test['level_3']+'[SEP]'+test['level_4']
train.head()

id level_1 level_2 level_3 level_4 content label text
0 0 工业/危化品类 电气安全 移动用电产品、电动工具及照明 移动使用的用电产品和I类电动工具的绝缘线,必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。 使用移动手动电动工具,外接线绝缘皮破损,应停止使用. 0 使用移动手动电动工具,外接线绝缘皮破损,应停止使用.[SEP]工业/危化品类[SEP]电气安...
1 1 工业/危化品类 消防检查 防火巡查 消防设施、器材和消防安全标志是否在位、完整; 一般 1 一般[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]消防设施、器材和消...
2 2 工业/危化品类 消防检查 防火检查 重点工种人员以及其他员工消防知识的掌握情况; 消防知识要加强 0 消防知识要加强[SEP]工业/危化品类[SEP]消防检查[SEP]防火检查[SEP]重点工种...
3 3 工业/危化品类 消防检查 防火巡查 消防设施、器材和消防安全标志是否在位、完整; 消防通道有货物摆放 清理不及时 0 消防通道有货物摆放 清理不及时[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[...
4 4 工业/危化品类 消防检查 防火巡查 常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用; 防火门打开状态 0 防火门打开状态[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]常闭式防...
1
2
train['text_len']=train['text'].map(len)
train['text'].map(len).describe()# 298-12=286
count    12000.000000
mean        80.444500
std         21.910859
min         43.000000
25%         66.000000
50%         75.000000
75%         92.000000
max        298.000000
Name: text, dtype: float64
1
test['text'].map(len).describe() # 520-12=518
count    18000.000000
mean        80.762611
std         22.719823
min         43.000000
25%         66.000000
50%         76.000000
75%         92.000000
max        520.000000
Name: text, dtype: float64
1
train['text_len'].plot(kind='kde')
<AxesSubplot:ylabel='Density'>

1669079055167

1
2
sum(train['text_len']>100) # text文本长度大于100的个数     1878
sum(train['text_len']>200) # text文本长度大于200的个数 11

模型的加载和配置

embedding

1
2
3
4
5
6
PRE_TRAINED_MODEL_NAME = 'bert-base-chinese'
# PRE_TRAINED_MODEL_NAME = 'hfl/chinese-roberta-wwm-ext'
# PRE_TRAINED_MODEL_NAME = 'hfl/chinese-roberta-wwm'

# tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
1
tokenizer
PreTrainedTokenizerFast(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})
1
2
sample_txt = '今天早上9点半起床,我在学习预训练模型的使用.'
len(sample_txt)
23
1
2
3
4
5
6
tokens = tokenizer.tokenize(sample_txt)
token_ids = tokenizer.convert_tokens_to_ids(tokens)

print(f'文本为: {sample_txt}')
print(f'分词的列表为: {tokens}')
print(f'词对应的唯一id: {token_ids}')
文本为: 今天早上9点半起床,我在学习预训练模型的使用.
分词的列表为: ['今', '天', '早', '上', '9', '点', '半', '起', '床', ',', '我', '在', '学', '习', '预', '训', '练', '模', '型', '的', '使', '用', '.']
词对应的唯一id: [791, 1921, 3193, 677, 130, 4157, 1288, 6629, 2414, 8024, 2769, 1762, 2110, 739, 7564, 6378, 5298, 3563, 1798, 4638, 886, 4500, 119]

查看特殊的Token

1
tokenizer.sep_token,tokenizer.sep_token_id
('[SEP]', 102)
1
tokenizer.cls_token,tokenizer.cls_token_id
('[CLS]', 101)
1
tokenizer.pad_token,tokenizer.pad_token_id
('[PAD]', 0)
1
tokenizer.mask_token,tokenizer.mask_token_id
('[MASK]', 103)
1
tokenizer.unk_token,tokenizer.unk_token_id
('[UNK]', 100)

简单的编码测试

1
2
3
4
5
6
7
8
9
10
11
12
encoding=tokenizer.encode_plus(
sample_txt,
# sample_txt_another,
max_length=32,
add_special_tokens=True,# [CLS]和[SEP]
return_token_type_ids=True,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt',# Pytorch tensor张量

)
encoding
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.

{'input_ids': tensor([[ 101,  791, 1921, 3193,  677,  130, 4157, 1288, 6629, 2414, 8024, 2769,
         1762, 2110,  739, 7564, 6378, 5298, 3563, 1798, 4638,  886, 4500,  119,
          102,    0,    0,    0,    0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 0, 0, 0, 0, 0, 0, 0]])}
1
encoding['attention_mask'][0]
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 0, 0, 0, 0, 0, 0, 0])
1
2
3
4
5
6
7
token_lens = []
# 选的是每一句话的长度
for txt in train.text:
# print(txt)
tokens = tokenizer.encode(txt, max_length=512)
token_lens.append(len(tokens))
# token_lens
1
2
3
sns.distplot(token_lens)
plt.xlim([0, 256]);
plt.xlabel('Token count');

1669079168255

通过分析,长度一般都在160之内

1
MAX_LEN = 160
1
encoding['input_ids'].flatten()
tensor([ 101,  791, 1921, 3193,  677,  130, 4157, 1288, 6629, 2414, 8024, 2769,
        1762, 2110,  739, 7564, 6378, 5298, 3563, 1798, 4638,  886, 4500,  119,
         102,    0,    0,    0,    0,    0,    0,    0])
1
encoding['input_ids']
tensor([[ 101,  791, 1921, 3193,  677,  130, 4157, 1288, 6629, 2414, 8024, 2769,
         1762, 2110,  739, 7564, 6378, 5298, 3563, 1798, 4638,  886, 4500,  119,
          102,    0,    0,    0,    0,    0,    0,    0]])

处理数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class EnterpriseDataset(Dataset):
def __init__(self,texts,labels,tokenizer,max_len):
self.texts=texts
self.labels=labels
self.tokenizer=tokenizer
self.max_len=max_len
def __len__(self):
return len(self.texts)

def __getitem__(self,item):
"""
item 为数据索引,迭代取第item条数据
"""
text=str(self.texts[item])
label=self.labels[item]

encoding=self.tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=True,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt',
)

# print(encoding['input_ids'])
return {
'texts':text,
'input_ids':encoding['input_ids'].flatten(),
'attention_mask':encoding['attention_mask'].flatten(),
# toeken_type_ids:0
'labels':torch.tensor(label,dtype=torch.long)
}

分割数据集

1
2
3
df_train, df_test = train_test_split(train, test_size=0.1, random_state=RANDOM_SEED)
df_val, df_test = train_test_split(df_test, test_size=0.5, random_state=RANDOM_SEED)
df_train.shape, df_val.shape, df_test.shape
((10800, 9), (600, 9), (600, 9))

创建DataLoader

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def create_data_loader(df,tokenizer,max_len,batch_size):
ds=EnterpriseDataset(
texts=df['text'].values,
labels=df['label'].values,
tokenizer=tokenizer,
max_len=max_len
)

return DataLoader(
ds,
batch_size=batch_size,
# num_workers=4 # windows多线程
)

1
2
3
4
5
6
BATCH_SIZE = 4

train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)

1
next(iter(train_data_loader))
{'texts': ['指示标识不清楚[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
  '发现本月有灭火器过期,已安排购买灭火器更换[SEP]商贸服务教文卫类[SEP]消防检查[SEP]防火检查[SEP]灭火器材配置及有效情况。',
  '安全出口标志灯有一个有故障,已买回安装改正。[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
  '堵了消防通道[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;'],
 'input_ids': tensor([[ 101, 2900, 4850, 3403, 6399,  679, 3926, 3504,  102, 2339,  689,  120,
          1314, 1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,  102, 7344, 4125,
          2337, 3389,  102, 2128, 1059, 1139, 1366,  510, 4541, 3141, 6858, 6887,
          3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403,
          2562,  510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039,  102,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 1355, 4385, 3315, 3299, 3300, 4127, 4125, 1690, 6814, 3309, 8024,
          2347, 2128, 2961, 6579,  743, 4127, 4125, 1690, 3291, 2940,  102, 1555,
          6588, 3302, 1218, 3136, 3152, 1310, 5102,  102, 3867, 7344, 3466, 3389,
           102, 7344, 4125, 3466, 3389,  102, 4127, 4125, 1690, 3332, 6981, 5390,
          1350, 3300, 3126, 2658, 1105,  511,  102,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 2128, 1059, 1139, 1366, 3403, 2562, 4128, 3300,  671,  702, 3300,
          3125, 7397, 8024, 2347,  743, 1726, 2128, 6163, 3121, 3633,  511,  102,
          2339,  689,  120, 1314, 1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,
           102, 7344, 4125, 2337, 3389,  102, 2128, 1059, 1139, 1366,  510, 4541,
          3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141,
          2900, 4850, 3403, 2562,  510, 2418, 2593, 4212, 3209, 3221, 1415, 2130,
          1962, 8039,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 1843,  749, 3867, 7344, 6858, 6887,  102, 2339,  689,  120, 1314,
          1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,  102, 7344, 4125, 2337,
          3389,  102, 2128, 1059, 1139, 1366,  510, 4541, 3141, 6858, 6887, 3221,
          1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562,
           510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039,  102,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'labels': tensor([0, 0, 0, 0])}
1
2
data = next(iter(train_data_loader))
data.keys()
dict_keys(['texts', 'input_ids', 'attention_mask', 'labels'])
1
2
3
print(data['input_ids'].shape)
print(data['attention_mask'].shape)
print(data['labels'].shape)
torch.Size([4, 160])
torch.Size([4, 160])
torch.Size([4])
1
bert_model = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
1
encoding['input_ids']
tensor([[ 101,  791, 1921, 3193,  677,  130, 4157, 1288, 6629, 2414, 8024, 2769,
         1762, 2110,  739, 7564, 6378, 5298, 3563, 1798, 4638,  886, 4500,  119,
          102,    0,    0,    0,    0,    0,    0,    0]])
1
2
3
4
5
last_hidden_state, pooled_output = bert_model(
input_ids=encoding['input_ids'],
attention_mask=encoding['attention_mask'],
return_dict = False
)

查看输出结果

1
last_hidden_state[0][0].shape
torch.Size([768])
1
pooled_output
tensor([[ 0.9999,  0.9998,  0.9989,  0.9629,  0.3075, -0.1866, -0.9904,  0.8628,
          0.9710, -0.9993,  1.0000,  1.0000,  0.9312, -0.9394,  0.9998, -0.9999,
          0.0417,  0.9999,  0.9458,  0.3190,  1.0000, -1.0000, -0.9062, -0.9048,
          0.1764,  0.9983,  0.9346, -0.8122, -0.9999,  0.9996,  0.7879,  0.9999,
          0.8475, -1.0000, -1.0000,  0.9413, -0.8260,  0.9889, -0.4976, -0.9857,
         -0.9955, -0.9580,  0.5833, -0.9996, -0.8932,  0.8563, -1.0000, -0.9999,
          0.9719,  0.9999, -0.7430, -0.9993,  0.9756, -0.9754,  0.2991,  0.8933,
         -0.9991,  0.9987,  1.0000,  0.4156,  0.9992, -0.9452, -0.8020, -0.9999,
          1.0000, -0.9964, -0.9900,  0.4365,  1.0000,  1.0000, -0.9400,  0.8794,
          1.0000,  0.9105, -0.6616,  1.0000, -0.9999,  0.6892, -1.0000, -0.9817,
          1.0000,  0.9957, -0.8844, -0.8248, -0.9921, -0.9999, -0.9998,  1.0000,
          0.5228,  0.1297,  0.9932, -0.9999, -1.0000,  0.9993, -0.9996, -0.9948,
         -0.9561,  0.9996, -0.5785, -0.9386, -0.2035,  0.9086, -0.9999, -0.9993,
          0.9959,  0.9984,  0.6953, -0.9995,  1.0000,  0.8610, -1.0000, -0.4507,
         -1.0000,  0.2384, -0.9812,  0.9998,  0.9504,  0.5421,  0.9995, -0.9998,
          0.9320, -0.9941, -0.9718, -0.9910,  0.9822,  1.0000,  0.9997, -0.9990,
          1.0000,  1.0000,  0.8608,  0.9964, -0.9997,  0.9799,  0.5985, -0.9098,
          0.5329, -0.6345,  1.0000,  0.9872,  0.9970, -0.9719,  0.9988, -0.9933,
          1.0000, -0.9999,  0.9973, -1.0000, -0.6550,  0.9996,  0.8899,  1.0000,
          0.2969,  0.9999, -0.9983, -0.9991,  0.9906, -0.6590,  0.9872, -1.0000,
          0.7658,  0.7876, -0.8556,  0.6304, -1.0000,  1.0000, -0.7938,  1.0000,
          0.9898,  0.2216, -0.9942, -0.9969,  0.8345, -0.9998, -0.9779,  0.9914,
          0.5227,  0.9992, -0.9893, -0.9889,  0.2325, -0.9887, -0.9999,  0.9885,
          0.0340,  0.9284,  0.5197,  0.4143,  0.8315,  0.1585, -0.5348,  1.0000,
          0.2361,  0.9985,  0.9999, -0.3446,  0.1012, -0.9924, -1.0000, -0.7542,
          0.9999, -0.2807, -0.9999,  0.9490, -1.0000,  0.9906, -0.7288, -0.5263,
         -0.9545, -0.9999,  0.9998, -0.9286, -0.9997, -0.5303,  0.8886,  0.5605,
         -0.9989, -0.3324,  0.9804, -0.9075,  0.9905, -0.9800, -0.9946,  0.6855,
         -0.9393,  0.9929,  0.9874,  1.0000,  0.9997, -0.0714, -0.9440,  1.0000,
          0.1676, -1.0000,  0.5573, -0.9611,  0.8835,  0.9999, -0.9980,  0.9294,
          1.0000,  0.7968,  1.0000, -0.7065, -0.9793, -0.9997,  1.0000,  0.9922,
          0.9999, -0.9984, -0.9995, -0.1701, -0.5426, -1.0000, -1.0000, -0.6334,
          0.9969,  0.9999, -0.1620, -0.9818, -0.9921, -0.9994,  1.0000, -0.9759,
          1.0000,  0.8570, -0.7434, -0.9164,  0.9438, -0.7311, -0.9986, -0.3936,
         -0.9997, -0.9650, -1.0000,  0.9433, -0.9999, -1.0000,  0.6913,  1.0000,
          0.8762, -1.0000,  0.9997,  0.9764,  0.7094, -0.9294,  0.9522, -1.0000,
          1.0000, -0.9965,  0.9428, -0.9972, -0.9897, -0.7680,  0.9922,  0.9999,
         -0.9999, -0.9597, -0.9922, -0.9807, -0.3632,  0.9936, -0.7280,  0.4117,
         -0.9498, -0.9666,  0.9545, -0.9957, -0.9970,  0.4028,  1.0000, -0.9798,
          1.0000,  0.9941,  1.0000,  0.9202, -0.9942,  0.9996,  0.5352, -0.5836,
         -0.8829, -0.9418,  0.9497, -0.0532,  0.6966, -0.9999,  0.9998,  0.9917,
          0.9612,  0.7289,  0.0167,  0.3179,  0.9627, -0.9911,  0.9995, -0.9996,
         -0.6737,  0.9991,  1.0000,  0.9932,  0.4880, -0.7488,  0.9986, -0.9961,
          0.9995, -1.0000,  0.9999, -0.9940,  0.9705, -0.9970, -0.9856,  1.0000,
          0.9846, -0.7932,  0.9997, -0.9386,  0.9938,  0.9738,  0.8173,  0.9913,
          0.9981,  1.0000, -0.9998, -0.9918, -0.9727, -0.9987, -0.9955, -1.0000,
         -0.1038, -1.0000, -0.9874, -0.9287,  0.5109, -0.9056,  0.1022,  0.7864,
         -0.8197,  0.5724, -0.5905,  0.2713, -0.7239, -0.9976, -0.9844, -1.0000,
         -0.9988,  0.8835,  0.9999, -0.9997,  0.9999, -0.9999, -0.9782,  0.9383,
         -0.5609,  0.7721,  0.9999, -1.0000,  0.9585,  0.9987,  1.0000,  0.9960,
          0.9993, -0.9741, -0.9999, -0.9989, -0.9999, -1.0000, -0.9998,  0.9343,
          0.6337, -1.0000,  0.0902,  0.8980,  1.0000,  0.9964, -0.9985, -0.6136,
         -0.9996, -0.8252,  0.9996, -0.0566, -1.0000,  0.9962, -0.8744,  1.0000,
         -0.8865,  0.9879,  0.8897,  0.9571,  0.9823, -1.0000,  0.9145,  1.0000,
          0.0365, -1.0000, -0.9985, -0.9075, -0.9998,  0.0369,  0.8120,  0.9999,
         -1.0000, -0.9155, -0.9975,  0.7988,  0.9922,  0.9998,  0.9982,  0.9267,
          0.9165,  0.5368,  0.1464,  0.9998,  0.4663, -0.9989,  0.9996, -0.7952,
          0.4527, -1.0000,  0.9998,  0.4073,  0.9999,  0.9159, -0.5480, -0.6821,
         -0.9904,  0.9938,  1.0000, -0.4229, -0.4845, -0.9981, -1.0000, -0.9861,
         -0.0950, -0.4625, -0.9629, -0.9998,  0.6675, -0.5244,  1.0000,  1.0000,
          0.9924, -0.9253, -0.9974,  0.9974, -0.9012,  0.9900, -0.2582, -1.0000,
         -0.9919, -0.9986,  1.0000, -0.9716, -0.9262, -0.9911, -0.2593,  0.5919,
         -0.9999, -0.4994, -0.9962,  0.9818,  1.0000, -0.9996,  0.9918, -0.9970,
          0.7085, -0.1369,  0.8077,  0.9955, -0.3394, -0.5860, -0.6887, -0.9841,
          0.9970,  0.9987, -0.9948, -0.8401,  0.9999,  0.0856,  0.9999,  0.5099,
          0.9466,  0.9567,  1.0000,  0.8771,  1.0000, -0.0815,  1.0000,  0.9999,
         -0.9392,  0.5744,  0.8723, -0.9686,  0.5958,  0.9822,  0.9997,  0.8854,
         -0.1952, -0.9967,  0.9994,  1.0000,  1.0000, -0.3391,  0.9883, -0.4452,
          0.9252,  0.4495,  0.9870,  0.3479,  0.2266,  0.9942,  0.9990, -0.9999,
         -0.9999, -1.0000,  1.0000,  0.9996, -0.6637, -1.0000,  0.9999,  0.4543,
          0.7471,  0.9983,  0.3772, -0.9812,  0.9853, -0.9995, -0.3404,  0.9788,
          0.9867,  0.7564,  0.9995, -0.9997,  0.7990,  1.0000,  0.0752,  0.9999,
          0.2912, -0.9941,  0.9970, -0.9935, -0.9995, -0.9743,  0.9991,  0.9981,
         -0.9273, -0.8402,  0.9996, -0.9999,  0.9999, -0.9998,  0.9724, -0.9939,
          1.0000, -0.9752, -0.9998, -0.3806,  0.8830,  0.8352, -0.8892,  1.0000,
         -0.8875, -0.8107,  0.7083, -0.8909, -0.9931, -0.9630,  0.0800, -1.0000,
          0.7777, -0.9611,  0.5867, -0.9947, -0.9999,  1.0000, -0.9084, -0.9414,
          0.9999, -0.8838, -1.0000,  0.9549, -0.9999, -0.6522,  0.7967, -0.6850,
          0.1524, -1.0000,  0.4800,  0.9999, -0.9998, -0.7089, -0.9129, -0.9864,
          0.6220,  0.8855,  0.9855, -0.8651,  0.3988, -0.2548,  0.9793, -0.7212,
         -0.2582, -0.9999, -0.8692, -0.6282, -0.9999, -0.9999, -1.0000,  1.0000,
          0.9996,  0.9999, -0.5600,  0.7442,  0.9460,  0.9927, -0.9999,  0.4407,
         -0.0461,  0.9937, -0.4887, -0.9994, -0.9198, -1.0000, -0.6905,  0.3538,
         -0.7728,  0.6622,  1.0000,  0.9999, -0.9999, -0.9994, -0.9995, -0.9979,
          0.9998,  0.9999,  0.9996, -0.9072, -0.5844,  0.9997,  0.9689,  0.5231,
         -0.9999, -0.9981, -0.9999,  0.7505, -0.9922, -0.9986,  0.9971,  1.0000,
          0.8730, -1.0000, -0.9533,  1.0000,  0.9997,  1.0000, -0.7768,  0.9999,
         -0.9838,  0.9819, -0.9993,  1.0000, -1.0000,  1.0000,  0.9999,  0.9809,
          0.9984, -0.9928,  0.9776, -0.9998, -0.7407,  0.9298, -0.4495, -0.9902,
          0.8053,  0.9996, -0.9952,  1.0000,  0.9243, -0.2028,  0.8002,  0.9873,
          0.9419, -0.6913, -0.9999,  0.8162,  0.9995,  0.9509,  1.0000,  0.9177,
          0.9996, -0.9839, -0.9998,  0.9914, -0.6991, -0.7821, -0.9998,  1.0000,
          1.0000, -0.9999, -0.9227,  0.7483,  0.1186,  1.0000,  0.9963,  0.9971,
          0.9857,  0.3887,  0.9996, -0.9999,  0.8526, -0.9980, -0.8613,  0.9999,
         -0.9899,  0.9999, -0.9981,  1.0000, -0.9858,  0.9944,  0.9989,  0.9684,
         -0.9968,  1.0000,  0.8246, -0.9956, -0.8348, -0.9374, -0.9999,  0.7827]],
       grad_fn=<TanhBackward0>)
1
last_hidden_state.shape # 每个token的向量表示
torch.Size([1, 32, 768])
1
bert_model.config
BertConfig {
  "_name_or_path": "bert-base-chinese",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.24.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 21128
}

bert后面又接了一个全连接层

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class EnterpriseDangerClassifier(nn.Module):
def __init__(self, n_classes):
super(EnterpriseDangerClassifier, self).__init__()
self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
self.drop = nn.Dropout(p=0.3)
self.out = nn.Linear(self.bert.config.hidden_size, n_classes) # 两个类别
def forward(self, input_ids, attention_mask):
_, pooled_output = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict = False
)
output = self.drop(pooled_output) # dropout
return self.out(output)
1
class_names=[0,1]

将数据和模型送到CUDA

1
2
3
4
from transformers import BertModel, BertTokenizer,BertConfig, AdamW, get_linear_schedule_with_warmup

model = EnterpriseDangerClassifier(len(class_names))
model = model.to(device)
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
1
2
3
4
5
input_ids = data['input_ids'].to(device)
attention_mask = data['attention_mask'].to(device)

print(input_ids.shape) # batch size x seq length
print(attention_mask.shape) # batch size x seq length
torch.Size([4, 160])
torch.Size([4, 160])

得到单个batch的输出token

1
model(input_ids, attention_mask)
tensor([[ 0.2120, -0.4050],
        [ 0.3156, -0.4160],
        [ 0.5127, -0.4634],
        [ 0.3168,  0.5057]], device='cuda:0', grad_fn=<AddmmBackward0>)
1
F.softmax(model(input_ids, attention_mask), dim=1)
tensor([[0.4999, 0.5001],
        [0.5258, 0.4742],
        [0.5899, 0.4101],
        [0.4575, 0.5425]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

训练模型前期配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
EPOCHS = 10 # 训练轮数

optimizer = AdamW(model.parameters(), lr=2e-5, correct_bias=False)
total_steps = len(train_data_loader) * EPOCHS

# Warmup预热学习率的方式,可以使得开始训练的几个epoches或者一些steps内学习率较小,在预热的小学习率下,模型可以慢慢趋于稳定,
# 等模型相对稳定后再选择预先设置的学习率进行训练,使得模型收敛速度变得更快,模型效果更佳
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=0,
num_training_steps=total_steps
)

# optimizer: 优化器
# num_warmup_steps:初始预热步数
# num_training_steps:整个训练过程的总步数

loss_fn = nn.CrossEntropyLoss().to(device)

定义模型的训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def train_epoch(
model,
data_loader,
loss_fn,
optimizer,
device,
scheduler,
n_examples
):
model = model.train() # train模式
losses = []
correct_predictions = 0
for d in data_loader:
input_ids = d["input_ids"].to(device)
attention_mask = d["attention_mask"].to(device)
targets = d["labels"].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
_, preds = torch.max(outputs, dim=1)
loss = loss_fn(outputs, targets)
correct_predictions += torch.sum(preds == targets)
losses.append(loss.item())
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
# 预热学习率
scheduler.step()
optimizer.zero_grad()
return correct_predictions.double() / n_examples, np.mean(losses)

模型的评估函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def eval_model(model, data_loader, loss_fn, device, n_examples):
model = model.eval() # 验证预测模式

losses = []
correct_predictions = 0

with torch.no_grad():
for d in data_loader:
input_ids = d["input_ids"].to(device)
attention_mask = d["attention_mask"].to(device)
targets = d["labels"].to(device)

outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
_, preds = torch.max(outputs, dim=1)

loss = loss_fn(outputs, targets)

correct_predictions += torch.sum(preds == targets)
losses.append(loss.item())

return correct_predictions.double() / n_examples, np.mean(losses)

训练模型:10EPOCHS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
history = defaultdict(list) # 记录10轮loss和acc
best_accuracy = 0

for epoch in range(EPOCHS):

print(f'Epoch {epoch + 1}/{EPOCHS}')
print('-' * 10)

train_acc, train_loss = train_epoch(
model,
train_data_loader,
loss_fn,
optimizer,
device,
scheduler,
len(df_train)
)

print(f'Train loss {train_loss} accuracy {train_acc}')

val_acc, val_loss = eval_model(
model,
val_data_loader,
loss_fn,
device,
len(df_val)
)

print(f'Val loss {val_loss} accuracy {val_acc}')
print()

history['train_acc'].append(train_acc)
history['train_loss'].append(train_loss)
history['val_acc'].append(val_acc)
history['val_loss'].append(val_loss)

if val_acc > best_accuracy:
torch.save(model.state_dict(), 'best_model_state.bin')
best_accuracy = val_acc
Epoch 1/10
----------
Train loss 0.4988938277521757 accuracy 0.8899999999999999
Val   loss 0.4194765945523977 accuracy 0.9

Epoch 2/10
----------
Train loss 0.4967574527254328 accuracy 0.8905555555555555
Val   loss 0.43736912585794924 accuracy 0.9

Epoch 3/10
----------
Train loss 0.49347498720511795 accuracy 0.8905555555555555
Val   loss 0.41818931301434836 accuracy 0.9

Epoch 4/10
----------
Train loss 0.4900011462407807 accuracy 0.8905555555555555
Val   loss 0.42409916249414287 accuracy 0.9

Epoch 5/10
----------
Train loss 0.4952681002088098 accuracy 0.8888888888888888
Val   loss 0.31909402589624125 accuracy 0.9

Epoch 6/10
----------
Train loss 0.2478140213253425 accuracy 0.9463888888888888
Val   loss 0.1787985412031412 accuracy 0.9666666666666667

Epoch 7/10
----------
Train loss 0.17434944392257387 accuracy 0.9677777777777777
Val   loss 0.15001839348037416 accuracy 0.9700000000000001

Epoch 8/10
----------
Train loss 0.12048366091100939 accuracy 0.9775925925925926
Val   loss 0.11547344802587758 accuracy 0.9783333333333334

Epoch 9/10
----------
Train loss 0.10136666681817992 accuracy 0.9813888888888889
Val   loss 0.10292303454208498 accuracy 0.9800000000000001

Epoch 10/10
----------
Train loss 0.08721379442805402 accuracy 0.9831481481481481
Val   loss 0.12598223814862042 accuracy 0.9766666666666667

准确率绘图

1
2
3
4
5
6
7
8
9
plt.plot([i.cpu() for i in history['train_acc']], label='train accuracy')
plt.plot([i.cpu() for i in history['val_acc']], label='validation accuracy')

plt.title('Training history')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.ylim([0, 1]);

1669079408883

1
2
3
4
5
6
7
8
9
test_acc, _ = eval_model(
model,
test_data_loader,
loss_fn,
device,
len(df_test)
)

test_acc.item()
0.9783333333333334

预测模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def get_predictions(model, data_loader):
model = model.eval()

raw_texts = []
predictions = []
prediction_probs = []
real_values = []

with torch.no_grad():
for d in data_loader:
texts = d["texts"]
input_ids = d["input_ids"].to(device)
attention_mask = d["attention_mask"].to(device)
targets = d["labels"].to(device)

outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
_, preds = torch.max(outputs, dim=1) # 类别

probs = F.softmax(outputs, dim=1) # 概率

raw_texts.extend(texts)
predictions.extend(preds)
prediction_probs.extend(probs)
real_values.extend(targets)

predictions = torch.stack(predictions).cpu()
prediction_probs = torch.stack(prediction_probs).cpu()
real_values = torch.stack(real_values).cpu()
return raw_texts, predictions, prediction_probs, real_values
1
2
3
4
5
y_texts, y_pred, y_pred_probs, y_test = get_predictions(
model,
test_data_loader
)

1
print(classification_report(y_test, y_pred, target_names=[str(label) for label in class_names])) # 分类报告
              precision    recall  f1-score   support

           0       0.99      0.99      0.99       554
           1       0.84      0.89      0.86        46

    accuracy                           0.98       600
   macro avg       0.91      0.94      0.93       600
weighted avg       0.98      0.98      0.98       600

查看混淆矩阵

1
2
3
4
5
6
7
8
9
10
11
def show_confusion_matrix(confusion_matrix):
hmap = sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues")
hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0, ha='right')
hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30, ha='right')
plt.ylabel('True label')
plt.xlabel('Predicted label');


cm = confusion_matrix(y_test, y_pred)
df_cm = pd.DataFrame(cm, index=class_names, columns=class_names)
show_confusion_matrix(df_cm)

1669079467197

评估单条数据

1
2
3
4
5
6
7
8
idx = 2

sample_text = y_texts[idx]
true_label = y_test[idx]
pred_df = pd.DataFrame({
'class_names': class_names,
'values': y_pred_probs[idx]
})
1
pred_d

class_names values
0 0 0.999889
1 1 0.000111
1
2
3
print("\n".join(wrap(sample_text)))
print()
print(f'True label: {class_names[true_label]}')
小A班应急照明灯坏[SEP]商贸服务教文卫类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急
照明是否完好。

True label: 0
1
2
3
4
sns.barplot(x='values', y='class_names', data=pred_df, orient='h')
plt.ylabel('sentiment')
plt.xlabel('probability')
plt.xlim([0, 1]);

1669079499324

1
sample_text = ' '
1
2
3
4
5
6
7
8
9
encoded_text = tokenizer.encode_plus(
sample_text,
max_length=MAX_LEN,
add_special_tokens=True,
return_token_type_ids=False,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt',
)
1
2
3
4
5
6
7
8
input_ids = encoded_text['input_ids'].to(device)
attention_mask = encoded_text['attention_mask'].to(device)

output = model(input_ids, attention_mask)
_, prediction = torch.max(output, dim=1)

print(f'Sample text: {sample_text}')
print(f'Danger label : {class_names[prediction]}')
Sample text:  
Danger label  : 1