-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip] resnet batchnorm backward fusion spec #4370
base: master
Are you sure you want to change the base?
Conversation
thanks - adding this to the scheduler roadmap! |
# easy case: merge 4 reduces in backward into 1 | ||
# double reduce case: merge stat calculations from 2 to 1 (be careful of long reduces!) | ||
# sum(x - \bar{x}): one kernel just calculates this, can be eliminated | ||
# pre-expand fusion: is it fast? -2 kernels possible, 1 fw, 1 bw |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this refer to E_2_16_64n1
+E_2048
(graph ref: https://tiny-tools-client.vercel.app/?id=f7b72a41bad14974970329924c89b2c0)
?
#4235 could do this, it won't because <LB METAL (2, 16, 8, 8) float (<UnaryOps.CAST: 3>, None)>
is forced_realize. I think it breaks the API if we fuse a forced_realize parent with its child.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am referring to E2_16_64n1 (full forward with relu) and E2_16_64 (full backward through batchnorm). The first can be fused with the next conv, and the latter can be fused with the next backward conv. (E_2048 simulates the backward from the next layer, plus relu backward)
This test case does not have the convs to focus on batchnorm, so it cannot happen here. will add more cases.
This branch currently is behind tinygrad/master. The line count difference bot is disabled. |
Added detailed behavior spec. The fusion decision for the parallel reduces should be straightforward and "free" performance wise, but fusing conv(a + b) may be bad in some cases. Need a heuristic to decide when a buffer counts as a "big" buffer, and when one is a "small" buffer. The specs so far can remove 8 out of 14 extraneous memory passes in bn(conv2d).relu(), with an estimated time saving of 33ms on BS=256 resnet. (Edited because I posted fake news) |
the scheduler change is a little tricky, since you need to make sure that each grouping is a contiguous sub-DAG. My solution to this is currently to do the grouping while toposorting, which should work for the specific bn training case, but is it possible to make it clean? Probably deferring contiguous reduces until you run out of nodes in queue then grouping them would work. |
I need to think about the scheduler change a bit more, but in general we don't wanna do merge schedules, if there is grouping to be done it should be here https://github.com/tinygrad/tinygrad/blob/master/tinygrad/engine/schedule.py#L225-L228 |
new_arg = MemBuffer(new_lbs.index(old_lbs[ast.arg.idx]), ast.arg.dtype, ast.arg.st) if ast.op in [BufferOps.LOAD, BufferOps.STORE] else ast.arg | ||
return LazyOp(ast.op, tuple(_replace_bufis(x, old_lbs, new_lbs) for x in ast.src), new_arg) | ||
|
||
def _merge_prescheduled(prescheduled: List[_LBScheduleItem]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've gone through this route in multioutput,
tinygrad/tinygrad/engine/schedule.py
Line 86 in 6c2cb8e
def _schedule_outputs(outs:List[_LBScheduleItem], reduce_for_op:Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem: |
I think you need to rebuild the entire AST.
i implemented i am prototyping with merge_prescheduled because i need to toposort to find these fusion opportunities (i don't see a way to analyze the graph locally to find them), and I need shapetracker information to match (lazybuffer, st) read pairs, conveniently provided by preschedule. the rules as implemented are a little in the style of "performance heuristic" though, which is a little different from the other rules we have. is it possible to move back to pure scheduling land? |
I think all of your fusion targets are children of https://tiny-tools-client.vercel.app/?id=3ef8c4a72b0c4999acca0dff9288b2fa could traversing its local graph work? |
Some of them are also children of the forward pass. How can we tell if there is a path forward -> BN forward -> stuff -> fusion targets so that we don't fuse bn forward and backward? The first attempt did toposort + local children. But if you don't have all inputs before E_2048 (with BN we are lucky), you will have to get lucky with the toposort order (most of the tests will not pass) |
# match by input + ST and two shapes? start with contigouous input only, check shapes (should determine reduces) | ||
|
||
# what if same input + st but one is early and another is late? | ||
check_schedule([x.sum(0, keepdim=True) + a, (a + b).sum()], 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a real-world case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could be... maybe if you have a bias weight and
out.sum(0) + bias -> next layer
(bias**2).sum() -> LARS
?
check_schedule([sum1, (x + sum1).sum()], 2) | ||
del sum1 | ||
|
||
# super tricky crossing dag case |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The (conservative) heuristic I am using is that this fusion should never add extra loads from bijective shapetrackers. If a shapetracker is bijective, then its size matches the full_shape of the kernel, and all non-bijective loads must be from smaller buffer(region)s. In the normal case, the non-bijective "small" buffers are from expands and are very small compared to the bijective ones (here it's 1/16), so adding these won't hurt. Here, fusing the diagonal will save 1 memory pass over a big buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, for simple reduces like these from bijective shapetrackers, it should be fine to fuse many unrelated reduces. Simple reduces don't really need a lot of cache -- the cache really helps when you have expands like (1, a) * (b, 1), since you can do an nm-sized tile with only n + m loads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this may even be a real world case -- consider x and y to be the forward outputs of different layers.
If you fuse those targets the doesn't the cache fill up with a bunch of the "stuff" bufs? We wanna fuse if they're sharing parents. |
we need to allow small "stuff"s (the bn backward takes some inputs from bn forward). See the argument for the bijective heuristic above |
hm, i think one of these kernels has a superset of "stuffs" across the rest of the fusion targets. i think that makes it safe to not check the "stuffs" 🤔 actually no, it doesn't , since one of the "stuffs" that only the superset kernel has could be a descendant of the rest of the fusion targets. |
This reverts commit 7875b26.
This is cool. Can we get some of these tests merged? (even if they are disabled for now) |
small example for easy inspection for now