-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a sharing mechanism to Ast terms #95
Comments
BTW, in the YOLO paper, they seem to have a sharing preservation mechanism as well (section 5.1.2). I wonder how complete it is. Is it simper/better than ours in Delta expressions? Given that we now do source transform, can we use it, too? Or is our source transform too late, after the unrolling of Haskell control flow and datatypes duplicates terms too much already? If so, would we need to either transform the raw Haskell or make the user write in a single-assignment form or something similar (perhaps using the |
In my reading, yolo section 5.1.2 doesn't so much describe a sharing preservation mechanism, as it does making sharing that was already there visibly (!) in the AST, also visible through a computational operation:
which I read as saying: it has So it seems YOLO just uses Why can't we? I feel a Or maybe I'm wrong and that will work fine. |
Oh, I see, so they assume their program already has the sharing marked, and they just preserve it through the transformations. That's no help, unless we want horde-ad users to mark all sharing with explicit Regarding how to denote sharing, we have a compromise in our Delta expressions: the stamp-per-node is introduced by a fake
which is a bit costly, but the ids don't clutter the rest of the syntax. The important decision is, though, that surface syntax has no impact on where |
The problem with |
One thing to the advantage of the About vectorising
Yes, that's an actual substitution so there will be some recomputation, but only of indexing which is constant-time anyway and fairly cheap. And furthermore if you try to turn that substitution into a let-binding to preserve sharing, you end up with infinite recursion in applying this rule. :) Not an answer to all your questions, though perhaps already interesting. |
That answers all my questions and more. Now I never end up with
That's a fair point. It's such a pity we can't trace Haskell Hah, while adding
but that's an unnatural name and argument order for
or should it be called I guess I'm accidentally circling back to our discussion of tracing functions, though this is about tracing applications. In fact, tracing function abstractions would have the same effect (and type signature and code)
where the
though in this example (that's the likely culprit for the 1M times slowdown) the second |
The good news is that this speeds up the 1M test. The bad news are
where the multiple
could not work. The best I can do is
but if so, the
With
but now I need to run vectorization on the outermost |
The good news from above scales nicely --- with more parameters, the neutral network runs 10 times faster. This 10x is from [edit: reducing and probably even totally eliminating] the exponential blowup already in InterpretAst, without either vectorization or Ast inside Delta. Here's the change: |
The conclusion after profiling is that The remaining 1000x slowdown, just as exponential, is that in Delta evaluation we accumulate Ast terms in cotangent contribution maps. The delta expressions are perfectly shared, but the Ast terms inside them are not (they have I wonder how to express the sharing of Ast terms residing in different cells of the cotangent contribution maps. |
Do you have a small example of an Ast that results in something being evaluated at least twice where it should be computed once? I think I'm getting what you say, but looking at an example will help a lot. Re 1 if you're having trouble seeing how the f to use that library, I've figured it out at some point and can dig up my working example code. The output is a graph, that is: if |
Err -- data-reify doesn't directly apply, of course, because that library assumes an untyped AST, whereas ours is typed. However, since ours is not scoped, simply typed (because it is in HOAS form -- the kind of |
I don't, right now, but let me show you how the duplication arises. Let's say we have a delta term horde-ad/simplified/HordeAd/Core/Delta.hs Lines 462 to 468 in fff6dd5
Evaluating it with But given that key One extra complication is that the gradient is a tuple of terms, one per input, so we can't even syntactically express |
Complete and utter success (but without vectorization and simplification, so far). 500x speedup by adding lets in precisely Possibly around 100x speedup (but it's exponential and I changed problem size, so the scale has changed) via adding two lets to the and then possibly a 100x speedup at a yet larger scale by adding lets to all the dual number arithmetic classes instances (which matter only for dual numbers with Ast inside) In the end, our second pipeline (the one producing an Ast gradient) that was 1M times slower is now a bit faster than the first pipeline and only a bit slower than a pipeline that goes from A problem is that we now have several places where sharing is introduced and a couple of kinds of sharing (not even counting the Another problem is that after the AD phase we no longer can apply simplification via term rewriting using our rules. The rewriting at this point breaks (e.g., via substitution) global sharing (the fake single-argument let), even though it works fine with local sharing (the ordinary let x = y in z). However, we can do local simplification when performing the final interpretation and it does matter for speed (at least when sharing is broken and so there's millions of terms; I haven't now tried disabling it). A minor problem is that memoization tables for global sharing can't be shared between different iterations of the |
I wonder about vectorization of let (local sharing) again. If I have
should I vectorize v1 at width 2 or 3? Or should I remove a portion of the sharing and vectorize instead
What if a build has 100 variables, and the I wonder if the global sharing fares any better. I start with
and let's focus on the first summand and vectorize it
the rules preserve equality fine, but I lost the sharing with the other summands, which clearly won't rewrite the It seems I can't get away from using a table as the state of the vectorization procedure. Let's repeat vectorization of the following, but with a memoization table
I vectorize the first summand to
and this seems to work fine, cheaply, not loosing any sharing that can easily be preserved and without inspecting the whole program or constructing 100-element lets. What did I miss? |
Oh, I see what I have missed: a the term bound in a let outside build can't contain the build variables, so we don't need to touch it. Doh. |
Yay for sharing, I guess! :)
Yes and this is concerning, not only from a correctness and efficiency perspective (do we have all sharing now?), but also from a paper-writing perspective: how is the poor reader (and future developers of horde-ad) going to understand this mess?
I'm not sure I follow your argument here. So rewriting + global sharing doesn't combine, but rewriting + local sharing works fine? So it should work if we have proper
This I also don't follow. Is this still relevant re your progress on let vectorisation tonight? (Which is great, by the way -- I was stumped by your first message and then shared your 'Doh' upon reading your second one) |
Yes. I knew about hash-consing, but I had no idea ad-hoc sharing changes so much (removes three different sources of exponential runtime, in this case, not counting the Delta sharing). For discovering the joys of sharing I blame Simon and his fruitful insistence on source transform in place of my concrete ADVal, which had only one kind of sharing explosion potential (explosion of Delta terms size). Source transform is very powerful, but it's so fragile vs sharing.
We have to clean it up, but I'd like to first decide the fate of the user-visible sharing (
All but
Let's only mention the local
poor buggers
Probably depends on the rules. Vectorization using your methods works great. Simplification is almost fine:
does not simplify to The global version
on the other hand, simplifies to
adding Which one is better is very context-dependent. But in general, I'd rather simplify less, but not pay an unbounded and unpredictable memory tax. I've been burned with hash-consing. Moreoever, global sharing is complex, stateful and not obviously correct.
Yes.
Two impediments. 1. The sharing in the transpose code can't even easily be expressed in the local way. 2. We indeed have to convert, if we naively add automatic sharing to user code, by wrapping each argument of Ast instance of Tensor with `tletR. Other than the two places, sharing can be written using the civlized local And Delta sharing is probably completely orthogonal and really well suited to the global method (Tree0), because we never rewrite Delta terms in the slightest and we do lots of global cross-referencing of them that is not captured by any syntax.
:) No, I'm fine, the memoization being invalidated by substitutions or changing arguments to the gather function is only relevant for the global lets, while vectorization is intended to work on the local lets, at least for now and at least in the paper. |
I've just implemented vectorization of |
I've added all the needed I think the problem with expressing sharing of the gradient (doable) and of cotangent contributions during transposition (not really doable IMHO) is that the transposition can't be expressed as rewriting rules, because it's stateful. The best we can do is operational semantics, but the sharing we are after is the sharing inside the environments on the left hand side of the turnstile, not on the right hand side, in the terms. So, it can't be expressed in the syntax. So probably the best we can do is keep the
can't be expressed with local lets. Fortunately, we don't have
I've implemented this and its analogue for the global lets. Apparently we don't have an evil enough test where this would be a bottleneck, but every bit of simplification helps, especially for readability when debugging. |
I'm stumped. I've implemented the ordinary local lets for transpose but the best I can do is a huge string of lets before a term with no lets in it. The same we'd get via sharing recovery from global lets, I imagine. What remains to do now is 8fd6612#r107726007, but I have no idea how to do this reasonably. The problem is that we need to share primal values (Ast terms) across the primal and dual components of a dual number. So the Ast-level let is not enough, because it only scopes over Ast terms and the dual numbers are not Ast terms (
The only solution that comes to mind is storing the let bindings as a third argument of the I'm considering ripping out local lets from transpose and adding an optional sharing recovery pass, needed whenever simplification term rewriting pass is to be performed before interpretation pass(es) or when the Ast gradient is used to build up some more complex program, perhaps to be differentiated again. This seems like a serious drawback of dual numbers. I bet CHAD doesn't have that problem. |
Showing off Tapenade feature-parity (as soon as we also support Fortran and COBOL), pretty-printing (based on CHAD and using Tom's hindent trick) and how local sharing ( foo :: RealFloat a => (a, a, a) -> a
foo (x, y, z) =
let w = x * sin y
in atan2 z w + z * w
rev foo:
\s0 x2 x3 x4 dt ->
dlet (x4 * dt)
(\x6 ->
dlet (negate (x4 * (tconst 1.0
/ (x4 * x4 + (x2 * sin x3) * (x2 * sin x3))))
* dt)
(\x7 ->
dmkDomains
(fromList
[ dfromR (tfromList [])
, dfromR (sin x3 * x7 + sin x3 * x6)
, dfromR (cos x3 * (x2 * x7) + cos x3 * (x2 * x6))
, dfromR
(((x2 * sin x3)
* (tconst 1.0
/ (x4 * x4 + (x2 * sin x3) * (x2 * sin x3))))
* dt
+ (x2 * sin x3)
* dt)
])))
visible sharing:
\s0 x2 x3 x4 dt ->
dlet
(x4 * dt)
(\x6 ->
dlet
(negate
(x4 *
(tconst 1.0 /
(x4 * x4 +
tletR4 (x2 * tletR1 (sin x3)) * tletR4 (x2 * tletR1 (sin x3))))) *
dt)
(\x7 ->
dmkDomains
(fromList
[ dfromR (tfromList [])
, dfromR (tletR1 (sin x3) * x7 + tletR2 (sin x3) * x6)
, dfromR (cos x3 * (x2 * x7) + cos x3 * (x2 * x6))
, dfromR
((tletR4 (x2 * tletR1 (sin x3)) *
(tconst 1.0 /
(x4 * x4 +
tletR4 (x2 * tletR1 (sin x3)) *
tletR4 (x2 * tletR1 (sin x3))))) *
dt +
tletR3 (x2 * tletR2 (sin x3)) * dt)
])))
fooLet :: forall r n. (RealFloat (TensorOf n r), Tensor r, KnownNat n)
=> (TensorOf n r, TensorOf n r, TensorOf n r) -> TensorOf n r
fooLet (x, y, z) = tlet @r @n (x * sin y) $ \w -> atan2 z w + z * w
rev fooLet:
\s0 x2 x3 x4 dt ->
dlet (negate (x4 * (tconst 1.0
/ (x4 * x4 + (x2 * sin x3) * (x2 * sin x3))))
* dt
+ x4 * dt)
(\x7 ->
dmkDomains
(fromList
[ dfromR (tfromList [])
, dfromR (sin x3 * x7)
, dfromR (cos x3 * (x2 * x7))
, dfromR
(((x2 * sin x3)
* (tconst 1.0 / (x4 * x4 + (x2 * sin x3) * (x2 * sin x3))))
* dt
* (x2 * sin x3)
* dt)
]))
visible sharing:
\s0 x2 x3 x4 dt ->
dlet
(negate
(x4 *
(tconst 1.0 /
(x4 * x4 +
tletR2 (x2 * tletR1 (sin x3)) * tletR2 (x2 * tletR1 (sin x3))))) *
dt +
x4 * dt)
(\x7 ->
dmkDomains
(fromList
[ dfromR (tfromList [])
, dfromR (tletR1 (sin x3) * x7)
, dfromR (cos x3 * (x2 * x7))
, dfromR
((tletR2 (x2 * tletR1 (sin x3)) *
(tconst 1.0 /
(x4 * x4 +
tletR2 (x2 * tletR1 (sin x3)) *
tletR2 (x2 * tletR1 (sin x3))))) *
dt +
tletR2 (x2 * tletR1 (sin x3)) * dt)
]))
the corresponding primal part, sharing with the gradient Ast above:
\s0 x2 x3 x4 -> atan2 x4 (tletR2 (x2 * tletR1 (sin x3)))
+ x4 * tletR2 (x2 * tletR1 (sin x3)) |
Interesting how there is some sort of re-association going on between the two "visible sharing" versions! The first has this structure: let x6 = x4 * dt
in let x7 = bla dt
in ... a * (b * x6) + a * (b * x7) ... whereas the second has this structure: let x7 = bla dt + x4 * dt
in ... a * (b * x7) ... I didn't expect such reassociation to come from sharing preservation, but maybe I'm shortsighted here. Also cool stuff! The global sharing representation is still counter-intuitive to read for me. |
Yes, this is quite curious. Notice how variable Something else that is clearly visible here, is how the local sharing in the surface language induces global sharing. Not by explicitly translating local variables to global identifiers, but by doing forward pass on some terms only once, not twice, which cases the global identifiers to be generated only once, as so shared across the two branches. |
Would it be warranted to add a simplification rule like this?
It's not quite complexity preserving if n is unbounded. But maybe we should do this for small n only. Assuming sufficiently intelligent array representations, repeated fromLists of already-materialised arrays shouldn't even do all that much; but maybe that's too optimistic.
Right, global sharing seems to have it much easier here. But it's much harder to maintain -- where does the
Can't it? Which
Yep!
That is, simplifying index-of-let and vectorising build-of-let?
It's only stateful because it's Cayley-transformed. The eval function: eval :: R -> Delta -> DeltaMap -> DeltaMap is really So maybe it can be expressed using rewrite rules, but just into the language of
This I don't really follow unfortunately.
(did you mean "presence", not "absence"?) As far as I know, any valid global sharing can be expressed as local sharing, although shared expressions might move a long way up the tree if they're used in textually distant parts of the code.
Yeah maybe it doesn't matter much in practice. But it's an easy simplification to make, and intuitively it shouldn't ever make things worse -- the cost of checking whether the RHS is a variable node should always be offset by saving a let frame at execution time.
You're right that it seems difficult to do better than this; it definitely lessens the appeal of using "local" lets. Hm, I hoped it wouldn't be this bad. Would it be better if we made data DOfAst d a = DOfAst (Ast (a, Dual d a)) i.e. produce the pair late instead of early. This will allow lets to build up naturally inside the AST, but does require you to add pairs to the Ast language. |
Right. This may fix that problem. I will try when I revisit simplification rules next time.
Ouch, I've made a mistake. This should be memoized, but obviously it can't work with this precise sharing. We' need
and then we memoize that However, when we have local sharing with enough sharing expressed let x = fromList [t0, t1] in let y = x ! [0] in (y, y, ...) it's stuck and that's probably what I had in mind and indeed global sharing more naturally leads to the optimal outcome here. Risking a premature generalization: if you have maximal sharing, the global style may be better, but with selective sharing (e.g., manual in response to benchmarking), local sharing is just fine, while being less complex. This makes me reconsider the idea to do forward pass (e.g., the numerical classes for dual numbers) and transpose using global sharing and then perform a sharing recovery pass in order to simplify the resulting Ast before compilation to GPU. If the resultant local sharing is as dense as above, and it may well be, because binary operators appear a lot, then it may turn out that no simplification rule applies at all, because it all stops at tbc |
Except if you add my let-fromList-inline rule from above with n >= 2, in which case this will simplify to
This would be an interesting proposition. I wonder how much you can massage the two to be equivalent by adding well-chosen simplification rules like the above.
But then simplification should include inlining of cheap computations. This is a standard part of compiler optimisation pipielines. My previous point stands -- I wonder if you can massage this to mostly not differ anymore. :) |
I managed to overcome it, at the cost of complicating our
Yes.
Oh, cool. Let me digest this and come back.
Nothing deep:
and the sharing that needs to be introduced is between components of
I don't know, right? :D I just wipe out memoization tables whenever I interpret terms inside a function. WIP. But let me amend the example to show the vague point that I have: tfromList [tgather sh v (\i -> tLetInt n0 (i + e)), tgather sh u (\i -> tLetInt n0 (i + e))] and let's assume let x i = i + e in tfromList [tgather sh v (\i -> tLetInt n0 (x i)), tgather sh u (\i -> tLetInt n0 (x i))] and we could simplify or interpret or whatever the term
I guess this moving up the tree involves lambda abstraction?
Surely. I've since applied it to a couple more places and no deterioration and I bet there are evil examples where it actually improves stuff a lot.
Note that the pretty-printed terms are of this form and indeed, I can't see a way out. However, I haven't yet grokked your "transposition is Cayley-transformed" idea, so perhaps this works. Still, I can't see how forward pass (numeric classes for ADVal) can generate better local sharing, given that both the
I hoped you wouldn't mention this ;). Our Instead of the current D u u' + D v v' = dD (u + v) (dAdd u' v')
we'd need u + v = tD (tprimaPart u + tPrimalPart v) (tAdd (tdualPart u) (tdualPart v')) where tScale :: KnownNat n => TensorOf n (Primal r) -> DualOf n r -> DualOf n r
...
instance (ADTensor (Ast0 r), ShowAstSimplify r)
=> Tensor (ADVal (Ast0 r)) where
tScale = dScale
...
instance ShowAstSimplify r
=> Tensor (Ast0 r) where
tScale (AstPrimalPart s) (AstDualPart t) = AstDualPart $ s `tmult` t I forgot if there's any fundamental typing or logic problem there or just
|
I've done a very bad thing. In order to pretty-print well, so that I can ask how to implement sharing, I've implemented this sharing method:
All global lets are ripped out, but this method is messy, currently 20% slower (80% slower in the Mnist fully connected nn test) and is asymptotically slower (probably quadratic instead of linear or log-linear; I hope not exponential this time) than global sharing in some cases that my large tests apparently don't trigger yet. |
Pretty-printing and sharing, continued. [Edited multiple times based on chat feedback.] foo :: RealFloat a => (a, (a, a)) -> a
foo (x, (y, z)) =
let w = x * sin y
in atan2 z w + z * w (1: naive human version) revFooHuman :: r -> (r, (r, r)) -> (r, (r, r))
revFooHuman dret (x, (y, z)) =
let x6 = sin y
x7 = x * x6
x8 = recip (z * z + x7 * x7)
x9 = sin y
x10 = x * x9
x11 = z * dret
x12 = negate (z * x8) * dret
in ( x6 * x12 + x9 * x11
, ( cos y * (x * x12) + cos y * (x * x11)
, (x7 * x8) * dret + x10 * dret
)
) (2: automatically printed; the beautifying variant) this is generated from let revFooHorde :: TensorOf 1 r -- scalar arguments packed in tensor s0
-> TensorOf 0 r -- dret
-> TensorOf 0 r -> TensorOf 0 r -> TensorOf 0 r -- tensor arguments
-> (TensorOf 1 r, TensorOf 0 r, TensorOf 0 r, TensorOf 0 r)
revFooHorde =
-- automatically pretty-printed code starts here
\s0 dret x y z -> -- could be written (x, y, z) for free, but not (x, (y, z))
let x6 = sin y
x7 = x * x6
x8 = recip (z * z + x7 * x7)
x9 = sin y
x10 = x * x9
x11 = z * dret
x12 = negate (z * x8) * dret
in ( tfromList []
, x6 * x12 + x9 * x11
, cos y * (x * x12) + cos y * (x * x11)
, (x7 * x8) * dret + x10 * dret
) (3: adapted version) this shows how to ad-hoc adapt (2) to be almost as useful as (4). Sharing is lost but otherwise, this can easily be used for constructing new revFoo :: TensorOf 1 r -- scalar arguments packed in tensor s0
-> TensorOf 0 r -- dret
-> TensorOf 0 r -> TensorOf 0 r -> TensorOf 0 r -- tensor arguments
-> DomainsOf r
revFoo sI dretI xI yI zI =
let revFooHorde :: TensorOf 1 r -- scalar arguments packed in tensor s0
-> TensorOf 0 r -- dret
-> TensorOf 0 r -> TensorOf 0 r -> TensorOf 0 r -- tensor arguments
-> (TensorOf 1 r, TensorOf 0 r, TensorOf 0 r, TensorOf 0 r)
revFooHorde =
-- automatically pretty-printed code starts here, copied from above, except for (x, y, z)
\s0 dret (x, y, z) -> -- the tuple is a possible variant, for illustration
let x6 = sin y
x7 = x * x6
x8 = recip (z * z + x7 * x7)
x9 = sin y
x10 = x * x9
x11 = z * dret
x12 = negate (z * x8) * dret
in ( tfromList []
, x6 * x12 + x9 * x11
, cos y * (x * x12) + cos y * (x * x11)
, (x7 * x8) * dret + x10 * dret
)
-- automatically pretty-printed code ends here
in toDomains'' $ gradient sI dretI (xI, yI, zI) (4: automatically printed; the raw variant that exposes all sharing) let revFooHordeRaw :: TensorOf 1 r -- scalar arguments packed in tensor s0
-> TensorOf 0 r -- dret
-> TensorOf 0 r -> TensorOf 0 r -> TensorOf 0 r -- tensor arguments
-> DomainsOf r
revFooHordeRaw =
-- automatically pretty-printed code starts here
\s0 dret x y z ->
dlet (sin y)
(\x6 ->
dlet (x * x6)
(\x7 ->
dlet (recip (z * z + x7 * x7))
(\x8 ->
dlet (sin y)
(\x9 ->
dlet (x * x9)
(\x10 ->
dlet (z * dret)
(\x11 ->
dlet (negate (z * x8) * dret)
(\x12 ->
dmkDomains
(fromList
[ dfromR (tfromList [])
, dfromR (x6 * x12 + x9 * x11)
, dfromR (cos y * (x * x12) + cos y * (x * x11))
, dfromR ((x7 * x8) * dret + x10 * dret)
])))))))) This is fooLet :: forall r n. (RealFloat (TensorOf n r), Tensor r, KnownNat n)
=> (TensorOf n r, (TensorOf n r, TensorOf n r)) -> TensorOf n r
fooLet (x, (y, z)) =
let w0 = x * sin y
in tlet w0 $ \w ->
atan2 z w + z * w
\s0 dret x y z ->
let x7 = sin y
x8 = x * x7
x9 = recip (z * z + x8 * x8)
x10 = negate (z * x9) * dret + z * dret
in ( tfromList []
, x7 * x10
, cos y * (x * x10)
, (x8 * x9) * dret + x8 * dret ) Relu. [Edit: greatly simplified just as Tom guessed it should be.] relu :: forall n r. (ADReady r, KnownNat n, Num (TensorOf n r)
=> TensorOf n r -> TensorOf n r
relu v =
let oneIfGtZero = tmap0N (\x -> ifB (x <=* 0) 0.0 1.0) v
in oneIfGtZero * v horde-ad/simplified/HordeAd/Core/TensorClass.hs Lines 296 to 298 in cc5c619
After we apply it to a variable of shape [3,4] (which probably can't be expressed in the user language without introducing special functions that take arguments only to compute their shapes, etc.), vectorization kicks in, first transforming the nested -- START of vectorization for term
tbuild1
4
(\x5 ->
tconstant
(tfromList [tconst 0.0, tconst 1.0] !
[ifB (m3 ! [i4, i5] <=* tconst 0.0) 0 1]))
-- END of vectorization yields
tconstant
(tgather
[4]
(tconst (fromList [2] [0.0, 1.0]))
(\[i5] -> [ifB (m3 ! [i4, i5] <=* tconst 0.0) 0 1]))
-- START of vectorization for term
tbuild1
3
(\v4 ->
tconstant
(tgather
[4]
(tconst (fromList [2] [0.0, 1.0]))
(\[i5] -> [ifB (m3 ! [i4, i5] <=* tconst 0.0) 0 1])))
-- END of vectorization yields
tconstant
(tgather
[3, 4]
(tconst (fromList [2] [0.0, 1.0]))
(\[i6, i5] -> [ifB (m3 ! [i6, i5] <=* tconst 0.0) 0 1])) The primal part function, after applying to a variable of shape [3,4], vectorization and the forward pass. (Apparently the variable name counter had different state than when vectorization log has been captured.) \s0 m3 ->
let m9 = tgather [3,4] (tconst (fromList [2] [0.0,1.0]))
(\[i7, i8] -> [ifB (m3 ! [i7, i8] <=* tconst 0.0) 0 1])
in m9 * m3 It's dual part, where LetR 3 (ScaleR (AstVar [3,4] (AstVarId 9)) (InputR (InputId 0))) Its gradient function. \s0 dret m3 ->
let m9 = tgather [3,4] (tconst (fromList [2] [0.0,1.0]))
(\[i7, i8] -> [ifB (m3 ! [i7, i8] <=* tconst 0.0) 0 1])
in (tfromList [], m9 * dret) Relu of multiplication by a reluT2 :: (TensorOf 1 (Ast0 Double), Ast0 Double)
-> TensorOf 1 (Ast0 Double)
reluT2 (t, r) = relu (t * tkonst 5 (tscalar r)) After we apply it to a variable of shape -- START of vectorization for term
tbuild1
5
(\x4 ->
tconstant
(tfromList [tconst 0.0, tconst 1.0] !
[ifB ((v3 * tkonst 5 (s0 ! [0])) ! [i4] <=* tconst 0.0) 0 1]))
-- END of vectorization yields
tconstant
(tgather
[5]
(tconst (fromList [2] [0.0, 1.0]))
(\[i4] -> [ifB (v3 ! [i4] * s0 ! [0] <=* tconst 0.0) 0 1])) The primal part function, after applying to a variable of shape [5], vectorization and the forward pass. \s0 v3 ->
let v6 = tkonst 5 (s0 ! [0])
v7 =
tgather
[5]
(tconst (fromList [2] [0.0, 1.0]))
(\[i5] ->
[ ifB
((let x11 = v3 ! [i5]
x12 = s0 ! [0]
in x11 * x12) <=*
tconst 0.0)
0
1
])
v8 = v3 * v6
in v7 * v8 It's dual part. LetR 10 (ScaleR (AstVar [5] (AstVarId 7)) (LetR 9 (AddR (ScaleR (AstVar [5] (AstVarId 6)) (InputR (InputId 0))) (ScaleR (AstVar [5] (AstVarId 3)) (LetR 8 (KonstR 5 (LetR 7 (IndexZ (LetR 6 (FromVectorR [ScalarR (Input0 (InputId 0))])) [AstIntConst 0] [1])))))))) Its gradient function, which looks misleading, because the BTW, note that some sharing (and its variables) come from the forward pass (these appear in the primal part function as well) while other sharing appears as late as in the transpose pass ( \s0 dret v3 ->
let v6 = tkonst 5 (s0 ! [0])
v7 =
tgather
[5]
(tconst (fromList [2] [0.0, 1.0]))
(\[i5] ->
[ ifB
((let x11 = v3 ! [i5]
x12 = s0 ! [0]
in x11 * x12) <=*
tconst 0.0)
0
1
])
v8 = v3 * v6
v9 = v7 * dret
v10 = tscatter [1] (tsum (v3 * v9)) (\[] -> [0])
in (tfromList [tconst 0.0 + v10 ! [0]], v6 * v9) Fortunately our Ast simplification handles this misleading trivial \s0 dret v3 ->
let v9 =
tconstant
(tgather
[5]
(tconst (fromList [2] [0.0, 1.0]))
(\[i5] -> [ifB (v3 ! [i5] * s0 ! [0] <=* tconst 0.0) 0 1])) *
dret
in (tkonst 1 (tconst 0.0 + tsum (v3 * v9)), tkonst 5 (s0 ! [0]) * v9) The same with reluMax :: forall n r. (ADReady r, KnownNat n)
=> TensorOf n r -> TensorOf n r
reluMax v = tmap0N (maxB 0) v horde-ad/simplified/HordeAd/Core/TensorClass.hs Lines 296 to 298 in cc5c619
After we apply it to a variable of shape [3,4], vectorization kicks in, first transforming the nested -- START of vectorization for term
tbuild1
4
(\x5 ->
tfromList [tconst 0.0, m3 ! [i4, i5]] !
[ifB (tconst 0.0 >=* m3 ! [i4, i5]) 0 1])
-- END of vectorization yields
tgather
[4]
(tfromList [tconstant (tkonst 4 (tconst 0.0)), m3 ! [i4]])
(\[i6] -> [ifB (tconst 0.0 >=* m3 ! [i4, i6]) 0 1, i6])
-- START of vectorization for term
tbuild1
3
(\v4 ->
tgather
[4]
(tfromList [tconstant (tkonst 4 (tconst 0.0)), m3 ! [i4]])
(\[i6] -> [ifB (tconst 0.0 >=* m3 ! [i4, i6]) 0 1, i6]))
-- END of vectorization yields
tgather
[3, 4]
(tfromList [tconstant (tkonst 3 (tkonst 4 (tconst 0.0))), m3])
(\[i7, i6] -> [ifB (tconst 0.0 >=* m3 ! [i7, i6]) 0 1, i7, i6]) primal \s0 m3 ->
tgather
[3, 4]
(tfromList [tkonst 3 (tkonst 4 (tconst 0.0)), m3])
(\[i8, i9] -> [ifB (tconst 0.0 >=* m3 ! [i8, i9]) 0 1, i8, i9]) dual LetR 5 (GatherZ [3,4] (LetR 4 (FromListR [ZeroR,InputR (InputId 0)])) <function> [2,3,4]) gradient \s0 dret m3 ->
let t12 =
tscatter
[2, 3, 4]
dret
(\[i10, i11] -> [ifB (tconst 0.0 >=* m3 ! [i10, i11]) 0 1, i10, i11])
in (tfromList [], t12 ! [1]) For comparison, here's the gradient Ast for the fast implementation of \s0 dret m3 ->
let m9 = tgather [3,4] (tconst (fromList [2] [0.0,1.0]))
(\[i7, i8] -> [ifB (m3 ! [i7, i8] <=* tconst 0.0) 0 1])
in (tfromList [], m9 * dret) Relu (using reluMaxT2 :: (TensorOf 1 (Ast0 Double), Ast0 Double)
-> TensorOf 1 (Ast0 Double)
reluMaxT2 (t, r) = reluMax (t * tkonst 5 (tscalar r)) After we apply it to a variable of shape -- START of vectorization for term
tbuild1
5
(\x4 ->
tfromList [tconst 0.0, (v3 * tkonst 5 (s0 ! [0])) ! [i4]] !
[ifB (tconst 0.0 >=* (v3 * tkonst 5 (s0 ! [0])) ! [i4]) 0 1])
-- END of vectorization yields
tgather
[5]
(tfromList [tconstant (tkonst 5 (tconst 0.0)), v3 * tkonst 5 (s0 ! [0])])
(\[i5] -> [ifB (tconst 0.0 >=* v3 ! [i5] * s0 ! [0]) 0 1, i5]) primal \s0 v3 ->
let v6 = tkonst 5 (s0 ! [0])
in tgather
[5]
(tfromList [tkonst 5 (tconst 0.0), v3 * v6])
(\[i7] ->
[ ifB
(tconst 0.0 >=*
(let x14 = v3 ! [i7]
x15 = s0 ! [0]
in x14 * x15))
0
1
, i7
]) dual LetR 14 (GatherZ [5] (LetR 13 (FromListR [ZeroR,LetR 12 (AddR (ScaleR (AstVar [5] (AstVarId 6)) (InputR (InputId 0))) (ScaleR (AstVar [5] (AstVarId 3)) (LetR 11 (KonstR 5 (LetR 10 (IndexZ (LetR 9 (FromVectorR [ScalarR (Input0 (InputId 0))])) [AstIntConst 0] [1]))))))])) <function> [2,5]) gradient \s0 dret v3 ->
let v6 = tkonst 5 (s0 ! [0])
m11 =
tscatter
[2, 5]
dret
(\[i8] ->
[ ifB
(tconst 0.0 >=*
(let x9 = v3 ! [i8]
x10 = s0 ! [0]
in x9 * x10))
0
1
, i8
])
v12 = m11 ! [1]
v13 = tscatter [1] (tsum (v3 * v12)) (\[] -> [0])
in (tfromList [tconst 0.0 + v13 ! [0]], v6 * v12) Our Ast simplification pass can deal with the trivial \s0 dret v3 ->
let v12 =
tscatter [2, 5] dret (\[i8] -> [ifB (tconst 0.0 >=* v3 ! [i8] * s0 ! [0]) 0 1, i8])
! [1]
in (tkonst 1 (tconst 0.0 + tsum (v3 * v12)), tkonst 5 (s0 ! [0]) * v12) For comparison, here's the gradient Ast for the fast implementation of \s0 dret v3 ->
let v9 =
tconstant
(tgather
[5]
(tconst (fromList [2] [0.0, 1.0]))
(\[i5] -> [ifB (v3 ! [i5] * s0 ! [0] <=* tconst 0.0) 0 1])) *
dret
in (tkonst 1 (tconst 0.0 + tsum (v3 * v9)), tkonst 5 (s0 ! [0]) * v9) |
I'm not sure what you mean with this. If you have lambda abstractions in your trees, the moving upwards might cross lambda abstractions, but it's still just moving the expression higher up the tree. E.g. bla (bla (n + tletR1 e) bla) foo (foo (tletR1 e + bar) foo)
~>
let x = e
in bla (bla (n + x) bla) foo (foo (x + bar) foo) in which
TODO think about this
Yes lol, you're right. Err, yeah. ¯\_(ツ)_/¯
:D About your pretty-printed story:
My hand-written, systematic version looks like this: revFooTom :: r -> (r, (r, r)) -> (r, (r, r))
revFooTom dt (x, (y, z)) =
let s = sin y
w = x * s
a = atan2 z w
b = z * w
c = a + b
dc = dt
da = dc
db = dc
dz1 = da * w / (z * z + w * w)
dw1 = da * -z / (z * z + w * w)
dz2 = db * w
dw2 = db * z
dw = dw1 + dw2
dz = dz1 + dz2
dx = dw * s
ds = dw * x
dy = ds * cos y
in (dx, (dy, dz)) which simplifies to this, if you inline definitions if they either just rebind a name, or are used at most once: revFooTom :: r -> (r, (r, r)) -> (r, (r, r))
revFooTom dt (x, (y, z)) =
let s = sin y
w = x * s
dw = dt * -z / (z * z + w * w) + dt * z
in ( dw * s
, ( (dw * x) * cos y
, dt * w / (z * z + w * w) + dt * w
)
) Why do you compute Also, your (4) version includes the snippet Re your relu examples, what are the lessons we should draw from this? It's good to have examples, but the only thing I'm lifting from it is that 1. scatter with invertible indexing function needs to be simplified (perhaps to a gather), and 2. it would really be nice if we can do some fusion after AD again so that snippets like tgather [5] (tconst (fromList [2] [0.0,1.0]))
(\[i5] -> [ifB ((let x12 = s0 ! [0] in
let x13 = x3 ! [i5]
in x13 * x12) <=* tconst 0.0) 0 1]) don't persist into code generation but instead become something like this: ifB (((x3 ! [i5]) * (s0 ! [0])) <=* tconst 0.0)
(tkonst 5 (tscalar 0.0))
(tkonst 5 (tscalar 1.0)) and subsequently this: tkonst 5 (tscalar
(ifB (((x3 ! [i5]) * (s0 ! [0])) <=* tconst 0.0)
0.0 1.0)) or something. |
I think we've already discussed that on the call a week ago? I've shown you these snippets: tfromList [tgather sh v (\i -> tLetInt n0 (i + e)), tgather sh u (\i -> tLetInt n0 (i + e))] let x i = i + e in tfromList [tgather sh v (\i -> tLetInt n0 (x i)), tgather sh u (\i -> tLetInt n0 (x i))] and we agreed the real problem here was too little exposed sharing and that with enough sharing lambda-abstraction is not necessary to lift sharing up the tree.
Well, sure, this is simple as long as the shared expression does not contain the lambda-abstracted variables, right? But I think we talked about that.
OTOH, there is a merit in leaving delta evaluation as it is for now and mentioning in the paper "it's the same as in the POPL paper, only extended and expresses in the language of In the long term, surely reworking this would give us new insight and the (relatively small) new sharing introduced by delta evaluation could perhaps be inserted locally, not at the top level, where the large amount of sharing coming from the forward pass has to reside (if we stick to simple dual number representation).
Perfect. This is equal to the snippet shown below "rev of fooLet (instantiated to TensorOf 0 r) printed with loss of sharing."
This is interesting. I'd like it to work without
Well spotted. I've corrected it now. My mistake when manually formatting and, at the same time, tweaking pretty-printing and updating the results.
I'm sure @awf will draw some lessons. Myself I'm curious what's the most efficient way to define
Me, similarly, but I wanted to introduce
I do apply the full simplification pass and show the result below the remark "Unfortunately, at this time we don't have any scatter simplification rules", but the result is very unsatisfying.
Unfortuantely we don't rewrite gathers into conditionals, even just because we don't have (float) conditionals in the syntax (also, sadly
This is a pretty cool simplification, but it's non-local, which makes it costly --- you have to look at the two branches and ensure they have the same tree prefix (or that they can be rewritten, perhaps complicating them, to have the same prefix). |
Oh that's what you meant. Yes, no please don't introduce lambdas while lifting. Indeed we talked about this, and the answer was that things become really annoying with scoped global sharing identifiers. :D (At least until you eliminate them and make them local sharing, if at all)
Stable names to get global sharing, then the typed sharing recovery algorithm to convert to local sharing. Accelerate literature calls this whole procedure "sharing recovery" instead of just the second half.
Oh oops you're right, that was an invalid transformation. Still the intent was not to turn gather into conditionals but to commute the ifB and the gather, assuming (incorrectly) that the condition was independent of the index variable.
Yeah, same commutation operation, just the other way round. In any case the same snippet would be more neatly written as: tbuild [5] (\[i5] -> ifB (x3![i5] * s0![0] <=* tconst 0.0) 0.0 1.0) which is indeed a |
I've eliminated the 20% and 80% slowdown at the cost of some extra hacks (and one more Ast constructor). The quadratic complexity is still there in pathological cases, e.g., when the API user takes Ast primal parts of dual numbers and embeds them in new dual numbers an enormous number of times. So perhaps the terrible top-level-local sharing may stay for now, while we are still taking over the world and so our resources are limited.
Done. :)
Got it. Now we only need a volunteer for #102. This is pretty standalone, I think.
This is complex, but makes sense. I've seen that |
I'm now leaning towards replicating the trick we have in Delta terms (Tom's Tree0 method --- a stamp per each term node). Then we could add the same stamp to the Delta expressions created from Ast terms.
There are benefits. The user doesn't have to worry and
let
from Haskell implies this permanent sharing (not only temporary sharing in memory). We get stamps in fewer Delta terms, which does not incur a risk of sharing violation as long as interpretation of Ast terms don't duplicate newly created Delta terms (but it can duplicate delta terms that will be received from interpreted Ast subterms).There are some drawbacks . We get many stamps in Ast terms: one stamp per Ast node. The implementation of interpretation of Ast in Tensor can't just have one extra case for
AstLet
, but instead needs to keep a table of interpreted subterms. Moreover, interpretation in an ADVal instance of Tensor has to pass on the stamps, so maybe this interpretation can't share code with all the others. I haven't thought through Vectorization nor Simplification, but given that they can incur duplication, we can't introduce sharing after they are complete. So they probably need to keep tables as well. I'm not sure if we can get away with stamps only in tensor terms, given that integer terms suffer from blowup when gathers are fused and consequently large terms are substituted for integer variables.The text was updated successfully, but these errors were encountered: