forked from facebookresearch/adaptive-softmax
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathAdaptiveLoss.lua
63 lines (55 loc) · 1.81 KB
/
AdaptiveLoss.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
-- Copyright (c) 2016-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
local AdaptiveLoss, Criterion = torch.class('nn.AdaptiveLoss', 'nn.Criterion')
function AdaptiveLoss:__init(cutoff)
Criterion.__init(self)
self.cutoff = cutoff
self.criterions = {}
for i = 1, #cutoff do
table.insert(self.criterions, nn.CrossEntropyCriterion())
self.criterions[i].nll.sizeAverage = false
end
end
function AdaptiveLoss:remapTarget(target)
local new_target = {target:clone()}
local cutoff = self.cutoff
for i = 1, #cutoff - 1 do
local m = target:ge(cutoff[i] + 1):cmul(target:le(cutoff[i+1]))
new_target[1][m] = cutoff[1] + i
if m:any() then
table.insert(new_target, target[m]:add(-cutoff[i]))
else
table.insert(new_target, {})
end
end
return new_target
end
function AdaptiveLoss:updateOutput(input, target)
local bsz = input[1]:size(1)
local target = self:remapTarget(target)
self.output = 0.0
self.gradInput = {}
for i = 1, #input do
if torch.isTensor(input[i]) then
assert(target[i]:min() > 0 and target[i]:max() <= input[i]:size(2))
local criterion = self.criterions[i]
local loss = criterion:updateOutput(input[i], target[i])
self.output = self.output + loss
self.gradInput[i] = criterion:updateGradInput(input[i], target[i])
end
end
return self.output
end
function AdaptiveLoss:updateGradInput(input, target)
return self.gradInput
end
function AdaptiveLoss:cuda()
for i = 1, #self.criterions do
self.criterions[i]:cuda()
end
return self
end