PyTorch Developer Podcast

CUDA graph trees

Episode Summary

CUDA graph trees are the internal implementation of CUDA graphs used in PT2 when you say mode="reduce-overhead". Their primary innovation is that they allow the reuse of memory across multiple CUDA graphs, as long as they form a tree structure of potential paths you can go down with the CUDA graph. This greatly reduced the memory usage of CUDA graphs in PT2. There are some operational implications to using CUDA graphs which are described in the podcast.

Episode Transcription

Hello, everyone, and welcome to the PyTorch Dev Podcast. Today, I want to talk about CUDA Graph Trees, our CUDA Graph integration with PyTorch 2. Most of this was work done by Elias Ellison, so kudos to him for actually building all of this.

First off, let's remind ourselves what CUDA Graph Trees are. I do have a podcast about it, so if you want to know more details about CUDA Graphs itself, you can go there. But CUDA Graphs is essentially a way to remove overhead from applications that are calling CUDA kernels. By saying, hey, instead of running all of the possibly very expensive host code that glues a bunch of CUDA kernels together, we just smash it all into a recording that just runs the CUDA kernels one after another exactly the same way that they were run before.

So in PyTorch Eager, we have an API for using CUDA graphs called makeGraphCallables. And it basically does exactly what you would expect. It will go ahead and CUDA graph record your function, and you will get exactly what you asked for. And maybe this is what you want. Maybe it isn't.

It's actually kind of hard to use CUDA graphs in a lot of situations, right? You have to make sure that there's no CPU compute in your program. There's nothing that varies from run to run. There's no unsafe calls to unsafe operators. Those will just cause CUDA Graph recording to fail because CUDA Graphs will say, no, no, no, you can't read out things from CPU. When you are passing in the inputs to CUDA Graphs, they all actually have to be static addresses because those are being burned into your CUDA Graph. You know, if you have an input, you have to make sure you copy that into a fixed buffer. All of this needs to be handled by hand.

You can do it. If you're very motivated, and people are often very motivated and will manually CUDAGraph their code. But one of the things that we wanted to do with PyTorch 2 was to make it easier for people to get this overhead reduction without having to go through all this rigmarole.

And of course, PyTorch 2 actually does help a lot with overhead reduction intrinsically because we're in the business of taking your models, factoring out all the Python code, so we don't actually have to run any of your Python code. We only have to run the residual bytecode afterwards that does exactly the Python state updates we need. And by fusing kernels together we reduce fixed costs because, well, you know, the less kernels you're running the less overhead you have to do in this case.

But it's still the case that for a lot of really overhead bound models with very very small compute and lots and lots of operations, it turns out CUDA graphs still gives you a pretty sizable efficiency improvement even when you're using the PyTorch 2 compiler. And this is something that we could reduce the overhead of PyTorch execution even in PyTorch 2. There are things that we could do. But CUDA graphs is just, there's nothing faster than zero, right? When you run with CUDA graphs, there is no host site overhead by construction because you're going straight to running the CUDA kernels one by one by one.

So CUDAGraphs is cool, and I want to turn back the clock back to the eve of the PyTorch 2 release, and we're having a call. I actually, I remember, I was driving home from, you know, doing some maintenance on my Tesla, so I was on the highway and I was phoned in to a conference call we were having, which was basically the question was, what are we going to do about CUDAGraphs?

And the problem was, CUDAGraphs, we could tell from our benchmarks, made things a lot faster when we were running PyTorch 2, but they used too much memory. Why did CUDA graphs in PyTorch 2 use too much memory?

The problem was related to graph breaks. Specifically, let's imagine that you've got your model and there's some graph breaks. You know, you've got graph 1, graph 2, graph 3. Obviously, we can't CUDAGraph the entire thing because we have no idea what's going on between the two graph breaks. So instead, we CUDAGraph each graph separately.

And when you CUDAGraph each graph separately, well, you know, how exactly does CUDAGraphs work? Normally, the way CUDAGraphs works is you end up with an isolated CUDAGraph block which contains enough memory to store all the inputs, because remember, it's all static addresses, right? So the next time you call this CUDA graph, you have to give it tensors in exactly the same location they were last time.

To make sure you actually have those addresses available, you have to actually keep that memory around. So for every CUDA graph, you have a big CUDA memory allocation, which has enough space for all the inputs, module parameters, because parameters you can just assume have some static location, and they don't change, so everyone can reference those static addresses. It has all of your input space, and also enough space to do all of the intermediate compute you might do, because obviously in the middle of your graph, you may do allocations, and those allocations also are going to have hard-coded addresses. And you need to have them in your CUDA graph.

So when you have three CUDA graphs, what you end up having is 3x the amount of memory you need because each CUDA graph has its own pool of memory that's disjoint from the other ones being like, hey, this is the memory that I need to actually do my compute because I've burned in all these static addresses and so I need to reserve it for myself when I do it.

And this is very memory expensive because when I ordinarily run my program in eager mode, I don't have this hoarding behavior, right? When I'm doing stuff with the CUDA caching allocator, I ask for some memory, I use it, when I'm done I return it to the CUDA caching allocator and it's allowed to send that memory off to someone else so that they can use it for something else. But these CUDAGraphs can't actually do that. They have to hold onto the memory because the next time you call it they need to make sure that memory is actually available for them to actually do things.

Okay, so CUDAGraphs was using up too much memory. We were like, oh my god, you know, what are we going to do about this? And then we're like, okay, well, we're going to do this. How are we going to launch PyTorch 2 with a version of CUDA graphs that takes up this much memory?

And we were thinking of ideas for how to do this. One of the ideas that we had was, hey, when you do normal eager mode, we're willing to reuse memory allocations between CUDA graphs. So there's nothing stopping you from reusing the memory allocations between separate CUDA graphs, right?

So remember, all the CUDA graph is doing is saying, hey, there's a static address, and the memory in the static address needs to be available when I use it. So if, for example, the CUDA graphs get called in exactly the same order every single time, then what you can just do is say, okay, well, this memory is no longer being used. I needed it for the first graph, but I'm no longer using it anymore at this point. Let me go ahead and use it for something else when I'm running my second CUDA graph. And I don't need to actually do the sum of the intermediates of all three graphs. I can do reuse. So my memory usage looks a lot more just like what the high watermark memory usage used to be.

But there's a problem with this. And the problem with this is when you have a graph break in PyTorch 2, you can't actually guarantee that the same graph will be called next, right? Because maybe the reason you did a graph break was because the user had a dynamic conditional which is going to shunt you between one graph or another graph. So if you do all this memory reuse... And then suddenly some other graph gets called, well, uh-oh, you know, maybe some memory that you were expecting to be available is no longer available and you're in trouble.

But there is a maybe obvious next step to do in this case, right? Which is what if when we diverge between the two CUDA graphs, we simply imagine that, well, there's two paths we can take. So at the time I do memory allocation and I'm done with the first graph, you know, the memory allocator is in some state. And then depending if I go to graph two, then I will do some things based on graph two. But if I go to graph three, instead, I'll do some other things and sort of imagine like in one of those like time travel movies where you make a decision. And depending on decision, the future branches off into two possible different futures, we just want to do the same thing for CUDAGraphs.

And this leads to this concept of CUDAGraph trees. And this is what we actually implemented in PyTorch 2. CUDAGraph trees completely solve the problem of memory reuse in CUDA graphs because we simply say well it's a choose-your-own-adventure. The memory usage you're going to end up using for the CUDA graph recording is the maximum of the memory usage for all the possible branches you could take. But because we are only allowed to evolve the CUDA graph in the sort of paths on the tree, every path on the tree is going to have a consistent allocation-deallocation pattern.

As long as I go down that same path, I can just simply reuse exactly the same memory addresses as before. And if I take a different path, well, that path is on its own execution. And I'm guaranteed not to change my mind and go down another path of the tree. Each of these paths are self contained. And then eventually, I get to the end of my training loop iteration, I go back to the beginning, and ostensibly, usually, when you're done with a single training step, all your memory is done, and so everything can be assumed to be cleared, and you can start going, reusing things again.

So this is the basic concept of CUDAGraphs, right? CUDAGraph trees. The main idea is we want to reuse memory across graphs. By reusing memory across graphs, we get rid of the big memory usage used by CUDA graphs. And the tech you have to build to actually do this is some sort of ability to checkpoint the state of the memory allocator. So that if you're like, hey, I'm running CUDA graphs, and I want to record if it goes this way. And I also want to record if it goes some other way, I need a way to reset the state of my allocator to what it was at that point in time, so that I can go ahead and then do a bunch of other allocations and deallocations based on what I see in the next case.

Okay, so that's the basic implementation idea behind CUDAGraph trees. There are some operational implications to how we've implemented CUDAGraph trees. One of the discussions that Les Cano opened up on GitHub is, hey, maybe we should turn on CUDAGraph by default. This mode reduces overhead when you're running PyTorch compile. And maybe this is a good idea. We're a bit nervous about it.

And the reason we're nervous about it is that although CUDAGraph trees is pretty good at what it sets out to be, which is a way of using CUDAGraphs where we can basically let you say, okay, just try reduce overhead. And PyTorch 2 is going to take care of dealing with all the safety conditions you need, right? We don't have a problem if you are doing CPU compute or unsafe operations, because, hey, we're PyTorch 2. We're actually getting a graph, and then we can go look at it and say, are you doing any compute on CPU? Are you calling non-zero, and then we can just disable CUDA graph, if those things have happened.

And because we're PyTorch 2, we also keep track of all the inputs. And we know, oh, these inputs are parameters, so we can statically bake them in. These inputs are just regular eager inputs, so we're going to allocate dedicated buffers for them in the CUDAGraph memory pool and copy them in. And we do all of this for you because we have a pretty deep understanding of what is going on in your code because, hey, having graphs is great.

And furthermore, because CUDAGraph trees have this sort of choose-your-own-adventure style properties to them, we can even do this in the presence of graph breaks. Obviously, your code inside the graph breaks is going to run slow, but all the stuff inside PyTorch 2 is actually going to run fast.

But this safety, this abstraction is not complete. One of the big things that you have to be aware of is that when eventually when we're doing CUDAGraph trees, we want to sort of stop the tree and go back to the beginning of the tree, right? If we always keep appending more kernels onto the tree, this is kind of pointless because if you're continuously recording new CUDAGraphs, you're never getting the benefit of replaying the CUDAGraph, right? You only get the benefit of CUDAGraphs when you actually have a pre-recorded CUDAGraph and you replay it again.

So at some point we had to be like, okay, we're done recording. We're going to go back to the root of the tree and now we can follow a path. And hopefully that path has all Cuda graphs we've already recorded. So we can go zip zap very fast.

So, when we restart the tree, when we go back to the root of the tree, we now have the big constraint which is that we actually need to have freed all the memory associated with the CUDAGraph memory pool, because we're going to go stomp over it again in an unpredictable way when we start using the memory again.

And I said usually user code is written so that this is a problem, but you can get it wrong, right? So if I hold onto a tensor that is an output of a CUDAGraph tree, then, well, that tensor, if it stays live, is going to inhibit CUDA graph trees from actually being able to be used as a member because we don't want to stomp over the data and then you get a bunch of garbage in one of these tensors that's hanging around.

Another problem that is sort of non-transparent with CUDA graph trees is what happens when you have mutations on input tensors. Remember that I said when you do CUDA graphs on an input tensor that doesn't have a static address, we go ahead and copy it into the CUDA graph. So once you've copied it into the CUDA graph, that's a separate allocation for the input in question. So if your programming question actually goes ahead and tries to mutate this, it will mutate CUDAGraph's internal representation of the memory in question, but it won't mutate the actual original user input, which may have been allocated in eager mode. We don't do an unsafe thing in this case. We actually just cancel CUDAGraph trees in this situation. But you know, if you're just applying CUDAGraph trees to some random code that you haven't actually looked at, it's possible that it doesn't actually work because there are things that look pretty benign. And we have gotten past them with graph breaks, but then they just inhibit CUDA graphs from working.

You kind of like if you're like, oh, I think my model actually should run with CUDA graphs, then you have to actually look and see if CUDAGraph is actually running when you turn it on with PyTorch 2. Because the chances are we actually may have turned it off for any number of reasons, some of which are just fundamental framework limitations, but not limitations from you, the user. It's probably not too difficult to adjust your code to handle this case.

Finally, CUDAGraph trees are not free, right? They do change the cost model of your program. I already mentioned one of the things that changes, right? When you have a CUDAGraph tree and you have a lot of branches, ordinarily you only use up the memory associated with the branch when you go down that branch. But a CUDAGraph tree is going to have a standing allocation which represents the maximum memory usage of all the branches you could possibly take. You better not be relying on the fact that, well, sometimes my memory does go up, but it doesn't happen all the time, and therefore there's something okay in this case. You're just always going to pay the worst case memory usage in this case.

Also, your memory usage is going to be worse than it would have been in eager mode, because when you CUDAGraph things, those CUDAGraph allocations have to go in a separate memory pool than the eager memory pool. If you're running everything in eager, the CUDA caching allocator may be able to make better use of your memory by serving things from a shared pool. But when you separate the pools, your memory usage generally gets worse because you've got two pools. So if something is free in one pool and you need an allocation in the other pool, that doesn't work. You have to just go ahead and allocate in it.

So there can be some memory inflation in this case. And finally, CUDA graphs, in the worst case scenario, could make your model run slower. And that's because for all the inputs which don't have fixed memory addresses, we have to copy them into the CUDA graph region. This is copy that you didn't have to do in normal eager mode.

Now, in a situation where you don't have lots of graph breaks, you have like one graph and your inputs are not too big, this is usually not a big deal because the savings from CUDA graphs more than out swamps the fixed one-time cost. But if you have a lot of graph breaks, then this can potentially be a performance problem.

So maybe we can turn on CUDA graphs by default, but we'll probably have to work on it some more because definitely CUDAGraphs is one of those things where you need to turn it on and then see if your model is doing what you expect or not.

There are some ideas that we have for improvements to CUDAGraphs. Some of the limitations that I talked about, such as mutations to input tensors, in principle can be fixed with just more engineering.

Another tool that we have in our toolbox is re-recording CUDAGraphs. We actually have already implemented this for dynamic shapes. The idea behind dynamic shapes in CUDAGraphs is that normally this doesn't work with CUDA graphs, because a CUDA graph has everything burned in, including the sizes of the tensors in question. So you can't have a single CUDA graph that works for multiple dynamic sizes. But what you can do in this case is, for every dynamic size you see, you could re-record the CUDA graph using the same dynamic kernels that the Inductor had generated, but just with a different size in question.

And this is profitable because it's a lot cheaper to re-record a CUDA graph whose cost is on order of how fast it takes to run the model than it is to actually do the entire PyTorch 2 recompilation again, which is pretty expensive. In part because compile times are slow, but it's just going to be a lot more work. You're actually generating kernels and stuff like that.

So this is something that we can do to work around problems that CUDA graphs have. And another case that Animesh and Laith have been looking into is, hey, you know, we probably also want to do re-recording of CUDA graphs if we have a CUDA graph that is referencing a lot of parameters, but actually we have a lot of different parameters.

And a common situation this occurs is, let's say you have a bunch of transformer blocks in your model, and you're only compiling the transformer block, and you want to CUDA graph the transformer block. So it would be nice if you could have a single compile product that works for all of the transformer blocks in your program. But in this case, the parameters for these blocks are different. And if you naively CUDA graph it, then you would have to do a copy-in on the parameters, which is generally a terrible idea. Unless it's a diffusion model, because apparently, according to Dima, who I was talking to about this, diffusion models don't have as much of a problem with doing this copy-in.

So, to deal with this problem, what you can do is you just re-record the CUDA graph for each individual block. So the compilation cost is you compile once with a version of your model that can work for an arbitrary parameter, and then for every particular transformer block, you re-record the CUDA graph with the new static addresses for each parameter. And then once you've done however dozens of transformer blocks, then they can all be reused. And this doesn't cause memory usage because you're going to reuse the same memory for each of these recordings.

Okay, so I hope that told you a little bit about what to expect with CUDAGraph trees. This is what happens when you do mode reduce overhead in PyTorch Compile. That's everything I wanted to say today. Talk to you next time.