通过智能化手段识别其中是否存在“虚报、假报”的情况
 
背景 企业自主填报安全生产隐患,对于将风险消除在事故萌芽阶段具有重要意义。企业在填报隐患时,往往存在不认真填报的情况,“虚报、假报”隐患内容,增大了企业监管的难度。采用大数据手段分析隐患内容,找出不切实履行主体责任的企业,向监管部门进行推送,实现精准执法,能够提高监管手段的有效性,增强企业安全责任意识。
任务 本赛题提供企业填报隐患数据,参赛选手需通过智能化手段识别其中是否存在“虚报、假报”的情况。
数据简介 本赛题数据集为脱敏后的企业填报自查隐患记录,数据说明如下:
评测标准 本赛题采用F1 -score作为模型评判标准。
具体实现代码如下: 导入所以需要的包 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 import  transformersfrom  transformers import  AutoModel, AutoTokenizer,AutoConfig, AdamW, get_linear_schedule_with_warmupimport  torchfrom  torch import  nn, optimfrom  torch.utils.data import  Dataset, DataLoaderimport  torch.nn.functional as  Fimport  reimport  numpy as  npimport  pandas as  pdimport  seaborn as  snsfrom  pylab import  rcParamsimport  matplotlib.pyplot as  pltfrom  matplotlib import  rcfrom  sklearn.model_selection import  train_test_splitfrom  sklearn.metrics import  confusion_matrix, classification_reportfrom  collections import  defaultdictfrom  textwrap import  wrapimport  warningswarnings.filterwarnings("ignore" ) %matplotlib inline %config InlineBackend.figure_format='retina'   
 
初始化设置 1 2 3 4 5 6 7 8 9 10 11 12 sns.set (style='whitegrid' , palette='muted' , font_scale=1.2 ) HAPPY_COLORS_PALETTE = ["#01BEFE" , "#FFDD00" , "#FF7D00" , "#FF006D" , "#ADFF02" , "#8F00FF" ] sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE)) rcParams['figure.figsize' ] = 12 , 8  RANDOM_SEED = 42  np.random.seed(RANDOM_SEED) torch.manual_seed(RANDOM_SEED) device = torch.device("cuda:0"  if  torch.cuda.is_available() else  "cpu" ) device 
 
device(type='cuda', index=0)
 
读取数据 1 2 3 sub=pd.read_csv('./data/02企业隐患排查/sub.csv' ) test=pd.read_csv('./data/02企业隐患排查/test.csv' ) train=pd.read_csv('./data/02企业隐患排查/train.csv' ) 
 
 
  
    
       
      id 
      level_1 
      level_2 
      level_3 
      level_4 
      content 
      label 
     
   
  
    
      0 
      0 
      工业/危化品类(现场)—2016版 
      (二)电气安全 
      6、移动用电产品、电动工具及照明 
      1、移动使用的用电产品和I类电动工具的绝缘线,必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。 
      使用移动手动电动工具,外接线绝缘皮破损,应停止使用. 
      0 
     
    
      1 
      1 
      工业/危化品类(现场)—2016版 
      (一)消防检查 
      1、防火巡查 
      3、消防设施、器材和消防安全标志是否在位、完整; 
      一般 
      1 
     
    
      2 
      2 
      工业/危化品类(现场)—2016版 
      (一)消防检查 
      2、防火检查 
      6、重点工种人员以及其他员工消防知识的掌握情况; 
      消防知识要加强 
      0 
     
    
      3 
      3 
      工业/危化品类(现场)—2016版 
      (一)消防检查 
      1、防火巡查 
      3、消防设施、器材和消防安全标志是否在位、完整; 
      消防通道有货物摆放 清理不及时 
      0 
     
    
      4 
      4 
      工业/危化品类(现场)—2016版 
      (一)消防检查 
      1、防火巡查 
      4、常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用; 
      防火门打开状态 
      0 
     
   
 
 
  
    
       
      id 
      level_1 
      level_2 
      level_3 
      level_4 
      content 
     
   
  
    
      0 
      0 
      交通运输类(现场)—2016版 
      (一)消防安全 
      2、防火检查 
      2、安全疏散通道、疏散指示标志、应急照明和安全出口情况。 
      RB1洗地机占用堵塞安全通道 
     
    
      1 
      1 
      工业/危化品类(选项)—2016版 
      (二)仓库 
      1、一般要求 
      1、库房内储存物品应分类、分堆、限额存放。 
      未分类堆放 
     
    
      2 
      2 
      工业/危化品类(现场)—2016版 
      (一)消防检查 
      1、防火巡查 
      3、消防设施、器材和消防安全标志是否在位、完整; 
      消防设施、器材和消防安全标志是否在位、完整 
     
    
      3 
      3 
      商贸服务教文卫类(现场)—2016版 
      (二)电气安全 
      3、电气线路及电源插头插座 
      3、电源插座、电源插头应按规定正确接线。 
      插座随意放在电器旁边 
     
    
      4 
      4 
      商贸服务教文卫类(现场)—2016版 
      (一)消防检查 
      1、防火巡查 
      6、其他消防安全情况。 
      检查中发现一瓶灭火器过期 
     
   
 
查看数据的形状 1 print ("train.shape,test.shape,sub.shape" ,train.shape,test.shape,sub.shape)
 
train.shape,test.shape,sub.shape (12000, 7) (18000, 6) (18000, 2)
 
查看是否存在空值 1 train[train['content' ].isna()] 
 
  
    
       
      id 
      level_1 
      level_2 
      level_3 
      level_4 
      content 
      label 
     
   
  
    
      6193 
      6193 
      工业/危化品类(现场)—2016版 
      (一)消防检查 
      1、防火巡查 
      3、消防设施、器材和消防安全标志是否在位、完整; 
      NaN 
      1 
     
    
      9248 
      9248 
      工业/危化品类(现场)—2016版 
      (一)消防检查 
      1、防火巡查 
      4、常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用; 
      NaN 
      1 
     
   
 
1 2 print ('train null nums' )train.shape[0 ]-train.count() 
 
train null nums
id         0
level_1    0
level_2    0
level_3    0
level_4    0
content    2
label      0
dtype: int64
 
1 2 print ('test null nums' )test.shape[0 ]-test.count() 
 
test null nums
id         0
level_1    0
level_2    0
level_3    0
level_4    0
content    4
dtype: int64
 
查看标签的分布 1 train['label' ].value_counts() 
 
0    10712
1     1288
Name: label, dtype: int64
 
 
数据预处理 1 2 train.fillna("空值" ,inplace=True ) test.fillna("空值" ,inplace=True ) 
 
1 train.shape[0 ]-train.count() 
 
id         0
level_1    0
level_2    0
level_3    0
level_4    0
content    0
label      0
dtype: int64
 
1 train['level_3' ].value_counts() 
 
1、防火巡查             4225
2、防火检查             2911
2、配电箱(柜、板)          710
1、作业通道              664
3、电气线路及电源插头插座       497
                   ... 
3、安全带                 1
4、特种设备及操作人员管理记录       1
4、安全技术交底              1
3、停车场                 1
1、水库安全                1
Name: level_3, Length: 153, dtype: int64
 
对训练集处理 1 2 3 4 train['level_1' ] = train['level_1' ].apply(lambda  x:x.split('(' )[0 ]) train['level_2' ] = train['level_2' ].apply(lambda  x:x.split(')' )[1 ]) train['level_3' ] = train['level_3' ].apply(lambda  x:re.split('[0-9]、' ,x)[-1 ]) train['level_4' ] = train['level_4' ].apply(lambda  x:re.split('[0-9]、' ,x)[-1 ]) 
 
对测试集处理 1 2 3 4 test['level_1' ] = test['level_1' ].apply(lambda  x:x.split('(' )[0 ]) test['level_2' ] = test['level_2' ].apply(lambda  x:x.split(')' )[-1 ]) test['level_3' ] = test['level_3' ].apply(lambda  x:re.split('[0-9]、' ,x)[-1 ]) test['level_4' ] = test['level_4' ].apply(lambda  x:re.split('[0-9]、' ,x)[-1 ]) 
 
 
  
    
       
      id 
      level_1 
      level_2 
      level_3 
      level_4 
      content 
      label 
     
   
  
    
      0 
      0 
      工业/危化品类 
      电气安全 
      移动用电产品、电动工具及照明 
      移动使用的用电产品和I类电动工具的绝缘线,必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。 
      使用移动手动电动工具,外接线绝缘皮破损,应停止使用. 
      0 
     
    
      1 
      1 
      工业/危化品类 
      消防检查 
      防火巡查 
      消防设施、器材和消防安全标志是否在位、完整; 
      一般 
      1 
     
    
      2 
      2 
      工业/危化品类 
      消防检查 
      防火检查 
      重点工种人员以及其他员工消防知识的掌握情况; 
      消防知识要加强 
      0 
     
    
      3 
      3 
      工业/危化品类 
      消防检查 
      防火巡查 
      消防设施、器材和消防安全标志是否在位、完整; 
      消防通道有货物摆放 清理不及时 
      0 
     
    
      4 
      4 
      工业/危化品类 
      消防检查 
      防火巡查 
      常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用; 
      防火门打开状态 
      0 
     
   
 
文本拼接 1 2 3 train['text' ]=train['content' ]+'[SEP]' +train['level_1' ]+'[SEP]' +train['level_2' ]+'[SEP]' +train['level_3' ]+'[SEP]' +train['level_4' ] test['text' ]=test['content' ]+'[SEP]' +test['level_1' ]+'[SEP]' +test['level_2' ]+'[SEP]' +test['level_3' ]+'[SEP]' +test['level_4' ] train.head() 
 
  
    
       
      id 
      level_1 
      level_2 
      level_3 
      level_4 
      content 
      label 
      text 
     
   
  
    
      0 
      0 
      工业/危化品类 
      电气安全 
      移动用电产品、电动工具及照明 
      移动使用的用电产品和I类电动工具的绝缘线,必须采用三芯(单相)或四芯(三相)多股铜芯橡套软线。 
      使用移动手动电动工具,外接线绝缘皮破损,应停止使用. 
      0 
      使用移动手动电动工具,外接线绝缘皮破损,应停止使用.[SEP]工业/危化品类[SEP]电气安... 
     
    
      1 
      1 
      工业/危化品类 
      消防检查 
      防火巡查 
      消防设施、器材和消防安全标志是否在位、完整; 
      一般 
      1 
      一般[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]消防设施、器材和消... 
     
    
      2 
      2 
      工业/危化品类 
      消防检查 
      防火检查 
      重点工种人员以及其他员工消防知识的掌握情况; 
      消防知识要加强 
      0 
      消防知识要加强[SEP]工业/危化品类[SEP]消防检查[SEP]防火检查[SEP]重点工种... 
     
    
      3 
      3 
      工业/危化品类 
      消防检查 
      防火巡查 
      消防设施、器材和消防安全标志是否在位、完整; 
      消防通道有货物摆放 清理不及时 
      0 
      消防通道有货物摆放 清理不及时[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[... 
     
    
      4 
      4 
      工业/危化品类 
      消防检查 
      防火巡查 
      常闭式防火门是否处于关闭状态,防火卷帘下是否堆放物品影响使用; 
      防火门打开状态 
      0 
      防火门打开状态[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]常闭式防... 
     
   
 
1 2 train['text_len' ]=train['text' ].map (len ) train['text' ].map (len ).describe() 
 
count    12000.000000
mean        80.444500
std         21.910859
min         43.000000
25%         66.000000
50%         75.000000
75%         92.000000
max        298.000000
Name: text, dtype: float64
 
1 test['text' ].map (len ).describe()  
 
count    18000.000000
mean        80.762611
std         22.719823
min         43.000000
25%         66.000000
50%         76.000000
75%         92.000000
max        520.000000
Name: text, dtype: float64
 
1 train['text_len' ].plot(kind='kde' ) 
 
<AxesSubplot:ylabel='Density'>
 
1 2 sum (train['text_len' ]>100 ) sum (train['text_len' ]>200 ) 
 
模型的加载和配置 embedding 1 2 3 4 5 6 PRE_TRAINED_MODEL_NAME = 'bert-base-chinese'  tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME) 
 
 
PreTrainedTokenizerFast(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})
 
1 2 sample_txt = '今天早上9点半起床,我在学习预训练模型的使用.'  len (sample_txt)
 
23
 
1 2 3 4 5 6 tokens = tokenizer.tokenize(sample_txt) token_ids = tokenizer.convert_tokens_to_ids(tokens) print (f'文本为: {sample_txt} ' )print (f'分词的列表为: {tokens} ' )print (f'词对应的唯一id: {token_ids} ' )
 
文本为: 今天早上9点半起床,我在学习预训练模型的使用.
分词的列表为: ['今', '天', '早', '上', '9', '点', '半', '起', '床', ',', '我', '在', '学', '习', '预', '训', '练', '模', '型', '的', '使', '用', '.']
词对应的唯一id: [791, 1921, 3193, 677, 130, 4157, 1288, 6629, 2414, 8024, 2769, 1762, 2110, 739, 7564, 6378, 5298, 3563, 1798, 4638, 886, 4500, 119]
 
查看特殊的Token 1 tokenizer.sep_token,tokenizer.sep_token_id 
 
('[SEP]', 102)
 
1 tokenizer.cls_token,tokenizer.cls_token_id 
 
('[CLS]', 101)
 
1 tokenizer.pad_token,tokenizer.pad_token_id 
 
('[PAD]', 0)
 
1 tokenizer.mask_token,tokenizer.mask_token_id 
 
('[MASK]', 103)
 
1 tokenizer.unk_token,tokenizer.unk_token_id 
 
('[UNK]', 100)
 
简单的编码测试 1 2 3 4 5 6 7 8 9 10 11 12 encoding=tokenizer.encode_plus(     sample_txt,          max_length=32 ,     add_special_tokens=True ,     return_token_type_ids=True ,     pad_to_max_length=True ,     return_attention_mask=True ,     return_tensors='pt' , ) encoding 
 
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
{'input_ids': tensor([[ 101,  791, 1921, 3193,  677,  130, 4157, 1288, 6629, 2414, 8024, 2769,
         1762, 2110,  739, 7564, 6378, 5298, 3563, 1798, 4638,  886, 4500,  119,
          102,    0,    0,    0,    0,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 0, 0, 0, 0, 0, 0, 0]])}
 
1 encoding['attention_mask' ][0 ] 
 
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 0, 0, 0, 0, 0, 0, 0])
 
1 2 3 4 5 6 7 token_lens = [] for  txt in  train.text:    tokens = tokenizer.encode(txt, max_length=512 )     token_lens.append(len (tokens)) 
 
1 2 3 sns.distplot(token_lens) plt.xlim([0 , 256 ]); plt.xlabel('Token count' ); 
 
    
通过分析,长度一般都在160之内  
1 encoding['input_ids' ].flatten() 
 
tensor([ 101,  791, 1921, 3193,  677,  130, 4157, 1288, 6629, 2414, 8024, 2769,
        1762, 2110,  739, 7564, 6378, 5298, 3563, 1798, 4638,  886, 4500,  119,
         102,    0,    0,    0,    0,    0,    0,    0])
 
 
tensor([[ 101,  791, 1921, 3193,  677,  130, 4157, 1288, 6629, 2414, 8024, 2769,
         1762, 2110,  739, 7564, 6378, 5298, 3563, 1798, 4638,  886, 4500,  119,
          102,    0,    0,    0,    0,    0,    0,    0]])
 
处理数据 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 class  EnterpriseDataset (Dataset ):    def  __init__ (self,texts,labels,tokenizer,max_len ):         self.texts=texts         self.labels=labels         self.tokenizer=tokenizer         self.max_len=max_len     def  __len__ (self ):         return  len (self.texts)          def  __getitem__ (self,item ):         """          item 为数据索引,迭代取第item条数据         """         text=str (self.texts[item])         label=self.labels[item]                  encoding=self.tokenizer.encode_plus(             text,             add_special_tokens=True ,             max_length=self.max_len,             return_token_type_ids=True ,             pad_to_max_length=True ,             return_attention_mask=True ,             return_tensors='pt' ,         )                  return  {             'texts' :text,             'input_ids' :encoding['input_ids' ].flatten(),             'attention_mask' :encoding['attention_mask' ].flatten(),                          'labels' :torch.tensor(label,dtype=torch.long)         } 
 
分割数据集 1 2 3 df_train, df_test = train_test_split(train, test_size=0.1 , random_state=RANDOM_SEED) df_val, df_test = train_test_split(df_test, test_size=0.5 , random_state=RANDOM_SEED) df_train.shape, df_val.shape, df_test.shape 
 
((10800, 9), (600, 9), (600, 9))
 
创建DataLoader 1 2 3 4 5 6 7 8 9 10 11 12 13 14 def  create_data_loader (df,tokenizer,max_len,batch_size ):    ds=EnterpriseDataset(         texts=df['text' ].values,         labels=df['label' ].values,         tokenizer=tokenizer,         max_len=max_len     )          return  DataLoader(         ds,         batch_size=batch_size,     ) 
 
1 2 3 4 5 6 BATCH_SIZE = 4  train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE) val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE) test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE) 
 
1 next (iter (train_data_loader))
 
{'texts': ['指示标识不清楚[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
  '发现本月有灭火器过期,已安排购买灭火器更换[SEP]商贸服务教文卫类[SEP]消防检查[SEP]防火检查[SEP]灭火器材配置及有效情况。',
  '安全出口标志灯有一个有故障,已买回安装改正。[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;',
  '堵了消防通道[SEP]工业/危化品类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急照明是否完好;'],
 'input_ids': tensor([[ 101, 2900, 4850, 3403, 6399,  679, 3926, 3504,  102, 2339,  689,  120,
          1314, 1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,  102, 7344, 4125,
          2337, 3389,  102, 2128, 1059, 1139, 1366,  510, 4541, 3141, 6858, 6887,
          3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403,
          2562,  510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039,  102,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 1355, 4385, 3315, 3299, 3300, 4127, 4125, 1690, 6814, 3309, 8024,
          2347, 2128, 2961, 6579,  743, 4127, 4125, 1690, 3291, 2940,  102, 1555,
          6588, 3302, 1218, 3136, 3152, 1310, 5102,  102, 3867, 7344, 3466, 3389,
           102, 7344, 4125, 3466, 3389,  102, 4127, 4125, 1690, 3332, 6981, 5390,
          1350, 3300, 3126, 2658, 1105,  511,  102,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 2128, 1059, 1139, 1366, 3403, 2562, 4128, 3300,  671,  702, 3300,
          3125, 7397, 8024, 2347,  743, 1726, 2128, 6163, 3121, 3633,  511,  102,
          2339,  689,  120, 1314, 1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,
           102, 7344, 4125, 2337, 3389,  102, 2128, 1059, 1139, 1366,  510, 4541,
          3141, 6858, 6887, 3221, 1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141,
          2900, 4850, 3403, 2562,  510, 2418, 2593, 4212, 3209, 3221, 1415, 2130,
          1962, 8039,  102,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0],
         [ 101, 1843,  749, 3867, 7344, 6858, 6887,  102, 2339,  689,  120, 1314,
          1265, 1501, 5102,  102, 3867, 7344, 3466, 3389,  102, 7344, 4125, 2337,
          3389,  102, 2128, 1059, 1139, 1366,  510, 4541, 3141, 6858, 6887, 3221,
          1415, 4517, 6858, 8024, 2128, 1059, 4541, 3141, 2900, 4850, 3403, 2562,
           510, 2418, 2593, 4212, 3209, 3221, 1415, 2130, 1962, 8039,  102,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'labels': tensor([0, 0, 0, 0])}
 
1 2 data = next (iter (train_data_loader)) data.keys() 
 
dict_keys(['texts', 'input_ids', 'attention_mask', 'labels'])
 
1 2 3 print (data['input_ids' ].shape)print (data['attention_mask' ].shape)print (data['labels' ].shape)
 
torch.Size([4, 160])
torch.Size([4, 160])
torch.Size([4])
 
1 bert_model = AutoModel.from_pretrained(PRE_TRAINED_MODEL_NAME) 
 
 
tensor([[ 101,  791, 1921, 3193,  677,  130, 4157, 1288, 6629, 2414, 8024, 2769,
         1762, 2110,  739, 7564, 6378, 5298, 3563, 1798, 4638,  886, 4500,  119,
          102,    0,    0,    0,    0,    0,    0,    0]])
 
1 2 3 4 5 last_hidden_state, pooled_output = bert_model(     input_ids=encoding['input_ids' ],      attention_mask=encoding['attention_mask' ],     return_dict = False  ) 
 
查看输出结果 1 last_hidden_state[0 ][0 ].shape 
 
torch.Size([768])
 
 
tensor([[ 0.9999,  0.9998,  0.9989,  0.9629,  0.3075, -0.1866, -0.9904,  0.8628,
          0.9710, -0.9993,  1.0000,  1.0000,  0.9312, -0.9394,  0.9998, -0.9999,
          0.0417,  0.9999,  0.9458,  0.3190,  1.0000, -1.0000, -0.9062, -0.9048,
          0.1764,  0.9983,  0.9346, -0.8122, -0.9999,  0.9996,  0.7879,  0.9999,
          0.8475, -1.0000, -1.0000,  0.9413, -0.8260,  0.9889, -0.4976, -0.9857,
         -0.9955, -0.9580,  0.5833, -0.9996, -0.8932,  0.8563, -1.0000, -0.9999,
          0.9719,  0.9999, -0.7430, -0.9993,  0.9756, -0.9754,  0.2991,  0.8933,
         -0.9991,  0.9987,  1.0000,  0.4156,  0.9992, -0.9452, -0.8020, -0.9999,
          1.0000, -0.9964, -0.9900,  0.4365,  1.0000,  1.0000, -0.9400,  0.8794,
          1.0000,  0.9105, -0.6616,  1.0000, -0.9999,  0.6892, -1.0000, -0.9817,
          1.0000,  0.9957, -0.8844, -0.8248, -0.9921, -0.9999, -0.9998,  1.0000,
          0.5228,  0.1297,  0.9932, -0.9999, -1.0000,  0.9993, -0.9996, -0.9948,
         -0.9561,  0.9996, -0.5785, -0.9386, -0.2035,  0.9086, -0.9999, -0.9993,
          0.9959,  0.9984,  0.6953, -0.9995,  1.0000,  0.8610, -1.0000, -0.4507,
         -1.0000,  0.2384, -0.9812,  0.9998,  0.9504,  0.5421,  0.9995, -0.9998,
          0.9320, -0.9941, -0.9718, -0.9910,  0.9822,  1.0000,  0.9997, -0.9990,
          1.0000,  1.0000,  0.8608,  0.9964, -0.9997,  0.9799,  0.5985, -0.9098,
          0.5329, -0.6345,  1.0000,  0.9872,  0.9970, -0.9719,  0.9988, -0.9933,
          1.0000, -0.9999,  0.9973, -1.0000, -0.6550,  0.9996,  0.8899,  1.0000,
          0.2969,  0.9999, -0.9983, -0.9991,  0.9906, -0.6590,  0.9872, -1.0000,
          0.7658,  0.7876, -0.8556,  0.6304, -1.0000,  1.0000, -0.7938,  1.0000,
          0.9898,  0.2216, -0.9942, -0.9969,  0.8345, -0.9998, -0.9779,  0.9914,
          0.5227,  0.9992, -0.9893, -0.9889,  0.2325, -0.9887, -0.9999,  0.9885,
          0.0340,  0.9284,  0.5197,  0.4143,  0.8315,  0.1585, -0.5348,  1.0000,
          0.2361,  0.9985,  0.9999, -0.3446,  0.1012, -0.9924, -1.0000, -0.7542,
          0.9999, -0.2807, -0.9999,  0.9490, -1.0000,  0.9906, -0.7288, -0.5263,
         -0.9545, -0.9999,  0.9998, -0.9286, -0.9997, -0.5303,  0.8886,  0.5605,
         -0.9989, -0.3324,  0.9804, -0.9075,  0.9905, -0.9800, -0.9946,  0.6855,
         -0.9393,  0.9929,  0.9874,  1.0000,  0.9997, -0.0714, -0.9440,  1.0000,
          0.1676, -1.0000,  0.5573, -0.9611,  0.8835,  0.9999, -0.9980,  0.9294,
          1.0000,  0.7968,  1.0000, -0.7065, -0.9793, -0.9997,  1.0000,  0.9922,
          0.9999, -0.9984, -0.9995, -0.1701, -0.5426, -1.0000, -1.0000, -0.6334,
          0.9969,  0.9999, -0.1620, -0.9818, -0.9921, -0.9994,  1.0000, -0.9759,
          1.0000,  0.8570, -0.7434, -0.9164,  0.9438, -0.7311, -0.9986, -0.3936,
         -0.9997, -0.9650, -1.0000,  0.9433, -0.9999, -1.0000,  0.6913,  1.0000,
          0.8762, -1.0000,  0.9997,  0.9764,  0.7094, -0.9294,  0.9522, -1.0000,
          1.0000, -0.9965,  0.9428, -0.9972, -0.9897, -0.7680,  0.9922,  0.9999,
         -0.9999, -0.9597, -0.9922, -0.9807, -0.3632,  0.9936, -0.7280,  0.4117,
         -0.9498, -0.9666,  0.9545, -0.9957, -0.9970,  0.4028,  1.0000, -0.9798,
          1.0000,  0.9941,  1.0000,  0.9202, -0.9942,  0.9996,  0.5352, -0.5836,
         -0.8829, -0.9418,  0.9497, -0.0532,  0.6966, -0.9999,  0.9998,  0.9917,
          0.9612,  0.7289,  0.0167,  0.3179,  0.9627, -0.9911,  0.9995, -0.9996,
         -0.6737,  0.9991,  1.0000,  0.9932,  0.4880, -0.7488,  0.9986, -0.9961,
          0.9995, -1.0000,  0.9999, -0.9940,  0.9705, -0.9970, -0.9856,  1.0000,
          0.9846, -0.7932,  0.9997, -0.9386,  0.9938,  0.9738,  0.8173,  0.9913,
          0.9981,  1.0000, -0.9998, -0.9918, -0.9727, -0.9987, -0.9955, -1.0000,
         -0.1038, -1.0000, -0.9874, -0.9287,  0.5109, -0.9056,  0.1022,  0.7864,
         -0.8197,  0.5724, -0.5905,  0.2713, -0.7239, -0.9976, -0.9844, -1.0000,
         -0.9988,  0.8835,  0.9999, -0.9997,  0.9999, -0.9999, -0.9782,  0.9383,
         -0.5609,  0.7721,  0.9999, -1.0000,  0.9585,  0.9987,  1.0000,  0.9960,
          0.9993, -0.9741, -0.9999, -0.9989, -0.9999, -1.0000, -0.9998,  0.9343,
          0.6337, -1.0000,  0.0902,  0.8980,  1.0000,  0.9964, -0.9985, -0.6136,
         -0.9996, -0.8252,  0.9996, -0.0566, -1.0000,  0.9962, -0.8744,  1.0000,
         -0.8865,  0.9879,  0.8897,  0.9571,  0.9823, -1.0000,  0.9145,  1.0000,
          0.0365, -1.0000, -0.9985, -0.9075, -0.9998,  0.0369,  0.8120,  0.9999,
         -1.0000, -0.9155, -0.9975,  0.7988,  0.9922,  0.9998,  0.9982,  0.9267,
          0.9165,  0.5368,  0.1464,  0.9998,  0.4663, -0.9989,  0.9996, -0.7952,
          0.4527, -1.0000,  0.9998,  0.4073,  0.9999,  0.9159, -0.5480, -0.6821,
         -0.9904,  0.9938,  1.0000, -0.4229, -0.4845, -0.9981, -1.0000, -0.9861,
         -0.0950, -0.4625, -0.9629, -0.9998,  0.6675, -0.5244,  1.0000,  1.0000,
          0.9924, -0.9253, -0.9974,  0.9974, -0.9012,  0.9900, -0.2582, -1.0000,
         -0.9919, -0.9986,  1.0000, -0.9716, -0.9262, -0.9911, -0.2593,  0.5919,
         -0.9999, -0.4994, -0.9962,  0.9818,  1.0000, -0.9996,  0.9918, -0.9970,
          0.7085, -0.1369,  0.8077,  0.9955, -0.3394, -0.5860, -0.6887, -0.9841,
          0.9970,  0.9987, -0.9948, -0.8401,  0.9999,  0.0856,  0.9999,  0.5099,
          0.9466,  0.9567,  1.0000,  0.8771,  1.0000, -0.0815,  1.0000,  0.9999,
         -0.9392,  0.5744,  0.8723, -0.9686,  0.5958,  0.9822,  0.9997,  0.8854,
         -0.1952, -0.9967,  0.9994,  1.0000,  1.0000, -0.3391,  0.9883, -0.4452,
          0.9252,  0.4495,  0.9870,  0.3479,  0.2266,  0.9942,  0.9990, -0.9999,
         -0.9999, -1.0000,  1.0000,  0.9996, -0.6637, -1.0000,  0.9999,  0.4543,
          0.7471,  0.9983,  0.3772, -0.9812,  0.9853, -0.9995, -0.3404,  0.9788,
          0.9867,  0.7564,  0.9995, -0.9997,  0.7990,  1.0000,  0.0752,  0.9999,
          0.2912, -0.9941,  0.9970, -0.9935, -0.9995, -0.9743,  0.9991,  0.9981,
         -0.9273, -0.8402,  0.9996, -0.9999,  0.9999, -0.9998,  0.9724, -0.9939,
          1.0000, -0.9752, -0.9998, -0.3806,  0.8830,  0.8352, -0.8892,  1.0000,
         -0.8875, -0.8107,  0.7083, -0.8909, -0.9931, -0.9630,  0.0800, -1.0000,
          0.7777, -0.9611,  0.5867, -0.9947, -0.9999,  1.0000, -0.9084, -0.9414,
          0.9999, -0.8838, -1.0000,  0.9549, -0.9999, -0.6522,  0.7967, -0.6850,
          0.1524, -1.0000,  0.4800,  0.9999, -0.9998, -0.7089, -0.9129, -0.9864,
          0.6220,  0.8855,  0.9855, -0.8651,  0.3988, -0.2548,  0.9793, -0.7212,
         -0.2582, -0.9999, -0.8692, -0.6282, -0.9999, -0.9999, -1.0000,  1.0000,
          0.9996,  0.9999, -0.5600,  0.7442,  0.9460,  0.9927, -0.9999,  0.4407,
         -0.0461,  0.9937, -0.4887, -0.9994, -0.9198, -1.0000, -0.6905,  0.3538,
         -0.7728,  0.6622,  1.0000,  0.9999, -0.9999, -0.9994, -0.9995, -0.9979,
          0.9998,  0.9999,  0.9996, -0.9072, -0.5844,  0.9997,  0.9689,  0.5231,
         -0.9999, -0.9981, -0.9999,  0.7505, -0.9922, -0.9986,  0.9971,  1.0000,
          0.8730, -1.0000, -0.9533,  1.0000,  0.9997,  1.0000, -0.7768,  0.9999,
         -0.9838,  0.9819, -0.9993,  1.0000, -1.0000,  1.0000,  0.9999,  0.9809,
          0.9984, -0.9928,  0.9776, -0.9998, -0.7407,  0.9298, -0.4495, -0.9902,
          0.8053,  0.9996, -0.9952,  1.0000,  0.9243, -0.2028,  0.8002,  0.9873,
          0.9419, -0.6913, -0.9999,  0.8162,  0.9995,  0.9509,  1.0000,  0.9177,
          0.9996, -0.9839, -0.9998,  0.9914, -0.6991, -0.7821, -0.9998,  1.0000,
          1.0000, -0.9999, -0.9227,  0.7483,  0.1186,  1.0000,  0.9963,  0.9971,
          0.9857,  0.3887,  0.9996, -0.9999,  0.8526, -0.9980, -0.8613,  0.9999,
         -0.9899,  0.9999, -0.9981,  1.0000, -0.9858,  0.9944,  0.9989,  0.9684,
         -0.9968,  1.0000,  0.8246, -0.9956, -0.8348, -0.9374, -0.9999,  0.7827]],
       grad_fn=<TanhBackward0>)
 
 
torch.Size([1, 32, 768])
 
 
BertConfig {
  "_name_or_path": "bert-base-chinese",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "position_embedding_type": "absolute",
  "transformers_version": "4.24.0",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 21128
}
 
bert后面又接了一个全连接层 1 2 3 4 5 6 7 8 9 10 11 12 13 14 class  EnterpriseDangerClassifier (nn.Module):    def  __init__ (self, n_classes ):         super (EnterpriseDangerClassifier, self).__init__()         self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME)         self.drop = nn.Dropout(p=0.3 )         self.out = nn.Linear(self.bert.config.hidden_size, n_classes)      def  forward (self, input_ids, attention_mask ):         _, pooled_output = self.bert(             input_ids=input_ids,             attention_mask=attention_mask,             return_dict = False          )         output = self.drop(pooled_output)          return  self.out(output) 
 
 
将数据和模型送到CUDA 1 2 3 4 from  transformers import  BertModel, BertTokenizer,BertConfig, AdamW, get_linear_schedule_with_warmupmodel = EnterpriseDangerClassifier(len (class_names)) model = model.to(device) 
 
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
 
1 2 3 4 5 input_ids = data['input_ids' ].to(device) attention_mask = data['attention_mask' ].to(device) print (input_ids.shape) print (attention_mask.shape) 
 
torch.Size([4, 160])
torch.Size([4, 160])
 
得到单个batch的输出token 1 model(input_ids, attention_mask) 
 
tensor([[ 0.2120, -0.4050],
        [ 0.3156, -0.4160],
        [ 0.5127, -0.4634],
        [ 0.3168,  0.5057]], device='cuda:0', grad_fn=<AddmmBackward0>)
 
1 F.softmax(model(input_ids, attention_mask), dim=1 ) 
 
tensor([[0.4999, 0.5001],
        [0.5258, 0.4742],
        [0.5899, 0.4101],
        [0.4575, 0.5425]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
 
训练模型前期配置 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 EPOCHS = 10   optimizer = AdamW(model.parameters(), lr=2e-5 , correct_bias=False ) total_steps = len (train_data_loader) * EPOCHS scheduler = get_linear_schedule_with_warmup(   optimizer,   num_warmup_steps=0 ,   num_training_steps=total_steps ) loss_fn = nn.CrossEntropyLoss().to(device) 
 
定义模型的训练 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 def  train_epoch (  model,    data_loader,    loss_fn,    optimizer,    device,    scheduler,    n_examples  ):    model = model.train()      losses = []     correct_predictions = 0      for  d in  data_loader:         input_ids = d["input_ids" ].to(device)         attention_mask = d["attention_mask" ].to(device)         targets = d["labels" ].to(device)         outputs = model(             input_ids=input_ids,             attention_mask=attention_mask         )         _, preds = torch.max (outputs, dim=1 )         loss = loss_fn(outputs, targets)         correct_predictions += torch.sum (preds == targets)         losses.append(loss.item())         loss.backward()         nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0 )         optimizer.step()                  scheduler.step()         optimizer.zero_grad()     return  correct_predictions.double() / n_examples, np.mean(losses) 
 
模型的评估函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def  eval_model (model, data_loader, loss_fn, device, n_examples ):    model = model.eval ()      losses = []     correct_predictions = 0      with  torch.no_grad():         for  d in  data_loader:             input_ids = d["input_ids" ].to(device)             attention_mask = d["attention_mask" ].to(device)             targets = d["labels" ].to(device)             outputs = model(                 input_ids=input_ids,                 attention_mask=attention_mask             )             _, preds = torch.max (outputs, dim=1 )             loss = loss_fn(outputs, targets)             correct_predictions += torch.sum (preds == targets)             losses.append(loss.item())     return  correct_predictions.double() / n_examples, np.mean(losses) 
 
训练模型:10EPOCHS 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 history = defaultdict(list )  best_accuracy = 0  for  epoch in  range (EPOCHS):    print (f'Epoch {epoch + 1 } /{EPOCHS} ' )     print ('-'  * 10 )     train_acc, train_loss = train_epoch(         model,         train_data_loader,         loss_fn,         optimizer,         device,         scheduler,         len (df_train)     )     print (f'Train loss {train_loss}  accuracy {train_acc} ' )     val_acc, val_loss = eval_model(         model,         val_data_loader,         loss_fn,         device,         len (df_val)     )     print (f'Val   loss {val_loss}  accuracy {val_acc} ' )     print ()     history['train_acc' ].append(train_acc)     history['train_loss' ].append(train_loss)     history['val_acc' ].append(val_acc)     history['val_loss' ].append(val_loss)     if  val_acc > best_accuracy:         torch.save(model.state_dict(), 'best_model_state.bin' )         best_accuracy = val_acc 
 
Epoch 1/10
----------
Train loss 0.4988938277521757 accuracy 0.8899999999999999
Val   loss 0.4194765945523977 accuracy 0.9
Epoch 2/10
----------
Train loss 0.4967574527254328 accuracy 0.8905555555555555
Val   loss 0.43736912585794924 accuracy 0.9
Epoch 3/10
----------
Train loss 0.49347498720511795 accuracy 0.8905555555555555
Val   loss 0.41818931301434836 accuracy 0.9
Epoch 4/10
----------
Train loss 0.4900011462407807 accuracy 0.8905555555555555
Val   loss 0.42409916249414287 accuracy 0.9
Epoch 5/10
----------
Train loss 0.4952681002088098 accuracy 0.8888888888888888
Val   loss 0.31909402589624125 accuracy 0.9
Epoch 6/10
----------
Train loss 0.2478140213253425 accuracy 0.9463888888888888
Val   loss 0.1787985412031412 accuracy 0.9666666666666667
Epoch 7/10
----------
Train loss 0.17434944392257387 accuracy 0.9677777777777777
Val   loss 0.15001839348037416 accuracy 0.9700000000000001
Epoch 8/10
----------
Train loss 0.12048366091100939 accuracy 0.9775925925925926
Val   loss 0.11547344802587758 accuracy 0.9783333333333334
Epoch 9/10
----------
Train loss 0.10136666681817992 accuracy 0.9813888888888889
Val   loss 0.10292303454208498 accuracy 0.9800000000000001
Epoch 10/10
----------
Train loss 0.08721379442805402 accuracy 0.9831481481481481
Val   loss 0.12598223814862042 accuracy 0.9766666666666667
 
准确率绘图 1 2 3 4 5 6 7 8 9 plt.plot([i.cpu() for  i in  history['train_acc' ]], label='train accuracy' ) plt.plot([i.cpu() for  i in  history['val_acc' ]], label='validation accuracy' ) plt.title('Training history' ) plt.ylabel('Accuracy' ) plt.xlabel('Epoch' ) plt.legend() plt.ylim([0 , 1 ]); 
 
1 2 3 4 5 6 7 8 9 test_acc, _ = eval_model(   model,   test_data_loader,   loss_fn,   device,   len (df_test) ) test_acc.item() 
 
0.9783333333333334
 
预测模型 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 def  get_predictions (model, data_loader ):    model = model.eval ()     raw_texts = []     predictions = []     prediction_probs = []     real_values = []     with  torch.no_grad():         for  d in  data_loader:             texts = d["texts" ]             input_ids = d["input_ids" ].to(device)             attention_mask = d["attention_mask" ].to(device)             targets = d["labels" ].to(device)             outputs = model(                 input_ids=input_ids,                 attention_mask=attention_mask             )             _, preds = torch.max (outputs, dim=1 )              probs = F.softmax(outputs, dim=1 )              raw_texts.extend(texts)             predictions.extend(preds)             prediction_probs.extend(probs)             real_values.extend(targets)     predictions = torch.stack(predictions).cpu()     prediction_probs = torch.stack(prediction_probs).cpu()     real_values = torch.stack(real_values).cpu()     return  raw_texts, predictions, prediction_probs, real_values 
 
1 2 3 4 5 y_texts, y_pred, y_pred_probs, y_test = get_predictions(   model,   test_data_loader ) 
 
1 print (classification_report(y_test, y_pred, target_names=[str (label) for  label in  class_names])) 
 
              precision    recall  f1-score   support
           0       0.99      0.99      0.99       554
           1       0.84      0.89      0.86        46
    accuracy                           0.98       600
   macro avg       0.91      0.94      0.93       600
weighted avg       0.98      0.98      0.98       600
 
查看混淆矩阵 1 2 3 4 5 6 7 8 9 10 11 def  show_confusion_matrix (confusion_matrix ):    hmap = sns.heatmap(confusion_matrix, annot=True , fmt="d" , cmap="Blues" )     hmap.yaxis.set_ticklabels(hmap.yaxis.get_ticklabels(), rotation=0 , ha='right' )     hmap.xaxis.set_ticklabels(hmap.xaxis.get_ticklabels(), rotation=30 , ha='right' )     plt.ylabel('True label' )     plt.xlabel('Predicted label' ); cm = confusion_matrix(y_test, y_pred) df_cm = pd.DataFrame(cm, index=class_names, columns=class_names) show_confusion_matrix(df_cm) 
 
     
评估单条数据 1 2 3 4 5 6 7 8 idx = 2  sample_text = y_texts[idx] true_label = y_test[idx] pred_df = pd.DataFrame({   'class_names' : class_names,   'values' : y_pred_probs[idx] }) 
 
 
  
    
       
      class_names 
      values 
     
   
  
    
      0 
      0 
      0.999889 
     
    
      1 
      1 
      0.000111 
     
   
 
1 2 3 print ("\n" .join(wrap(sample_text)))print ()print (f'True label: {class_names[true_label]} ' )
 
小A班应急照明灯坏[SEP]商贸服务教文卫类[SEP]消防检查[SEP]防火巡查[SEP]安全出口、疏散通道是否畅通,安全疏散指示标志、应急
照明是否完好。
True label: 0
 
1 2 3 4 sns.barplot(x='values' , y='class_names' , data=pred_df, orient='h' ) plt.ylabel('sentiment' ) plt.xlabel('probability' ) plt.xlim([0 , 1 ]); 
 
 
1 2 3 4 5 6 7 8 9 encoded_text = tokenizer.encode_plus(   sample_text,   max_length=MAX_LEN,   add_special_tokens=True ,   return_token_type_ids=False ,   pad_to_max_length=True ,   return_attention_mask=True ,   return_tensors='pt' , ) 
 
1 2 3 4 5 6 7 8 input_ids = encoded_text['input_ids' ].to(device) attention_mask = encoded_text['attention_mask' ].to(device) output = model(input_ids, attention_mask) _, prediction = torch.max (output, dim=1 ) print (f'Sample text: {sample_text} ' )print (f'Danger label  : {class_names[prediction]} ' )
 
Sample text:  
Danger label  : 1