PyTorch Developer Podcast

Inductor - Define-by-run IR

Episode Summary

Define-by-run IR is how Inductor defines the internal compute of a pointwise/reduction operation. It is characterized by a function that calls a number of functions in the 'ops' namespace, where these ops can be overridden by different handlers depending on what kind of semantic analysis you need to do. The ops Inductor supports include regular arithmetic operators, but also memory load/store, indirect indexing, masking and collective operations like reductions.

Episode Transcription

Hello everyone and welcome to the PyTorch Dev Podcast. Today I want to talk a little more in detail about the define-by-run portion of Inductor's IR. What is the define-by-run portion of Inductor's IR? Well, in our previous podcast episode about Inductor's IR, we talked about the various IR nodes that are explicitly represented as an intermediate representation between ATen operations and when we actually do Triton code generation. Well, the defined by run portion of the IR is the specific part of the IR, which is responsible for representing the element wise operations that you might be interested in doing when you're representing some operation.

The canonical example of when we use defined by run IR in Inductor is, for example, when we're representing a point wise operation. So to represent a point wise operation, we get a regular point wise IR node that represents the entirety of the point wise operation. But then there is an inner function, which represents the actual compute that is going to happen inside the pointwise operation, the ads, the moles, whatever, all of that stuff.

So when you are thinking about, you know, where are the data structures for inductors IR, you'll see all the top level ones have classes that are subclasses of IR node, but then all the little actual, you know, compute, all of that stuff is going to be done via this defined by run IR.

So how can I go about and read about what exactly this defined by run IR is? Well, if you were asking me a month ago, I'd say, well, you kind of have to figure it out by reading the code. Fortunately, I recently added a pull request to Inductor to basically document the entirety of the what I call ops handler inside Inductor. Because the way that the defined by run IR works is that we're constructing functions that are calling other functions, in this case, operations in the ops namespace, which we have the ability to override the meaning of so that we can do different things depending on what we need to do.

So just to break it down, in Inside Inductor, there's a module called virtualized. What this module does is it defines some dynamically scoped variables, which represent various things you might be interested in querying. when you're performing operations in inductor. And the one we're particularly interested in is a global variable. Well, not really global, it's thread local, variable called ops, which represents all of the potential operations you can do inside the defined by run inductor IR.

So ops has a method named add, it has a method named store, load, etc. when we are defining, for example, a pointwise operation, and we want to define the inner function for that pointwise operation, what the inner function is going to look like, it's going to say, well, I'm a function. And once you pass me in some indexes, usually these functions are taken indexes as arguments saying, you know, hey, this is where you should actually get information from what is typically going to happen is you're going to go ahead and you know, do a load to read out the information in question, and then actually, you know, do what with the result call, you know, addition or multiplication or whatever, you know, actual operation that you want to do. And this all gets packaged up into the inner function, which gets associated with, for example, a pointwise operation.

So when I do something like this, I have the ability to basically change the meaning of what calls to ops means depending on what I need to do. So the very most simple example of what you might want to do is you might want to turn one of these inner functions. into a string representing, well, what is the actual computation that you want to do? When you print out a pointwise IR node and you get out, you know, hey, the inner fun is this thing, we're actually calling this function inner function string, which is doing this operation.

So what exactly does it do? Well, it says, okay, let me go ahead and override the ops handler, the meaning of the ops object in virtualized. So that points to a kernel printing handler, which basically says, okay, well, you know, whenever you call me, what I'm going to do is I'm going to turn your call into a string representing whatever it is that you call me with, and then, you know, return those strings. And so when you're done, you basically get a, you know, string representation of all the operations that happened in that case. And so everything that you want to do, code generation, semantic analysis, they all operate by overriding the meaning of ops. in the virtualized namespace, and then going ahead and running the inner function directly. You can even verify this inner function into a plain FXIR. I mean, it's in fact very simple. What you do is you just say, okay. What I'm going to do is instead of passing in regular index variables, I'm just going to pass in FX proxies. And when I do calls on those FX proxies, I'm going to instead record what calls actually happened into an FX graph.

So, you know, very simple. you know, very easy to write code very simply. Writing things in this way is also very convenient in Python, because Python supports a lot of metaprogramming. So if you're running one of these operator handlers, and you're just like, well, you know, for most things, I have a very generic formula that works for any because there are tons of these operations, right, like every primitive math operation, actually, the way to think about it is for every like torch operator, which we support. pointwise compute on, you know, and that includes things like negate and sign, each of these has a ops definition. Now, sometimes when we're actually doing code gen, we can desugar these into more primitive operations. And we often do. But just for ease of sort of wiring everything up, basically everything that is supported in the Torch frontend gets an ops operator inside of this Define by Run IR.

So there's a lot of these that you have to handle. And people often don't need to handle them all individually. They can just write a generic get attribute that takes in some list of positional arguments, takes in some list of keyword arguments, and then does the operation on all of these things. So that's pretty nice.

So what exactly should I expect to see? when I am looking at the supported operations in, you know, Ops inside this defined by run IR.

So as I've mentioned, there's all of the regular, you know, arithmetic computation that you might be interested in. Those are quite uniform, so I'm not going to talk about them too much.

There's also operations for reading from memory and storing from memory, so store and load. There's also operations for interacting with randomness. Randomness is directly encoded in the defined by run IR, you know, because they require special code generation typically.

And there's also some kind of really unusual things that are also supported in this IR. So for example, One of the very important things we need to do when we are generating code is we need to compute indexing expressions that say exactly where in memory we want to read from, right? And so the normal situation when you're doing indexing is you get a bunch of SymPy expressions representing, you know, some sort of indexing compute. These are represented as SymPy expressions because we want to be able to simplify these expressions to, you know, basically get rid of... Because in general, it's going to be very complicated. You need to multiply every index dimension by the stride and do all of that. But sometimes it can be simplified quite a bit. And then maybe you only need a single Linux variable at the end. So you typically have these SMPI expressions floating around. But of course, sometimes, you know, we want to do operations, which, for example, depend on, you know, a indirect. You want to do some indirect indexing where you have some computation that you did based on tensor data. And then you want to actually do that to do an indexing expression. So there's a indirect indexing function, which essentially takes a regular, you know. regular value that you computed, regular tensor compute value, and then turns it into a Sympy expression so that you can use it in subsequent indexing operations.

So this one's very unusual and typically needs special handling because most of the operations inside ops return. I want to say just some tensor value. It's actually not well-defined what the ops handler returns because whenever you're doing different analyses, we will override the return value to mean different things. If I'm formatting my inner function to be a string, these functions are going to take in strings representing the various inputs and then return a string saying, hey, this is what the output string format is going to be. And if I'm doing some sort of... code generation that's typically, what I'm typically passing around is not a string, but this thing called a CSC value, which is like a string, but also we're doing some common sub-expression elimination while we're at it. But indirect indexing is different. It takes in one of these unspecified values and produces a SymPy expression.

Now, unlike all of the regular tensor compute pointwise operations, we don't actually override the meaning of SymPy expressions. So SymPy expressions are always done via SymPy. They're always represented explicitly as the SymPy abstract syntax tree. So you actually do need to provide a SymPy expression, even if it's just like a bogus one when you're implementing something like indirect indexing. Some of the other unusual operations we support.

So for example, the defined by run IR is also higher order in some cases. For example, the masked operator handles a situation where you are performing some sort of set of operations, like say some loads in stores, but they may not be as effective as you would like them to be. Not always be valid. For example, you are doing a indirect load and sometimes the index is invalid And in fact what is happening is that you had some condition which said hey, should I do the load or not? I can't unconditionally do the load because if I unconditionally do the load, I'll have an illegal memory access.

So I need to mask out the load only on the parallel compute where the index is valid should I actually do the load. So the masked operator lets us do this by simply saying, okay, well, give me a mask saying whether or not the index is valid or not. And then give me some function, like an inner function inside of my defined by unrun IR, which actually has the stores and loads that I want to have run. in a mass fashion. And I actually, I checked the implementation while I was preparing this podcast, and all we do is we just override the meaning of store and load before we go ahead and execute the body of the mass load.

So, you know, not only is defined by run IR, you know, like when you're at the top level and you're trying to decide what to do, you override the meaning of operations, but also we can recursively override the meaning within these local scopes to make them do different things.

Um, so, you know, the last set of operations that you'll get are some weird, you know, sort of collective style aggregation things. Like if you're doing reductions or scans, we also have operations representing those in the IR because, uh, well you, you need, you need a little more, uh, juice to actually represent that.

We, we do have dedicated top level IR nodes representing reductions and things like that. And these special operations are typically not valid. unless they're run in a context like that. But they're also something interesting to know about. Peter Bell is the expert on scan, having been the one who implemented it in the first place.

So we have talked about the defined by run IR in more detail. We've talked about the operators inside it and the general way you work with this, namely by overriding virtualized. That's everything I wanted to talk about today. Talk to you next time.