forked from yangheng95/PyABSA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lcfs_dual_bert.py
52 lines (43 loc) · 1.85 KB
/
lcfs_dual_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# -*- coding: utf-8 -*-
# @FileName: lcfs_dual_bert.py
# @Time : 2021/6/20 9:30
# @Author : yangheng@m.scnu.edu.cn
# @github : https://github.com/yangheng95
# Copyright (C) 2021. All Rights Reserved.
import copy
import torch
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertPooler
from pyabsa.network.sa_encoder import Encoder
class LCFS_DUAL_BERT(nn.Module):
inputs = ['text_bert_indices', 'text_raw_bert_indices', 'lcf_vec']
def __init__(self, bert, opt):
super(LCFS_DUAL_BERT, self).__init__()
self.bert4global = bert
self.bert4local = copy.deepcopy(bert)
self.opt = opt
self.dropout = nn.Dropout(opt.dropout)
self.bert_SA = Encoder(bert.config, opt)
self.linear2 = nn.Linear(opt.embed_dim * 2, opt.embed_dim)
self.bert_SA_ = Encoder(bert.config, opt)
self.bert_pooler = BertPooler(bert.config)
self.dense = nn.Linear(opt.embed_dim, opt.polarities_dim)
def forward(self, inputs):
if self.opt.use_bert_spc:
text_bert_indices = inputs[0]
else:
text_bert_indices = inputs[1]
text_local_indices = inputs[1]
lcf_matrix = inputs[2]
global_context_features = self.bert4global(text_bert_indices)['last_hidden_state']
local_context_features = self.bert4local(text_local_indices)['last_hidden_state']
# LCF layer
lcf_features = torch.mul(local_context_features, lcf_matrix)
lcf_features = self.bert_SA(lcf_features)
cat_features = torch.cat((lcf_features, global_context_features), dim=-1)
cat_features = self.linear2(cat_features)
cat_features = self.dropout(cat_features)
cat_features = self.bert_SA_(cat_features)
pooled_out = self.bert_pooler(cat_features)
dense_out = self.dense(pooled_out)
return dense_out