Compiler collectives are a PT2 feature where by compiler instances across multiple ranks use NCCL collectives to communicate information to other instances. This is used to ensure we consistently decide if inputs or static or dynamic across all ranks. See also PR at https://github.com/pytorch/pytorch/pull/130935
Compiler collectives are a PT2 feature where by compiler instances across multiple ranks use NCCL collectives to communicate information to other instances. This is used to ensure we consistently decide if inputs or static or dynamic across all ranks. See also PR at https://github.com/pytorch/pytorch/pull/130935
Hello everyone and welcome to the PyTorch Dev Podcast. Today I want to talk about Compiler Collectives, a new feature in PyTorch 2 compilation which allows the compiler to communicate to other instances of the compiler on other ranks in distributed training in order to communicate information that may be useful to other nodes in the training.
To explain why Compiler Collectives are useful, I first need to recollect a particular problem that we encountered in our production deployment of PyTorch 2 inside Meta. The problem looked something like this. Occasionally, we would have jobs that were running with PyTorch 2 enabled, and they would NCCL timeout. Now, NCCL timeouts occur whenever you have a NCCL collective, and some of the collectives just have to wait too long for a result. There's a timeout because one of the common reasons why you wait too long is because there's a deadlock or it's never actually going to finish. We have a timeout to make sure we actually kill the nodes and make sure we release resources in this situation.
In this particular case, we were NCCL timeouting. The first thing you do when you have a NCCL timeout is you go and look and see what the heck all the jobs were doing at the time they crashed. We noticed that in this particular case, some of the jobs were doing compilation. Why were some of the jobs compiling code in Torch Compile while other ranks were just waiting in the network collective?
Using TLParse, a log parser that we have for PyTorch 2, which can tell you what was going on on all the nodes (see a previous podcast for more information), we noticed that the ranks that were compiling were actually doing an extra recompilation that the other ranks were not. Some of the ranks had just gone ahead and run the code and got all the way to the collective, and this poor unlucky rank was actually recompiling.
Further inspection of the trace revealed that the reason why this rank had decided that it needed to be compiled was that there was some particular input to one of the graphs that it had compiled and that input had changed. The graph that the node had previously compiled was actually static - it had thought that the size of the input at that location was static. When you compare this to the other nodes, those other nodes had already compiled a dynamic node for them.
What had happened here is a consequence of something that we call automatic dynamic. Automatic dynamic in PyTorch 2 says, hey, we don't know whether or not your inputs are static or dynamic unless you explicitly tell us. If you don't tell us, we will assume that all of your inputs are static. Depending on what we see at runtime, if you pass us a tensor with size five, and then you pass us a tensor of size seven on the second run, we will realize, oh, actually, it looks like you want this to be dynamic. We will recompile your graph it is dynamic in this case.
The problem is that most of the other nodes had gone in particular inputs that had varied between the first run and the second run, say going from 5 to 7, and they had all recompiled with that input being dynamic. But the unlucky node, the node that was actually recompiling at the time of the NCCL timeout, had actually unluckily gone in an input of exactly the same size both instances, and it had happily assumed, well, let's just keep it being static. It only got caught out not being prepared to deal with it at the end of the next run when they suddenly had to recompile.
When we first ran into this problem, I thought, oh my god, automatic dynamic was a mistake. Except that it's not really a mistake because if we didn't have automatic dynamic, then this model would not have compiled at all. It's a useful mistake, but in some sense it's architecturally a bit questionable, because the whole point of SPMD distributed training is you want all the nodes to be doing the same thing. You want all the nodes to be compiling at the same time. It was really bad if one of the nodes is recompiling.
Even if we adjusted the NCCL timeout so that we didn't timeout, because if you wait long enough, then the recompiling node will eventually get to the end and you will be able to make progress, it's still not optimal for this divergence to happen because all of the other ranks are waiting for this one straggler rank to finish compilation. There was another ongoing problem with our production deployment where things that were supposed to compile in 30 minutes were actually taking two hours to compile. That really exacerbated the problem a lot in that particular case.
When we noticed this problem, it was kind of an interesting issue, and it wasn't entirely clear what we should do about it. Should we go ahead and force a bigger timeout in this situation? Should we do something else? The one solution that we settled on, which was a balance of being easy to implement and not requiring too many extra constraints from the user, is this thing called compiler collectives.
Compiler collectives are an abstract idea. The abstract idea is, hey, when I am doing compilation on my PyTorch 2 process, let me actually assume, and this is a new assumption, that everyone in the group inside my training job is compiling at the same time. Now, you can assume everyone is compiling at the same time, then what I can do is during compilation, I can do a collective to all the other nodes to basically tell them, hey, what's going on?
This is an abstract idea. You can use this for all sorts of things. But what we're going to use it for and to solve this particular recompilation problem is this: We are going to say, hey, have all the ranks talk to each other whenever you see an input, a tensor input. And when you see the tensor input, I want you to tell all the other ranks what you saw the size of that tensor input.
In this particular case, what happened was the input that was dynamic actually is variable across all of the ranks because it's some sort of data-dependent size. This is like a recommendation model, so there's a sparse feature going on and not all the ranks are getting the same sizes. When you have this situation where it's unbalanced across all the ranks, if all ranks talk to each other to try to figure out what's going on, they'll say, hey, actually, everyone has a different size for this. So maybe, even though this is the first time I've run, and I don't necessarily know what the size of the rank should be, let me just go ahead and make it dynamic.
More importantly, because all the nodes are talking to each other, we can ensure that they consistently decide whether or not a particular input should be dynamic or static. In this way, we either never recompile, or if every rank happens to be unlucky and sees the exact same size input, iteration 1 and iteration 2, everyone recompiles at the next stage. Hey, that's not great, but at least you're not going to NCCL timeout because everyone is still doing the same thing.
Now, I slightly lied in this explanation. I suggested that we do a communication every time we see a tensor input. But you don't really want to do that, because communications are expensive. You want to batch them together, typically. What we actually do is we run the Dynamo tracing process to the very end of the region we want to compile, collecting up all the inputs we've seen along the way. Then at that point in time, we go ahead and do the collective, have everyone talk, be like, hey, here are the sizes of all of the inputs I've seen.
Because dynamic tracing is something you sort of can't do retroactively, you need to have made the decision to make something dynamic at the very beginning, we just tell everyone to go ahead and restart your Dynamo analysis. And this time, make decisions about whether or not inputs are dynamic or not based on this compiler collective.
We actually already have this restart capability. We use this restart capability to deal with graph breaks because when a graph break happens, if we're in the middle of some inline call stack, we actually don't have the ability to graph break inside a nested user frame. We need to pop all the way back to when the first inline function call happened. But that involves, in full generality, rewinding back arbitrary changes to the mutable state. Instead of having to figure out how to reverse all that, we just say, okay, whatever, we're going to start over again, but this time we're just going to stop immediately when we get to the function call. Same idea, we're going to restart and then use our new knowledge to make different decisions when we're compiling.
Compiler collectives are pretty cool. I actually successfully ran them on the production model that sort of sent us down this goose chase in the first place. There's a really interesting consequence to it. Not only does it solve the recompile problem, which actually happens pretty rarely - I don't even know that I've actually necessarily solved it, this is something that I'd have to actually run the real model. I have a synthetic test case that shows that it works, but I don't know definitively that it works with the real model.
But the thing is that because the compiler is talking to each other, even on the first iteration, I actually can skip the stutter step that happens typically when you have automatic dynamic. The stutter step being the very first time I compile it with static shapes, and then the second time I compile it with dynamic shapes. I don't need to do that anymore because I figure out immediately that the shapes are all dynamic. This actually drops down compile time for this model from 95 minutes to 63 minutes. That's pretty cool.
The problem with this approach is that it's not universally applicable. I said that we're going to assume that every rank compiles at the same time, but it's really easy for me to have a valid SPMD program with Torch Compile that doesn't have this property. Like, just say that I have one rank doing one thing, another rank doing another thing, and it just so happens that the first rank has one graph to compile, but the second rank has two ranks to compile. This is a kind of strange architecture, and you're definitely doing something unusual if this sort of thing happens, but it's possible.
In this situation, you can't turn on compiler collectors because you're just going to deadlock when one of the compiled regions is trying to talk to the other ones. Fortunately, the deadlock is pretty obvious because, if you have any sort of ability to look at all the stacks when you're thinking deadlocks - like, you know, basic capability that you should have when doing distributed training - you just look at the stacks, you'll see someone was blocked in a compiler collective and you'll be like, okay, yeah, I guess that's what happened.
But it does give me a little trouble figuring out how I'm going to roll this out because right now in Nightlies it's a configuration option. It's not on by default. I actually want this to be on for most of the jobs we're running but it's going to be a little bit of work to figure out how to roll it out.
Okay, that's basically it behind compiler collectives. The original PR is actually very simple. I had to fix some bugs because there are some funny interactions but it basically worked pretty well. I do kind of wonder if this is the right approach. There are a bunch of other approaches that we thought about, which I just want to talk about briefly here, because they are kind of interesting alternate approaches.
One of the other ideas that I had was, hey, why don't you just mark dynamic, the input in question, and I don't have to go through all of this rigmarole of doing compiler collectives to talk to each other to figure it out. Sometimes I think this is exactly what you should do. But in this particular model, it's actually not a single graph that's getting compiled. It's actually 10 subgraphs, five of which have non-trivial graph content. The particular graph that is getting recompiled is embedded in the middle of this opaque model that I don't really know what it is. I actually, due to some vagaries in our environment, I can't even edit it directly. It's produced as a side effect of some other compilation process. Yeah, this is kind of a crazy thing to do, but that's just the place we are for this particular model. It's not obvious where to put the MarkDynamic because where to put it depends on where the graph breaks that Dynamo decided to put were. And in general, that's not well defined.
Another idea that I've had in the past about AutomaticDynamic is this thing where we have to run it once and then run it again to figure things out. It's always been sort of a stick up my back, like ugh, can't we just record this somehow and then the next time around just do the right thing? Seems pretty reasonable, right? If you imagine some sort of profile guided optimization setup, the way profile guided optimizers in traditional compilers work is you just run your program, you get a profile, you put it up somewhere and then compiler uses it and optimizes your code. If the profile changes, if the code changes, then the profile might be a little stale out of date and your compilation might not be as good, but as long as you refresh the profile then things will be good again. Yeah, it's kind of operationally complicated to run and maybe we still want to do this, but I got talked out of it so I don't know, not really what I'm going to do.
Another idea is, hey, what's the point of forcing every compiler to compile at the same time? Don't you really want just one compiler to compile everything and then send it everywhere? If you're really doing SPMD then it's really going to be the same thing everywhere and it's true, that would be pretty nice. There are some problems. One of the problems is that you don't have an obvious artifact without running Dynamo because when you have a bunch of graph breaks, once again, Dynamo is calling the shots about where the graph breaks are.
You can imagine some sort of record replay setup where you record a Dynamo execution, and then you replay that on subsequent runs, and that's exactly what you want. That would certainly work. But it's kind of operational. You actually still have to run the entire training script to actually get the recording, which is your quote-unquote compile product. There's no offline compilation process. Actually, one of the big problems with trying to make PyTorch 2 more ahead of time is that it really leans into the eager mode. A lot of the time, we are solving a lot of problems by just being able to assume that we're actually running the model with real data. This solves a lot of problems. If you take that away, if you're trying to do a full expert workflow, things get a lot harder in a lot of aspects.
But compiling only one place is kind of a good idea. We're kind of talking about this for some of the easier regimes. For example, once you go to Inductor, in some sense the Inductor compilation is much more well behaved than Dynamo because what's Inductor? Well, it takes in an FX graph and a big pile of config options and then produces a bunch of Triton kernels that you want to actually run. This is actually good old fashioned style compiler input output. You could very much imagine, just go ahead and say, do the compilation of this graph somewhere else, right? Some other service where if all the ranks ask that service for the same compile result, it can notice that they're the same, batch them into a single request, compile it and return the result to everyone. We call this remote execution for Inductor. We actually talked about this at the most recent composability sync. Go check that out if you're more interested.
Maybe we'll have that in the future. We're still kind of fighting fires with our existing cache deployment and caching is kind of easier, harder, it shares a lot of similarities with remote execution, you have to solve a lot of the same problems with that. We're still working on getting caching under control, so that's kind of where we are right now.
Alright, that's everything I want to talk about with compiler collectives. Talk to you next time.