-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgate.py
51 lines (36 loc) · 1.84 KB
/
gate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
class GateMul(nn.Module):
def __init__(self, emb_size, num_lit_size, txt_lit_size, gate_activation=torch.sigmoid):
super(GateMul, self).__init__()
self.emb_size = emb_size
self.num_lit_size = num_lit_size
self.txt_lit_size = txt_lit_size
self.gate_activation = gate_activation
self.g = nn.Linear(emb_size+num_lit_size+txt_lit_size, emb_size)
self.gate_ent = nn.Linear(emb_size, emb_size, bias=False)
self.gate_num_lit = nn.Linear(num_lit_size, emb_size, bias=False)
self.gate_txt_lit = nn.Linear(txt_lit_size, emb_size, bias=False)
self.gate_bias = nn.Parameter(torch.zeros(emb_size))
def forward(self, x_ent, x_lit_num, x_lit_txt):
x = torch.cat([x_ent, x_lit_num, x_lit_txt], dim=1)
g_embedded = torch.tanh(self.g(x))
gate = self.gate_activation(self.gate_ent(x_ent) + self.gate_num_lit(x_lit_num) + self.gate_txt_lit(x_lit_txt) + self.gate_bias)
output = (1-gate) * x_ent + gate * g_embedded
return output
class Gate(nn.Module):
def __init__(self, emb_size, lit_size, gate_activation=torch.sigmoid):
super(Gate, self).__init__()
self.emb_size = emb_size
self.lit_size = lit_size
self.gate_activation = gate_activation
self.g = nn.Linear(emb_size+lit_size, emb_size)
self.gate_ent = nn.Linear(emb_size, emb_size, bias=False)
self.gate_lit = nn.Linear(lit_size, emb_size, bias=False)
self.gate_bias = nn.Parameter(torch.zeros(emb_size))
def forward(self, x_ent, x_lit):
x = torch.cat([x_ent, x_lit], dim=1)
g_embedded = torch.tanh(self.g(x))
gate = self.gate_activation(self.gate_ent(x_ent) + self.gate_lit(x_lit) + self.gate_bias)
output = (1-gate) * x_ent + gate * g_embedded
return output