-
Sometimes I need to access a specific gradient transformation state within Does Optax provide a utility for this? Is this an anti-pattern? (In my case I have some utility creating a composed gradient transformation and I'd rather not access an inner gradient transformation state by a nested index because that feels too brittle.) Currently I'm using def find_state(state, cls):
if isinstance(state, cls):
return [state]
if isinstance(state, tuple):
return functools.reduce(
operator.add, (find_state(child, cls) for child in state), [])
return [] Then I can look up a specific state with e.g. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Hi! Thanks a lot for the question! As far as I know there is no utility for this at the moment in optax, but I also wouldn't consider it an anti-pattern (e.g. I think it's the best way to log variables from the optimizer state #206). For a simple optimizer state (and many are simple) I think it's fine to use the index. For more complicated chains I think it can be less readable to have e.g. What would the nested index look like in your case? Would it only be known at runtime? The reason I'm asking is that if it's a matter of readability, it might be possible to improve this without introducing a utility for searching within an optimizer state. Thanks a lot for the question again! |
Beta Was this translation helpful? Give feedback.
Hi! Thanks a lot for the question!
As far as I know there is no utility for this at the moment in optax, but I also wouldn't consider it an anti-pattern (e.g. I think it's the best way to log variables from the optimizer state #206).
For a simple optimizer state (and many are simple) I think it's fine to use the index. For more complicated chains I think it can be less readable to have e.g.
state[2]
.What would the nested index look like in your case? Would it only be known at runtime?
The reason I'm asking is that if it's a matter of readability, it might be possible to improve this without introducing a utility for searching within an optimizer state.
Thanks a lot for the question again!