通过智能化手段识别其中是否存在“虚报、假报”的情况
背景 企业自主填报安全生产隐患,对于将风险消除在事故萌芽阶段具有重要意义。企业在填报隐患时,往往存在不认真填报的情况,“虚报、假报”隐患内容,增大了企业监管的难度。采用大数据手段分析隐患内容,找出不切实履行主体责任的企业,向监管部门进行推送,实现精准执法,能够提高监管手段的有效性,增强企业安全责任意识。
任务 本赛题提供企业填报隐患数据,参赛选手需通过智能化手段识别其中是否存在“虚报、假报”的情况。
数据简介 本赛题数据集为脱敏后的企业填报自查隐患记录,数据说明如下:
评测标准 本赛题采用F1 -score作为模型评判标准。
具体实现代码如下: 导入所以需要的包 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 import transformersfrom transformers import AutoModel, AutoTokenizer,AutoConfig, AdamW, get_linear_schedule_with_warmupimport torchfrom torch import nn, optimfrom torch.utils.data import Dataset, DataLoaderimport torch.nn.functional as Fimport reimport numpy as npimport pandas as pdimport seaborn as snsfrom pylab import rcParamsimport matplotlib.pyplot as pltfrom matplotlib import rcfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import confusion_matrix, classification_reportfrom collections import defaultdictfrom textwrap import wrapimport warningswarnings.filterwarnings("ignore" ) %matplotlib inline %config InlineBackend.figure_format='retina'
初始化设置 1 2 3 4 5 6 7 8 9 10 11 12 sns.set (style='whitegrid' , palette='muted' , font_scale=1.2 ) HAPPY_COLORS_PALETTE = ["#01BEFE" , "#FFDD00" , "#FF7D00" , "#FF006D" , "#ADFF02" , "#8F00FF" ] sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE)) rcParams['figure.figsize' ] = 12 , 8 RANDOM_SEED = 42 np.random.seed(RANDOM_SEED) torch.manual_seed(RANDOM_SEED) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu" ) device
device(type='cuda', index=0)
读取数据 1 2 3 sub=pd.read_csv('./data/02企业隐患排查/sub.csv' ) test=pd.read_csv('./data/02企业隐患排查/test.csv' ) train=pd.read_csv('./data/02企业隐患排查/train.csv' )
id
level_1
level_2
level_3
level_4
content
label
0
0
工业/危化品类(现场)—2016版
(二)电气安全
6、移动用电产品、电动工具及照明
1、移动使用的用电产品和I类电动工具的绝缘线,必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。
使用移动手动电动工具,外接线绝缘皮破损,应停止使用.
0
1
1
工业/危化品类(现场)—2016版
(一)消防检查
1、防火巡查
3、消防设施、器材和消防安全标志是否在位、完整;
一般
1
2
2
工业/危化品类(现场)—2016版
(一)消防检查
2、防火检查
6、重点工种人员以及其他员工消防知识的掌握情况;
消防知识要加强
0
3
3
工业/危化品类(现场)—2016版
(一)消防检查
1、防火巡查
3、消防设施、器材和消防安全标志是否在位、完整;
消防通道有货物摆放 清理不及时
0
4
4
工业/危化品类(现场)—2016版
(一)消防检查
1、防火巡查
4、常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用;
防火门打开状态
0
id
level_1
level_2
level_3
level_4
content
0
0
交通运输类(现场)—2016版
(一)消防安全
2、防火检查
2、安全疏散通道、疏散指示标志、应急照明和安全出口情况。
RB1洗地机占用堵塞安全通道
1
1
工业/危化品类(选项)—2016版
(二)仓库
1、一般要求
1、库房内储存物品应分类、分堆、限额存放。
未分类堆放
2
2
工业/危化品类(现场)—2016版
(一)消防检查
1、防火巡查
3、消防设施、器材和消防安全标志是否在位、完整;
消防设施、器材和消防安全标志是否在位、完整
3
3
商贸服务教文卫类(现场)—2016版
(二)电气安全
3、电气线路及电源插头插座
3、电源插座、电源插头应按规定正确接线。
插座随意放在电器旁边
4
4
商贸服务教文卫类(现场)—2016版
(一)消防检查
1、防火巡查
6、其他消防安全情况。
检查中发现一瓶灭火器过期
查看数据的形状 1 print ("train.shape,test.shape,sub.shape" ,train.shape,test.shape,sub.shape)
train.shape,test.shape,sub.shape (12000, 7) (18000, 6) (18000, 2)
查看是否存在空值 1 train[train['content' ].isna()]
id
level_1
level_2
level_3
level_4
content
label
6193
6193
工业/危化品类(现场)—2016版
(一)消防检查
1、防火巡查
3、消防设施、器材和消防安全标志是否在位、完整;
NaN
1
9248
9248
工业/危化品类(现场)—2016版
(一)消防检查
1、防火巡查
4、常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用;
NaN
1
1 2 print ('train null nums' )train.shape[0 ]-train.count()
train null nums
id 0
level_1 0
level_2 0
level_3 0
level_4 0
content 2
label 0
dtype: int64
1 2 print ('test null nums' )test.shape[0 ]-test.count()
test null nums
id 0
level_1 0
level_2 0
level_3 0
level_4 0
content 4
dtype: int64
查看标签的分布 1 train['label' ].value_counts()
0 10712
1 1288
Name: label, dtype: int64
数据预处理 1 2 train.fillna("空值" ,inplace=True ) test.fillna("空值" ,inplace=True )
1 train.shape[0 ]-train.count()
id 0
level_1 0
level_2 0
level_3 0
level_4 0
content 0
label 0
dtype: int64
1 train['level_3' ].value_counts()
1、防火巡查 4225
2、防火检查 2911
2、配电箱(柜、板) 710
1、作业通道 664
3、电气线路及电源插头插座 497
...
3、安全带 1
4、特种设备及操作人员管理记录 1
4、安全技术交底 1
3、停车场 1
1、水库安全 1
Name: level_3, Length: 153, dtype: int64
对训练集处理 1 2 3 4 train['level_1' ] = train['level_1' ].apply(lambda x:x.split('(' )[0 ]) train['level_2' ] = train['level_2' ].apply(lambda x:x.split(')' )[1 ]) train['level_3' ] = train['level_3' ].apply(lambda x:re.split('[0-9]、' ,x)[-1 ]) train['level_4' ] = train['level_4' ].apply(lambda x:re.split('[0-9]、' ,x)[-1 ])
对测试集处理 1 2 3 4 test['level_1' ] = test['level_1' ].apply(lambda x:x.split('(' )[0 ]) test['level_2' ] = test['level_2' ].apply(lambda x:x.split(')' )[-1 ]) test['level_3' ] = test['level_3' ].apply(lambda x:re.split('[0-9]、' ,x)[-1 ]) test['level_4' ] = test['level_4' ].apply(lambda x:re.split('[0-9]、' ,x)[-1 ])
id
level_1
level_2
level_3
level_4
content
label
0
0
工业/危化品类
电气安全
移动用电产品、电动工具及照明
移动使用的用电产品和I类电动工具的绝缘线,必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。
使用移动手动电动工具,外接线绝缘皮破损,应停止使用.
0
1
1
工业/危化品类
消防检查
防火巡查
消防设施、器材和消防安全标志是否在位、完整;
一般
1
2
2
工业/危化品类
消防检查
防火检查
重点工种人员以及其他员工消防知识的掌握情况;
消防知识要加强
0
3
3
工业/危化品类
消防检查
防火巡查
消防设施、器材和消防安全标志是否在位、完整;
消防通道有货物摆放 清理不及时
0
4
4
工业/危化品类
消防检查
防火巡查
常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用;
防火门打开状态
0
文本拼接 1 2 3 train['text' ]=train['content' ]+'[SEP]' +train['level_1' ]+'[SEP]' +train['level_2' ]+'[SEP]' +train['level_3' ]+'[SEP]' +train['level_4' ] test['text' ]=test['content' ]+'[SEP]' +test['level_1' ]+'[SEP]' +test['level_2' ]+'[SEP]' +test['level_3' ]+'[SEP]' +test['level_4' ] train.head()
id
level_1
level_2
level_3
level_4
content
label
text
0
0
工业/危化品类
电气安全
移动用电产品、电动工具及照明
移动使用的用电产品和I类电动工具的绝缘线,必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。
使用移动手动电动工具,外接线绝缘皮破损,应停止使用.
0
使用移动手动电动工具,外接线绝缘皮破损,应停止使用.[SEP]工业/危化品类[SEP]电气安...
1
1
工业/危化品类
消防检查
防火巡查
消防设施、器材和消防安全标志是否在位、完整;
一般
1
一般[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]消防设施、器材和消...
2
2
工业/危化品类
消防检查
防火检查
重点工种人员以及其他员工消防知识的掌握情况;
消防知识要加强
0
消防知识要加强[SEP]工业/危化品类[SEP]消防检查[SEP]防火检查[SEP]重点工种...
3
3
工业/危化品类
消防检查
防火巡查
消防设施、器材和消防安全标志是否在位、完整;
消防通道有货物摆放 清理不及时
0
消防通道有货物摆放 清理不及时[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[...
4
4
工业/危化品类
消防检查
防火巡查
常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用;
防火门打开状态
0
防火门打开状态[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]常闭式防...
1 2 train['text_len' ]=train['text' ].map (len ) train['text' ].map (len ).describe()
count 12000.000000
mean 80.444500
std 21.910859
min 43.000000
25% 66.000000
50% 75.000000
75% 92.000000
max 298.000000
Name: text, dtype: float64
1 test['text' ].map (len ).describe()
count 18000.000000
mean 80.762611
std 22.719823
min 43.000000
25% 66.000000
50% 76.000000
75% 92.000000
max 520.000000
Name: text, dtype: float64
1 train['text_len' ].plot(kind='kde' )
<AxesSubplot:ylabel='Density'>
1 2 sum (train['text_len' ]>100 ) sum (train['text_len' ]>200 )
模型的加载和配置 embedding 1 2 3 4 5 6 PRE_TRAINED_MODEL_NAME = 'bert-base-chinese' tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
PreTrainedTokenizerFast(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})
1 2 sample_txt = '今天早上9点半起床,我在学习预训练模型的使用.' len (sample_txt)
23
1 2 3 4 5 6 tokens = tokenizer.tokenize(sample_txt) token_ids = tokenizer.convert_tokens_to_ids(tokens) print (f'文本为: {sample_txt} ' )print (f'分词的列表为: {tokens} ' )print (f'词对应的唯一id: {token_ids} ' )
文本为: 今天早上9点半起床,我在学习预训练模型的使用.
分词的列表为: ['今', '天', '早', '上', '9', '点', '半', '起', '床', ',', '我', '在', '学', '习', '预', '训', '练', '模', '型', '的', '使', '用', '.']
词对应的唯一id: [791, 1921, 3193, 677, 130, 4157, 1288, 6629, 2414, 8024, 2769, 1762, 2110, 739, 7564, 6378, 5298, 3563, 1798, 4638, 886, 4500, 119]
查看特殊的Token 1 tokenizer.sep_token,tokenizer.sep_token_id
('[SEP]', 102)
1 tokenizer.cls_token,tokenizer.cls_token_id
('[CLS]', 101)
1 tokenizer.pad_token,tokenizer.pad_token_id
('[PAD]', 0)
1 tokenizer.mask_token,tokenizer.mask_token_id
('[MASK]', 103)
1 tokenizer.unk_token,tokenizer.unk_token_id
('[UNK]', 100)
简单的编码测试 1 2 3 4 5 6 7 8 9 10 11 12 encoding=tokenizer.encode_plus( sample_txt, max_length=32 , add_special_tokens=True , return_token_type_ids=True , pad_to_max_length=True , return_attention_mask=True , return_tensors='pt' , ) encoding
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
{'input_ids': tensor([[ 101, 791, 1921, 3193, 677, 130, 4157, 1288, 6629, 2414, 8024, 2769,
1762, 2110, 739, 7564, 6378, 5298, 3563, 1798, 4638, 886, 4500, 119,
102, 0, 0, 0, 0, 0, 0, 0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 0, 0, 0, 0, 0, 0, 0]])}
1 encoding['attention_mask' ][0 ]
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 0, 0, 0, 0, 0, 0, 0])
1 2 3 4 5 6 7 token_lens = [] for txt in train.text: tokens = tokenizer.encode(txt, max_length=512 ) token_lens.append(len (tokens))
1 2 3 sns.distplot(token_lens) plt.xlim([0 , 256 ]); plt.xlabel('Token count' );
通过分析,长度一般都在160之内
1 encoding['input_ids' ].flatten()
tensor([ 101, 791, 1921, 3193, 677, 130, 4157, 1288, 6629, 2414, 8024, 2769,
1762, 2110, 739, 7564, 6378, 5298, 3563, 1798, 4638, 886, 4500, 119,
102, 0, 0, 0, 0, 0, 0, 0])
tensor([[ 101, 791, 1921, 3193, 677, 130, 4157, 1288, 6629, 2414, 8024, 2769,
1762, 2110, 739, 7564, 6378, 5298, 3563, 1798, 4638, 886, 4500, 119,
102, 0, 0, 0, 0, 0, 0, 0]])
处理数据 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 class EnterpriseDataset (Dataset ): def __init__ (self,texts,labels,tokenizer,max_len ): self.texts=texts self.labels=labels self.tokenizer=tokenizer self.max_len=max_len def __len__ (self ): return len (self.texts) def __getitem__ (self,item ): """ item 为数据索引,迭代取第item条数据 """ text=str (self.texts[item]) label=self.labels[item] encoding=self.tokenizer.encode_plus( text, add_special_tokens=True , max_length=self.max_len, return_token_type_ids=True , pad_to_max_length=True , return_attention_mask=True , return_tensors='pt' , ) return { 'texts' :text, 'input_ids' :encoding['input_ids' ].flatten(), 'attention_mask' :encoding['attention_mask' ].flatten(), 'labels' :torch.tensor(label,dtype=torch.long) }
分割数据集 1 2 3 df_train, df_test = train_test_split(train, test_size=0.1 , random_state=RANDOM_SEED) df_val, df_test = train_test_split(df_test, test_size=0.5 , random_state=RANDOM_SEED) df_train.shape, df_val.shape, df_test.shape
((10800, 9), (600, 9), (600, 9))
创建DataLoader 1 2 3 4 5 6 7 8 9 10 11 12 13 14 def create_data_loader (df,tokenizer,max_len,batch_size ): ds=EnterpriseDataset( texts=df['text' ].values, labels=df['label' ].values, tokenizer=tokenizer, max_len=max_len ) return DataLoader( ds, batch_size=batch_size, )
1 2 3 4 5 6 BATCH_SIZE = 4 train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE) val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE) test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)
1 next (iter (train_data_loader))
{'texts': ['指示标识不清楚[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
'发现本月有灭火器过期,已安排购买灭火器更换[SEP]商贸服务教文卫类[SEP]消防检查[SEP]防火检查[SEP]灭火器材配置及有效情况。',
'安全出口标志灯有一个有故障,已买回安装改正。[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
'堵了消防通道[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;'],
'input_ids': tensor([[ 101, 2900, 4850, 3403, 6399, 679, 3926, 3504, 102, 2339, 689, 120,
1314, 1265, 1501, 5102, 102, 3867, 7344, 3466, 3389, 102, 7344, 4125,
2337, 3389, 102, 2128, 1059, 1139, 1366, 510, 4541, 3141, 6858, 6887,
3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403,
2562, 510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039, 102,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[ 101, 1355, 4385, 3315, 3299, 3300, 4127, 4125, 1690, 6814, 3309, 8024,
2347, 2128, 2961, 6579, 743, 4127, 4125, 1690, 3291, 2940, 102, 1555,
6588, 3302, 1218, 3136, 3152, 1310, 5102, 102, 3867, 7344, 3466, 3389,
102, 7344, 4125, 3466, 3389, 102, 4127, 4125, 1690, 3332, 6981, 5390,
1350, 3300, 3126, 2658, 1105, 511, 102, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[ 101, 2128, 1059, 1139, 1366, 3403, 2562, 4128, 3300, 671, 702, 3300,
3125, 7397, 8024, 2347, 743, 1726, 2128, 6163, 3121, 3633, 511, 102,
2339, 689, 120, 1314, 1265, 1501, 5102, 102, 3867, 7344, 3466, 3389,
102, 7344, 4125, 2337, 3389, 102, 2128, 1059, 1139, 1366, 510, 4541,
3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141,
2900, 4850, 3403, 2562, 510, 2418, 2593, 4212, 3209, 3221, 1415, 2130,
1962, 8039, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0],
[ 101, 1843, 749, 3867, 7344, 6858, 6887, 102, 2339, 689, 120, 1314,
1265, 1501, 5102, 102, 3867, 7344, 3466, 3389, 102, 7344, 4125, 2337,
3389, 102, 2128, 1059, 1139, 1366, 510, 4541, 3141, 6858, 6887, 3221,
1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562,
510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039, 102, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0]]),
'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
'labels': tensor([0, 0, 0, 0])}
1 2 data = next (iter (train_data_loader)) data.keys()
dict_keys(['texts', 'input_ids', 'attention_mask', 'labels'])
1 2 3 print (data['input_ids' ].shape)print (data['attention_mask' ].shape)print (data['labels' ].shape)
torch.Size([4, 160])
torch.Size([4, 160])
torch.Size([4])
1 bert_model = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME)
tensor([[ 101, 791, 1921, 3193, 677, 130, 4157, 1288, 6629, 2414, 8024, 2769,
1762, 2110, 739, 7564, 6378, 5298, 3563, 1798, 4638, 886, 4500, 119,
102, 0, 0, 0, 0, 0, 0, 0]])
1 2 3 4 5 last_hidden_state, pooled_output = bert_model( input_ids=encoding['input_ids' ], attention_mask=encoding['attention_mask' ], return_dict = False )
查看输出结果 1 last_hidden_state[0 ][0 ].shape
torch.Size([768])
tensor([[ 0.9999, 0.9998, 0.9989, 0.9629, 0.3075, -0.1866, -0.9904, 0.8628,
0.9710, -0.9993, 1.0000, 1.0000, 0.9312, -0.9394, 0.9998, -0.9999,
0.0417, 0.9999, 0.9458, 0.3190, 1.0000, -1.0000, -0.9062, -0.9048,
0.1764, 0.9983, 0.9346, -0.8122, -0.9999, 0.9996, 0.7879, 0.9999,
0.8475, -1.0000, -1.0000, 0.9413, -0.8260, 0.9889, -0.4976, -0.9857,
-0.9955, -0.9580, 0.5833, -0.9996, -0.8932, 0.8563, -1.0000, -0.9999,
0.9719, 0.9999, -0.7430, -0.9993, 0.9756, -0.9754, 0.2991, 0.8933,
-0.9991, 0.9987, 1.0000, 0.4156, 0.9992, -0.9452, -0.8020, -0.9999,
1.0000, -0.9964, -0.9900, 0.4365, 1.0000, 1.0000, -0.9400, 0.8794,
1.0000, 0.9105, -0.6616, 1.0000, -0.9999, 0.6892, -1.0000, -0.9817,
1.0000, 0.9957, -0.8844, -0.8248, -0.9921, -0.9999, -0.9998, 1.0000,
0.5228, 0.1297, 0.9932, -0.9999, -1.0000, 0.9993, -0.9996, -0.9948,
-0.9561, 0.9996, -0.5785, -0.9386, -0.2035, 0.9086, -0.9999, -0.9993,
0.9959, 0.9984, 0.6953, -0.9995, 1.0000, 0.8610, -1.0000, -0.4507,
-1.0000, 0.2384, -0.9812, 0.9998, 0.9504, 0.5421, 0.9995, -0.9998,
0.9320, -0.9941, -0.9718, -0.9910, 0.9822, 1.0000, 0.9997, -0.9990,
1.0000, 1.0000, 0.8608, 0.9964, -0.9997, 0.9799, 0.5985, -0.9098,
0.5329, -0.6345, 1.0000, 0.9872, 0.9970, -0.9719, 0.9988, -0.9933,
1.0000, -0.9999, 0.9973, -1.0000, -0.6550, 0.9996, 0.8899, 1.0000,
0.2969, 0.9999, -0.9983, -0.9991, 0.9906, -0.6590, 0.9872, -1.0000,
0.7658, 0.7876, -0.8556, 0.6304, -1.0000, 1.0000, -0.7938, 1.0000,
0.9898, 0.2216, -0.9942, -0.9969, 0.8345, -0.9998, -0.9779, 0.9914,
0.5227, 0.9992, -0.9893, -0.9889, 0.2325, -0.9887, -0.9999, 0.9885,
0.0340, 0.9284, 0.5197, 0.4143, 0.8315, 0.1585, -0.5348, 1.0000,
0.2361, 0.9985, 0.9999, -0.3446, 0.1012, -0.9924, -1.0000, -0.7542,
0.9999, -0.2807, -0.9999, 0.9490, -1.0000, 0.9906, -0.7288, -0.5263,
-0.9545, -0.9999, 0.9998, -0.9286, -0.9997, -0.5303, 0.8886, 0.5605,
-0.9989, -0.3324, 0.9804, -0.9075, 0.9905, -0.9800, -0.9946, 0.6855,
-0.9393, 0.9929, 0.9874, 1.0000, 0.9997, -0.0714, -0.9440, 1.0000,
0.1676, -1.0000, 0.5573, -0.9611, 0.8835, 0.9999, -0.9980, 0.9294,
1.0000, 0.7968, 1.0000, -0.7065, -0.9793, -0.9997, 1.0000, 0.9922,
0.9999, -0.9984, -0.9995, -0.1701, -0.5426, -1.0000, -1.0000, -0.6334,
0.9969, 0.9999, -0.1620, -0.9818, -0.9921, -0.9994, 1.0000, -0.9759,
1.0000, 0.8570, -0.7434, -0.9164, 0.9438, -0.7311, -0.9986, -0.3936,
-0.9997, -0.9650, -1.0000, 0.9433, -0.9999, -1.0000, 0.6913, 1.0000,
0.8762, -1.0000, 0.9997, 0.9764, 0.7094, -0.9294, 0.9522, -1.0000,
1.0000, -0.9965, 0.9428, -0.9972, -0.9897, -0.7680, 0.9922, 0.9999,
-0.9999, -0.9597, -0.9922, -0.9807, -0.3632, 0.9936, -0.7280, 0.4117,
-0.9498, -0.9666, 0.9545, -0.9957, -0.9970, 0.4028, 1.0000, -0.9798,
1.0000, 0.9941, 1.0000, 0.9202, -0.9942, 0.9996, 0.5352, -0.5836,
-0.8829, -0.9418, 0.9497, -0.0532, 0.6966, -0.9999, 0.9998, 0.9917,
0.9612, 0.7289, 0.0167, 0.3179, 0.9627, -0.9911, 0.9995, -0.9996,
-0.6737, 0.9991, 1.0000, 0.9932, 0.4880, -0.7488, 0.9986, -0.9961,
0.9995, -1.0000, 0.9999, -0.9940, 0.9705, -0.9970, -0.9856, 1.0000,
0.9846, -0.7932, 0.9997, -0.9386, 0.9938, 0.9738, 0.8173, 0.9913,
0.9981, 1.0000, -0.9998, -0.9918, -0.9727, -0.9987, -0.9955, -1.0000,
-0.1038, -1.0000, -0.9874, -0.9287, 0.5109, -0.9056, 0.1022, 0.7864,
-0.8197, 0.5724, -0.5905, 0.2713, -0.7239, -0.9976, -0.9844, -1.0000,
-0.9988, 0.8835, 0.9999, -0.9997, 0.9999, -0.9999, -0.9782, 0.9383,
-0.5609, 0.7721, 0.9999, -1.0000, 0.9585, 0.9987, 1.0000, 0.9960,
0.9993, -0.9741, -0.9999, -0.9989, -0.9999, -1.0000, -0.9998, 0.9343,
0.6337, -1.0000, 0.0902, 0.8980, 1.0000, 0.9964, -0.9985, -0.6136,
-0.9996, -0.8252, 0.9996, -0.0566, -1.0000, 0.9962, -0.8744, 1.0000,
-0.8865, 0.9879, 0.8897, 0.9571, 0.9823, -1.0000, 0.9145, 1.0000,
0.0365, -1.0000, -0.9985, -0.9075, -0.9998, 0.0369, 0.8120, 0.9999,
-1.0000, -0.9155, -0.9975, 0.7988, 0.9922, 0.9998, 0.9982, 0.9267,
0.9165, 0.5368, 0.1464, 0.9998, 0.4663, -0.9989, 0.9996, -0.7952,
0.4527, -1.0000, 0.9998, 0.4073, 0.9999, 0.9159, -0.5480, -0.6821,
-0.9904, 0.9938, 1.0000, -0.4229, -0.4845, -0.9981, -1.0000, -0.9861,
-0.0950, -0.4625, -0.9629, -0.9998, 0.6675, -0.5244, 1.0000, 1.0000,
0.9924, -0.9253, -0.9974, 0.9974, -0.9012, 0.9900, -0.2582, -1.0000,
-0.9919, -0.9986, 1.0000, -0.9716, -0.9262, -0.9911, -0.2593, 0.5919,
-0.9999, -0.4994, -0.9962, 0.9818, 1.0000, -0.9996, 0.9918, -0.9970,
0.7085, -0.1369, 0.8077, 0.9955, -0.3394, -0.5860, -0.6887, -0.9841,
0.9970, 0.9987, -0.9948, -0.8401, 0.9999, 0.0856, 0.9999, 0.5099,
0.9466, 0.9567, 1.0000, 0.8771, 1.0000, -0.0815, 1.0000, 0.9999,
-0.9392, 0.5744, 0.8723, -0.9686, 0.5958, 0.9822, 0.9997, 0.8854,
-0.1952, -0.9967, 0.9994, 1.0000, 1.0000, -0.3391, 0.9883, -0.4452,
0.9252, 0.4495, 0.9870, 0.3479, 0.2266, 0.9942, 0.9990, -0.9999,
-0.9999, -1.0000, 1.0000, 0.9996, -0.6637, -1.0000, 0.9999, 0.4543,
0.7471, 0.9983, 0.3772, -0.9812, 0.9853, -0.9995, -0.3404, 0.9788,
0.9867, 0.7564, 0.9995, -0.9997, 0.7990, 1.0000, 0.0752, 0.9999,
0.2912, -0.9941, 0.9970, -0.9935, -0.9995, -0.9743, 0.9991, 0.9981,
-0.9273, -0.8402, 0.9996, -0.9999, 0.9999, -0.9998, 0.9724, -0.9939,
1.0000, -0.9752, -0.9998, -0.3806, 0.8830, 0.8352, -0.8892, 1.0000,
-0.8875, -0.8107, 0.7083, -0.8909, -0.9931, -0.9630, 0.0800, -1.0000,
0.7777, -0.9611, 0.5867, -0.9947, -0.9999, 1.0000, -0.9084, -0.9414,
0.9999, -0.8838, -1.0000, 0.9549, -0.9999, -0.6522, 0.7967, -0.6850,
0.1524, -1.0000, 0.4800, 0.9999, -0.9998, -0.7089, -0.9129, -0.9864,
0.6220, 0.8855, 0.9855, -0.8651, 0.3988, -0.2548, 0.9793, -0.7212,
-0.2582, -0.9999, -0.8692, -0.6282, -0.9999, -0.9999, -1.0000, 1.0000,
0.9996, 0.9999, -0.5600, 0.7442, 0.9460, 0.9927, -0.9999, 0.4407,
-0.0461, 0.9937, -0.4887, -0.9994, -0.9198, -1.0000, -0.6905, 0.3538,
-0.7728, 0.6622, 1.0000, 0.9999, -0.9999, -0.9994, -0.9995, -0.9979,
0.9998, 0.9999, 0.9996, -0.9072, -0.5844, 0.9997, 0.9689, 0.5231,
-0.9999, -0.9981, -0.9999, 0.7505, -0.9922, -0.9986, 0.9971, 1.0000,
0.8730, -1.0000, -0.9533, 1.0000, 0.9997, 1.0000, -0.7768, 0.9999,
-0.9838, 0.9819, -0.9993, 1.0000, -1.0000, 1.0000, 0.9999, 0.9809,
0.9984, -0.9928, 0.9776, -0.9998, -0.7407, 0.9298, -0.4495, -0.9902,
0.8053, 0.9996, -0.9952, 1.0000, 0.9243, -0.2028, 0.8002, 0.9873,
0.9419, -0.6913, -0.9999, 0.8162, 0.9995, 0.9509, 1.0000, 0.9177,
0.9996, -0.9839, -0.9998, 0.9914, -0.6991, -0.7821, -0.9998, 1.0000,
1.0000, -0.9999, -0.9227, 0.7483, 0.1186, 1.0000, 0.9963, 0.9971,
0.9857, 0.3887, 0.9996, -0.9999, 0.8526, -0.9980, -0.8613, 0.9999,
-0.9899, 0.9999, -0.9981, 1.0000, -0.9858, 0.9944, 0.9989, 0.9684,
-0.9968, 1.0000, 0.8246, -0.9956, -0.8348, -0.9374, -0.9999, 0.7827]],
grad_fn=<TanhBackward0>)
torch.Size([1, 32, 768])
BertConfig {
"_name_or_path": "bert-base-chinese",
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"classifier_dropout": null,
"directionality": "bidi",
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"position_embedding_type": "absolute",
"transformers_version": "4.24.0",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 21128
}
bert后面又接了一个全连接层 1 2 3 4 5 6 7 8 9 10 11 12 13 14 class EnterpriseDangerClassifier (nn.Module): def __init__ (self, n_classes ): super (EnterpriseDangerClassifier, self).__init__() self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME) self.drop = nn.Dropout(p=0.3 ) self.out = nn.Linear(self.bert.config.hidden_size, n_classes) def forward (self, input_ids, attention_mask ): _, pooled_output = self.bert( input_ids=input_ids, attention_mask=attention_mask, return_dict = False ) output = self.drop(pooled_output) return self.out(output)
将数据和模型送到CUDA 1 2 3 4 from transformers import BertModel, BertTokenizer,BertConfig, AdamW, get_linear_schedule_with_warmupmodel = EnterpriseDangerClassifier(len (class_names)) model = model.to(device)
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
1 2 3 4 5 input_ids = data['input_ids' ].to(device) attention_mask = data['attention_mask' ].to(device) print (input_ids.shape) print (attention_mask.shape)
torch.Size([4, 160])
torch.Size([4, 160])
得到单个batch的输出token 1 model(input_ids, attention_mask)
tensor([[ 0.2120, -0.4050],
[ 0.3156, -0.4160],
[ 0.5127, -0.4634],
[ 0.3168, 0.5057]], device='cuda:0', grad_fn=<AddmmBackward0>)
1 F.softmax(model(input_ids, attention_mask), dim=1 )
tensor([[0.4999, 0.5001],
[0.5258, 0.4742],
[0.5899, 0.4101],
[0.4575, 0.5425]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
训练模型前期配置 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 EPOCHS = 10 optimizer = AdamW(model.parameters(), lr=2e-5 , correct_bias=False ) total_steps = len (train_data_loader) * EPOCHS scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0 , num_training_steps=total_steps ) loss_fn = nn.CrossEntropyLoss().to(device)
定义模型的训练 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 def train_epoch ( model, data_loader, loss_fn, optimizer, device, scheduler, n_examples ): model = model.train() losses = [] correct_predictions = 0 for d in data_loader: input_ids = d["input_ids" ].to(device) attention_mask = d["attention_mask" ].to(device) targets = d["labels" ].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, preds = torch.max (outputs, dim=1 ) loss = loss_fn(outputs, targets) correct_predictions += torch.sum (preds == targets) losses.append(loss.item()) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0 ) optimizer.step() scheduler.step() optimizer.zero_grad() return correct_predictions.double() / n_examples, np.mean(losses)
模型的评估函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def eval_model (model, data_loader, loss_fn, device, n_examples ): model = model.eval () losses = [] correct_predictions = 0 with torch.no_grad(): for d in data_loader: input_ids = d["input_ids" ].to(device) attention_mask = d["attention_mask" ].to(device) targets = d["labels" ].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, preds = torch.max (outputs, dim=1 ) loss = loss_fn(outputs, targets) correct_predictions += torch.sum (preds == targets) losses.append(loss.item()) return correct_predictions.double() / n_examples, np.mean(losses)
训练模型:10EPOCHS 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 history = defaultdict(list ) best_accuracy = 0 for epoch in range (EPOCHS): print (f'Epoch {epoch + 1 } /{EPOCHS} ' ) print ('-' * 10 ) train_acc, train_loss = train_epoch( model, train_data_loader, loss_fn, optimizer, device, scheduler, len (df_train) ) print (f'Train loss {train_loss} accuracy {train_acc} ' ) val_acc, val_loss = eval_model( model, val_data_loader, loss_fn, device, len (df_val) ) print (f'Val loss {val_loss} accuracy {val_acc} ' ) print () history['train_acc' ].append(train_acc) history['train_loss' ].append(train_loss) history['val_acc' ].append(val_acc) history['val_loss' ].append(val_loss) if val_acc > best_accuracy: torch.save(model.state_dict(), 'best_model_state.bin' ) best_accuracy = val_acc
Epoch 1/10
----------
Train loss 0.4988938277521757 accuracy 0.8899999999999999
Val loss 0.4194765945523977 accuracy 0.9
Epoch 2/10
----------
Train loss 0.4967574527254328 accuracy 0.8905555555555555
Val loss 0.43736912585794924 accuracy 0.9
Epoch 3/10
----------
Train loss 0.49347498720511795 accuracy 0.8905555555555555
Val loss 0.41818931301434836 accuracy 0.9
Epoch 4/10
----------
Train loss 0.4900011462407807 accuracy 0.8905555555555555
Val loss 0.42409916249414287 accuracy 0.9
Epoch 5/10
----------
Train loss 0.4952681002088098 accuracy 0.8888888888888888
Val loss 0.31909402589624125 accuracy 0.9
Epoch 6/10
----------
Train loss 0.2478140213253425 accuracy 0.9463888888888888
Val loss 0.1787985412031412 accuracy 0.9666666666666667
Epoch 7/10
----------
Train loss 0.17434944392257387 accuracy 0.9677777777777777
Val loss 0.15001839348037416 accuracy 0.9700000000000001
Epoch 8/10
----------
Train loss 0.12048366091100939 accuracy 0.9775925925925926
Val loss 0.11547344802587758 accuracy 0.9783333333333334
Epoch 9/10
----------
Train loss 0.10136666681817992 accuracy 0.9813888888888889
Val loss 0.10292303454208498 accuracy 0.9800000000000001
Epoch 10/10
----------
Train loss 0.08721379442805402 accuracy 0.9831481481481481
Val loss 0.12598223814862042 accuracy 0.9766666666666667
准确率绘图 1 2 3 4 5 6 7 8 9 plt.plot([i.cpu() for i in history['train_acc' ]], label='train accuracy' ) plt.plot([i.cpu() for i in history['val_acc' ]], label='validation accuracy' ) plt.title('Training history' ) plt.ylabel('Accuracy' ) plt.xlabel('Epoch' ) plt.legend() plt.ylim([0 , 1 ]);
1 2 3 4 5 6 7 8 9 test_acc, _ = eval_model( model, test_data_loader, loss_fn, device, len (df_test) ) test_acc.item()
0.9783333333333334
预测模型 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 def get_predictions (model, data_loader ): model = model.eval () raw_texts = [] predictions = [] prediction_probs = [] real_values = [] with torch.no_grad(): for d in data_loader: texts = d["texts" ] input_ids = d["input_ids" ].to(device) attention_mask = d["attention_mask" ].to(device) targets = d["labels" ].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) _, preds = torch.max (outputs, dim=1 ) probs = F.softmax(outputs, dim=1 ) raw_texts.extend(texts) predictions.extend(preds) prediction_probs.extend(probs) real_values.extend(targets) predictions = torch.stack(predictions).cpu() prediction_probs = torch.stack(prediction_probs).cpu() real_values = torch.stack(real_values).cpu() return raw_texts, predictions, prediction_probs, real_values
1 2 3 4 5 y_texts, y_pred, y_pred_probs, y_test = get_predictions( model, test_data_loader )
1 print (classification_report(y_test, y_pred, target_names=[str (label) for label in class_names]))
precision recall f1-score support
0 0.99 0.99 0.99 554
1 0.84 0.89 0.86 46
accuracy 0.98 600
macro avg 0.91 0.94 0.93 600
weighted avg 0.98 0.98 0.98 600
查看混淆矩阵 1 2 3 4 5 6 7 8 9 10 11 def show_confusion_matrix (confusion_matrix ): hmap = sns.heatmap(confusion_matrix, annot=True , fmt="d" , cmap="Blues" ) hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0 , ha='right' ) hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30 , ha='right' ) plt.ylabel('True label' ) plt.xlabel('Predicted label' ); cm = confusion_matrix(y_test, y_pred) df_cm = pd.DataFrame(cm, index=class_names, columns=class_names) show_confusion_matrix(df_cm)
评估单条数据 1 2 3 4 5 6 7 8 idx = 2 sample_text = y_texts[idx] true_label = y_test[idx] pred_df = pd.DataFrame({ 'class_names' : class_names, 'values' : y_pred_probs[idx] })
class_names
values
0
0
0.999889
1
1
0.000111
1 2 3 print ("\n" .join(wrap(sample_text)))print ()print (f'True label: {class_names[true_label]} ' )
小A班应急照明灯坏[SEP]商贸服务教文卫类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急
照明是否完好。
True label: 0
1 2 3 4 sns.barplot(x='values' , y='class_names' , data=pred_df, orient='h' ) plt.ylabel('sentiment' ) plt.xlabel('probability' ) plt.xlim([0 , 1 ]);
1 2 3 4 5 6 7 8 9 encoded_text = tokenizer.encode_plus( sample_text, max_length=MAX_LEN, add_special_tokens=True , return_token_type_ids=False , pad_to_max_length=True , return_attention_mask=True , return_tensors='pt' , )
1 2 3 4 5 6 7 8 input_ids = encoded_text['input_ids' ].to(device) attention_mask = encoded_text['attention_mask' ].to(device) output = model(input_ids, attention_mask) _, prediction = torch.max (output, dim=1 ) print (f'Sample text: {sample_text} ' )print (f'Danger label : {class_names[prediction]} ' )
Sample text:
Danger label : 1