-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathseq_utils.py
329 lines (291 loc) · 12.2 KB
/
seq_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
# sequence utility functions
import torch
import math
import numpy as np
def ot2bieos_ts(ts_tag_sequence):
"""
ot2bieos function for targeted-sentiment task, ts refers to targeted -sentiment / aspect-based sentiment
:param ts_tag_sequence: tag sequence for targeted sentiment
:return:
"""
n_tags = len(ts_tag_sequence)
new_ts_sequence = []
prev_pos = '$$$'
for i in range(n_tags):
cur_ts_tag = ts_tag_sequence[i]
if cur_ts_tag == 'O' or cur_ts_tag == 'EQ':
# when meet the EQ label, regard it as O label
new_ts_sequence.append('O')
cur_pos = 'O'
else:
cur_pos, cur_sentiment = cur_ts_tag.split('-')
# cur_pos is T
if cur_pos != prev_pos:
# prev_pos is O and new_cur_pos can only be B or S
if i == n_tags - 1:
new_ts_sequence.append('S-%s' % cur_sentiment)
else:
next_ts_tag = ts_tag_sequence[i + 1]
if next_ts_tag == 'O':
new_ts_sequence.append('S-%s' % cur_sentiment)
else:
new_ts_sequence.append('B-%s' % cur_sentiment)
else:
# prev_pos is T and new_cur_pos can only be I or E
if i == n_tags - 1:
new_ts_sequence.append('E-%s' % cur_sentiment)
else:
next_ts_tag = ts_tag_sequence[i + 1]
if next_ts_tag == 'O':
new_ts_sequence.append('E-%s' % cur_sentiment)
else:
new_ts_sequence.append('I-%s' % cur_sentiment)
prev_pos = cur_pos
return new_ts_sequence
def ot2bieos_ts_batch(ts_tag_seqs):
"""
batch version of function ot2bieos_ts
:param ts_tag_seqs:
:return:
"""
new_ts_tag_seqs = []
n_seqs = len(ts_tag_seqs)
for i in range(n_seqs):
new_ts_seq = ot2bieos_ts(ts_tag_sequence=ts_tag_seqs[i])
new_ts_tag_seqs.append(new_ts_seq)
return new_ts_tag_seqs
def ot2bio_ts(ts_tag_sequence):
"""
ot2bio function for ts tag sequence
:param ts_tag_sequence:
:return:
"""
new_ts_sequence = []
n_tag = len(ts_tag_sequence)
prev_pos = '$$$'
for i in range(n_tag):
cur_ts_tag = ts_tag_sequence[i]
if cur_ts_tag == 'O':
new_ts_sequence.append('O')
cur_pos = 'O'
else:
# current tag is subjective tag, i.e., cur_pos is T
# print(cur_ts_tag)
cur_pos, cur_sentiment = cur_ts_tag.split('-')
if cur_pos == prev_pos:
# prev_pos is T
new_ts_sequence.append('I-%s' % cur_sentiment)
else:
# prev_pos is O
new_ts_sequence.append('B-%s' % cur_sentiment)
prev_pos = cur_pos
return new_ts_sequence
def ot2bio_ts_batch(ts_tag_seqs):
"""
batch version of function ot2bio_ts
:param ts_tag_seqs:
:return:
"""
new_ts_tag_seqs = []
n_seqs = len(ts_tag_seqs)
for i in range(n_seqs):
new_ts_seq = ot2bio_ts(ts_tag_sequence=ts_tag_seqs[i])
new_ts_tag_seqs.append(new_ts_seq)
return new_ts_tag_seqs
def bio2ot_ts(ts_tag_sequence):
"""
perform bio-->ot for ts tag sequence
:param ts_tag_sequence:
:return:
"""
new_ts_sequence = []
n_tags = len(ts_tag_sequence)
for i in range(n_tags):
ts_tag = ts_tag_sequence[i]
if ts_tag == 'O' or ts_tag == 'EQ':
new_ts_sequence.append('O')
else:
pos, sentiment = ts_tag.split('-')
new_ts_sequence.append('T-%s' % sentiment)
return new_ts_sequence
def bio2ot_ts_batch(ts_tag_seqs):
"""
batch version of function bio2ot_ts
:param ts_tag_seqs:
:return:
"""
new_ts_tag_seqs = []
n_seqs = len(ts_tag_seqs)
for i in range(n_seqs):
new_ts_seq = bio2ot_ts(ts_tag_sequence=ts_tag_seqs[i])
new_ts_tag_seqs.append(new_ts_seq)
return new_ts_tag_seqs
def tag2ts(ts_tag_sequence):
"""
transform ts tag sequence to targeted sentiment
:param ts_tag_sequence: tag sequence for ts task
:return:
"""
n_tags = len(ts_tag_sequence)
ts_sequence, sentiments = [], []
beg, end = -1, -1
for i in range(n_tags):
ts_tag = ts_tag_sequence[i]
# current position and sentiment
# tag O and tag EQ will not be counted
eles = ts_tag.split('-')
if len(eles) == 2:
pos, sentiment = eles
else:
pos, sentiment = 'O', 'O'
if sentiment != 'O':
# current word is a subjective word
sentiments.append(sentiment)
if pos == 'S':
# singleton
ts_sequence.append((i, i, sentiment))
sentiments = []
elif pos == 'B':
beg = i
if len(sentiments) > 1:
# remove the effect of the noisy I-{POS,NEG,NEU}
sentiments = [sentiments[-1]]
elif pos == 'E':
end = i
# schema1: only the consistent sentiment tags are accepted
# that is, all of the sentiment tags are the same
if end > beg > -1 and len(set(sentiments)) == 1:
ts_sequence.append((beg, end, sentiment))
sentiments = []
beg, end = -1, -1
return ts_sequence
def logsumexp(tensor, dim=-1, keepdim=False):
"""
:param tensor:
:param dim:
:param keepdim:
:return:
"""
max_score, _ = tensor.max(dim, keepdim=keepdim)
if keepdim:
stable_vec = tensor - max_score
else:
stable_vec = tensor - max_score.unsqueeze(dim)
return max_score + (stable_vec.exp().sum(dim, keepdim=keepdim)).log()
def viterbi_decode(tag_sequence, transition_matrix,
tag_observations=None, allowed_start_transitions=None,
allowed_end_transitions=None):
"""
Perform Viterbi decoding in log space over a sequence given a transition matrix
specifying pairwise (transition) potentials between tags and a matrix of shape
(sequence_length, num_tags) specifying unary potentials for possible tags per
timestep.
Parameters
----------
tag_sequence : torch.Tensor, required.
A tensor of shape (sequence_length, num_tags) representing scores for
a set of tags over a given sequence.
transition_matrix : torch.Tensor, required.
A tensor of shape (num_tags, num_tags) representing the binary potentials
for transitioning between a given pair of tags.
tag_observations : Optional[List[int]], optional, (default = None)
A list of length ``sequence_length`` containing the class ids of observed
elements in the sequence, with unobserved elements being set to -1. Note that
it is possible to provide evidence which results in degenerate labelings if
the sequences of tags you provide as evidence cannot transition between each
other, or those transitions are extremely unlikely. In this situation we log a
warning, but the responsibility for providing self-consistent evidence ultimately
lies with the user.
allowed_start_transitions : torch.Tensor, optional, (default = None)
An optional tensor of shape (num_tags,) describing which tags the START token
may transition *to*. If provided, additional transition constraints will be used for
determining the start element of the sequence.
allowed_end_transitions : torch.Tensor, optional, (default = None)
An optional tensor of shape (num_tags,) describing which tags may transition *to* the
end tag. If provided, additional transition constraints will be used for determining
the end element of the sequence.
Returns
-------
viterbi_path : List[int]
The tag indices of the maximum likelihood tag sequence.
viterbi_score : torch.Tensor
The score of the viterbi path.
"""
sequence_length, num_tags = list(tag_sequence.size())
has_start_end_restrictions = allowed_end_transitions is not None or allowed_start_transitions is not None
if has_start_end_restrictions:
if allowed_end_transitions is None:
allowed_end_transitions = torch.zeros(num_tags)
if allowed_start_transitions is None:
allowed_start_transitions = torch.zeros(num_tags)
num_tags = num_tags + 2
new_transition_matrix = torch.zeros(num_tags, num_tags)
new_transition_matrix[:-2, :-2] = transition_matrix
# Start and end transitions are fully defined, but cannot transition between each other.
# pylint: disable=not-callable
allowed_start_transitions = torch.cat([allowed_start_transitions, torch.tensor([-math.inf, -math.inf])])
allowed_end_transitions = torch.cat([allowed_end_transitions, torch.tensor([-math.inf, -math.inf])])
# pylint: enable=not-callable
# First define how we may transition FROM the start and end tags.
new_transition_matrix[-2, :] = allowed_start_transitions
# We cannot transition from the end tag to any tag.
new_transition_matrix[-1, :] = -math.inf
new_transition_matrix[:, -1] = allowed_end_transitions
# We cannot transition to the start tag from any tag.
new_transition_matrix[:, -2] = -math.inf
transition_matrix = new_transition_matrix
if tag_observations:
if len(tag_observations) != sequence_length:
raise Exception("Observations were provided, but they were not the same length "
"as the sequence. Found sequence of length: {} and evidence: {}"
.format(sequence_length, tag_observations))
else:
tag_observations = [-1 for _ in range(sequence_length)]
if has_start_end_restrictions:
tag_observations = [num_tags - 2] + tag_observations + [num_tags - 1]
zero_sentinel = torch.zeros(1, num_tags)
extra_tags_sentinel = torch.ones(sequence_length, 2) * -math.inf
tag_sequence = torch.cat([tag_sequence, extra_tags_sentinel], -1)
tag_sequence = torch.cat([zero_sentinel, tag_sequence, zero_sentinel], 0)
sequence_length = tag_sequence.size(0)
path_scores = []
path_indices = []
if tag_observations[0] != -1:
one_hot = torch.zeros(num_tags)
one_hot[tag_observations[0]] = 100000.
path_scores.append(one_hot)
else:
path_scores.append(tag_sequence[0, :])
# Evaluate the scores for all possible paths.
for timestep in range(1, sequence_length):
# Add pairwise potentials to current scores.
summed_potentials = path_scores[timestep - 1].unsqueeze(-1) + transition_matrix
scores, paths = torch.max(summed_potentials, 0)
# If we have an observation for this timestep, use it
# instead of the distribution over tags.
observation = tag_observations[timestep]
# Warn the user if they have passed
# invalid/extremely unlikely evidence.
if tag_observations[timestep - 1] != -1 and observation != -1:
if transition_matrix[tag_observations[timestep - 1], observation] < -10000:
logger.warning("The pairwise potential between tags you have passed as "
"observations is extremely unlikely. Double check your evidence "
"or transition potentials!")
if observation != -1:
one_hot = torch.zeros(num_tags)
one_hot[observation] = 100000.
path_scores.append(one_hot)
else:
path_scores.append(tag_sequence[timestep, :] + scores.squeeze())
path_indices.append(paths.squeeze())
# Construct the most likely sequence backwards.
viterbi_score, best_path = torch.max(path_scores[-1], 0)
viterbi_path = [int(best_path.numpy())]
for backward_timestep in reversed(path_indices):
viterbi_path.append(int(backward_timestep[viterbi_path[-1]]))
# Reverse the backward path.
viterbi_path.reverse()
if has_start_end_restrictions:
viterbi_path = viterbi_path[1:-1]
#return viterbi_path, viterbi_score
return np.array(viterbi_path, dtype=np.int32)