-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoptimizer.py
350 lines (281 loc) · 12 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
from typing import Callable, Iterable, Tuple
import torch
from torch.optim import Optimizer
class AdamW(Optimizer):
def __init__(
self,
params: Iterable[torch.nn.parameter.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
correct_bias: bool = True,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])
)
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {} - should be >= 0.0".format(weight_decay)
)
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias
)
super().__init__(params, defaults)
def step(self, closure: Callable = None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"Adam does not support sparse gradients, please consider SparseAdam instead"
)
# State should be stored in this dictionary
state = self.state[p]
device = p.device
# Access hyperparameters from the `group` dictionary
alpha = group["lr"]
beta_1, beta_2 = group["betas"]
eps = group["eps"]
weight_decay = group["weight_decay"]
correct_bias = group["correct_bias"]
# Init state variables
if "t" not in state:
state["t"] = torch.tensor([0]).to(device)
if "m" not in state:
state["m"] = torch.zeros(size=grad.size(), dtype=grad.dtype).to(device)
if "v" not in state:
state["v"] = torch.zeros(size=grad.size(), dtype=grad.dtype).to(device)
state["t"] += 1
# Calculation of new weights
# Complete the implementation of AdamW here, reading and saving
# your state in the `state` dictionary above.
# The hyperparameters can be read from the `group` dictionary
# (they are lr, betas, eps, weight_decay, as saved in the constructor).
# 1- Update first and second moments of the gradients
state["m"].mul_(beta_1).add_(grad, alpha=1 - beta_1)
state["v"].mul_(beta_2).addcmul_(grad, grad, value=1 - beta_2)
# 2- Apply bias correction
# (using the "efficient version" given in https://arxiv.org/abs/1412.6980;
# also given in the pseudo-code in the project description).
if correct_bias:
alpha *= torch.sqrt(1 - beta_2 ** state["t"]) / (1 - beta_1 ** state["t"])
# 3- Update parameters (p.data).
p.data.sub_(alpha * state["m"] / (torch.sqrt(state["v"]) + eps))
# 4- After that main gradient-based update, update again using weight decay
# (incorporating the learning rate again).
p.data.sub_(group["lr"] * p.data * weight_decay)
return loss
class SophiaG(Optimizer):
"""
Sophia: Second-order Clipped Stochastic Optimization.
Using Sophia with the Gauss-Newton-Bartlett estimate of the Hessian.state["hessian"]
https://arxiv.org/pdf/2305.14342.pdf
"""
def __init__(
self,
params: Iterable[torch.nn.parameter.Parameter],
lr: float = 1e-4,
betas: Tuple[float, float] = (0.965, 0.99),
rho: float = 0.04,
weight_decay: float = 0.1,
eps: float = 1e-15,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])
)
if not 0.0 <= rho:
raise ValueError("Invalid rho value: {} - should be >= 0.0".format(rho))
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {} - should be >= 0.0".format(weight_decay)
)
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
defaults = dict(
lr=lr,
betas=betas,
rho=rho,
weight_decay=weight_decay,
eps=eps,
)
super(SophiaG, self).__init__(params, defaults)
@torch.no_grad()
def update_hessian(self, bs: int):
for group in self.param_groups:
_, beta2 = group["betas"]
for p in group["params"]:
if p.grad is None:
continue
state = self.state[p]
# B · ^g ⊙ ^g
# Update the hessian estimate (moving average)
state["hessian"].mul_(beta2).addcmul_(p.grad, p.grad, value=bs - bs * beta2)
@torch.no_grad()
def step(self, closure: Callable = None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
grad = p.grad
if grad is None:
continue
if grad.is_sparse:
raise RuntimeError("Sophia does not support sparse gradients")
# State should be stored in this dictionary
state = self.state[p]
# Init state variables
if len(state) == 0:
state["step"] = torch.zeros((1,), dtype=torch.float, device=p.device)
state["exp_avg"] = torch.zeros_like(p)
state["hessian"] = torch.zeros_like(p)
# Access hyperparameters from the `group` dictionary
beta1, _ = group["betas"]
rho = group["rho"]
lr = group["lr"]
eps = group["eps"]
weight_decay = group["weight_decay"]
exp_avg = state["exp_avg"]
hess = state["hessian"]
# Calculation of new weights
state["step"] += 1
# 1 - Perform stepweight decay
p.data.mul_(1 - lr * weight_decay)
# 2 - Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# 3 - Decay the hessian running average coefficient
# Clipping the hessian.
ratio = (exp_avg / (rho * hess + eps)).clamp(-1, 1)
p.data.add_(ratio, alpha=-lr)
return loss
class SophiaH(Optimizer):
"""
Sophia: Second-order Clipped Stochastic Optimization.
Using Sophia with the Hutchinson estimate of the Hessian.state["hessian"]
https://arxiv.org/pdf/2305.14342.pdf
"""
def __init__(
self,
params: Iterable[torch.nn.parameter.Parameter],
lr: float = 1e-4,
betas: Tuple[float, float] = (0.96, 0.99),
rho: float = 1e-2,
weight_decay: float = 0.0,
eps: float = 1e-12,
update_period: int = 10,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
"Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1])
)
if not 0.0 <= rho:
raise ValueError("Invalid rho value: {} - should be >= 0.0".format(rho))
if not 0.0 <= weight_decay:
raise ValueError(
"Invalid weight_decay value: {} - should be >= 0.0".format(weight_decay)
)
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
if not 0 < update_period:
raise ValueError(
"Invalid update_period value: {} - should be > 0".format(update_period)
)
self.update_period = update_period
defaults = dict(
lr=lr,
betas=betas,
rho=rho,
weight_decay=weight_decay,
eps=eps,
update_period=update_period,
)
super(SophiaH, self).__init__(params, defaults)
def _update_hessian(self):
for group in self.param_groups:
_, beta2 = group["betas"]
for p in group["params"]:
if p.grad is None:
continue
state = self.state[p]
# draw u from N(0, I)
u = torch.randn_like(p.grad)
# Compute < grad, u >
# Differentiate < grad, u > w.r.t. p
hvp = torch.autograd.grad(p.grad, p, grad_outputs=u, retain_graph=True)[0]
# u ⊙ hvp
state["hessian"].mul_(beta2).addcmul_(u, hvp, value=1 - beta2)
@torch.no_grad()
def step(self, closure: Callable = None):
step = self.param_groups[0].get("step", 1)
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
if step % self.update_period == 0:
self._update_hessian()
for group in self.param_groups:
if "step" in group:
group["step"] += 1
else:
group["step"] = 1
for p in group["params"]:
grad = p.grad
if grad is None:
continue
if grad.is_sparse:
raise RuntimeError("Sophia does not support sparse gradients")
# State should be stored in this dictionary
state = self.state[p]
# Init state variables
if len(state) == 0:
state["step"] = torch.zeros((1,), dtype=torch.float, device=p.device)
state["exp_avg"] = torch.zeros_like(p)
state["hessian"] = torch.zeros_like(p)
# Access hyperparameters from the `group` dictionary
beta1, _ = group["betas"]
rho = group["rho"]
lr = group["lr"]
eps = group["eps"]
weight_decay = group["weight_decay"]
exp_avg = state["exp_avg"]
hess = state["hessian"]
# Calculation of new weights
state["step"] += 1
# 1 - Perform stepweight decay
p.data.mul_(1 - lr * weight_decay)
# 2 - Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# 3 - Decay the hessian running average coefficient
# Clipping the hessian.
ratio = (exp_avg / torch.clip(hess, min=eps)).clamp(-rho, rho)
p.data.add_(ratio, alpha=-lr)
return loss