PyTorch Developer Podcast

API design via lexical and dynamic scoping

Episode Summary

Lexical and dynamic scoping are useful tools to reason about various API design choices in PyTorch, related to context managers, global flags, dynamic dispatch, and how to deal with BC-breaking changes. I'll walk through three case studies, one from Python itself (changing the meaning of division to true division), and two from PyTorch (device context managers, and torch function for factory functions).

Episode Notes

Lexical and dynamic scoping are useful tools to reason about various API design choices in PyTorch, related to context managers, global flags, dynamic dispatch, and how to deal with BC-breaking changes. I'll walk through three case studies, one from Python itself (changing the meaning of division to true division), and two from PyTorch (device context managers, and torch function for factory functions).

Further reading.

Episode Transcription

Hello, everyone, and welcome to the PyTorch Dev Podcast. Today, I want to talk about lexical scoping, dynamic scoping, and how these programming languages concepts relate to library design in PyTorch, specifically with regards to backwards compatibility and other questions. When I talk to people about working on PyTorch, sometimes I get questions from people who knew me before I joined the PyTorch project as a Haskell developer working on compilers. And they'd ask me if I was doing any programming languages stuff here in machine learning land. And I'd always be very happy to answer people and say yes. In fact, I use programming languages concepts all the time as a developer on the PyTorch project. And today's podcast about lexical and dynamic scoping is an example of how I use these concepts from programming languages to reason about some actually fairly complicated API design questions that, you know, as a Python library, PyTorch has to answer when we want to, you know, talk about how we're going to design an API in question.

So to start with, I need to explain what is lexical scoping? What is dynamic scoping? So lexical scoping is, so when we talk about scoping, we're typically talking about how do we resolve what the meaning of a variable is? So when I have a function, and I refer to the variable x, you know, how do I know what x is? Lexical scoping says that the value of x is whatever is lexically closest that defines the X in question. And when I say lexically closest, I mean, imagine you're looking at the source code of your program. You see the X, your eye wanders up outside of the enclosing blocks until you find a block that actually defines the X variable in question. And that definition is going to be the one that your actual use of the variable is going to point to.

In contrast, dynamic scoping is a form of scoping where the reference to x doesn't actually refer to whatever is lexically obvious. Instead, there's a concept of an implicit global variable, if you can think of it that way, which sort of gets changed whenever you do an assignment. So what the value of x will be is not what you saw in the lexical scoping, but in fact, whatever the color to your function set the variable to be when you, before you actually call in the function. So you have to look at the call stack to figure out what the value of a dynamically scoped variable is.

And so very concretely, in the Python programming language, there's no native support for dynamic scoping, but a lot of use cases that people use for context managers, you know, that's the with statement where you can with blah, and then inside of this block, something different happens because of the context manager. Context managers are a very easy way to implement dynamic scoping, because what you do is when you enter the context manager, you set some global variable to some value. When you exit, you reset it to its original value. And that's basically equivalent to having done a dynamically scoped variable assignment. And of course, you know, regular old Variable references in Python are done lexically. If you import modules and use identifiers from those, that's also done lexically. Okay, so up until this point, this is something that, you know, you might have gotten told about in your programming languages class in undergrad.

So what the heck does this have to do with PyTorch API design? So the first thing I want to talk about is a sort of case study in what happens when you want to change the semantics of a library, or in this particular example's case, the Python language itself, and why, you know, whether or not you choose to do this with lexical or dynamic scoping has pretty big implications on how usable the thing is.

So here's how the case study goes. So back in Python 2, the Python developers made a bad decision. And the bad decision they made was that they defined the slash operator to mean integer division. This was a very understandable mistake to make because languages like C defined a single slash. It'd be integer division. But what they found was that lots of people were using Python to do calculators and stuff like that. And they'd always ask things like, what is 1 divided by 2? And Python would helpfully or unhelpfully, from your perspective, say 0. And that was very unexpected.

So the Python developers decided, okay, we want to change what the meaning of division is. We want to change it from integer division to true division so that if you divide 1 by 2, you don't get 0. And so you get 0.5. Obviously, this is BC breaking. So how are you going to deal with a problem like this? Well, you want some way when you have a BC breaking change to let people opt in the new behavior before it becomes mandatory, and then only at some later point in time, namely Python 3, make it required.

So, you know, there's this intermediate time when you can change the meaning of your program to switch from, you know, integer division into true division. So how exactly did Python do this? Well, Python actually needed to introduce a special mechanism called a future import to make this happen. So the way the future import worked was that there's this special module called future, and you could say from future import division, and then what that would do was it would change the meaning of all of the slashes inside your current module to go from division to true division.

Now, if you're like me and you're thinking, you know, why the heck do I have to introduce an entirely new language feature? So future is not a module. It is like a special language feature that changes how the Python bytecode interpreter interprets your program. And why the heck do they have to introduce this new feature? Why couldn't they just have said, well, like something like, OK, instead of importing division. from the normal module, import division from the true division module. The same way, if I had a function and I wanted to change the function semantics, I could have a v1 of the module and a v2 of the module, and I could just pick which module I imported that function from to get one version or the other.

Well, the reason they needed to do this was because the division operator actually isn't a function. What division in Python desugars into is a call into a magic method. And whether or not it desugars into a call into the magic method div or the magic method true div depends precisely on your version of Python and whether or not you import future division. So in effect, the way that the meaning of division was defined was not by lexical scoping, which in fact, in some languages like Haskell, the meaning of division is lexically scoped. It's provided by this prelude module that is implicitly imported by your program, and that's how you tell what the meaning of division is. That's not the case in Python. Division always desugars into one of these method invocations. And method invocations, well, they're not really lexically scoped or dynamically scoped. Instead, it's a form of dynamic dispatch where you ask the object what the meaning of the operation should be.

And so to change the method invocation that happens in this case, you actually need some actual juice from the language itself. And so that's why the future mechanism exists. So Python had this problem. The problem they had was that they wanted to change the meaning of a method invocation in a backwards incompatible way, but they had no way of letting people opt into it one by one. So they introduced a language feature letting you change the meaning of the method from one thing to another.

In PyTorch, we often want to make BC breaking changes to methods. But unfortunately for us, there's no way to implement a same future style mechanism inside PyTorch. You just can't do it because it requires language support and Python didn't give us language support to do this. The best approximation for this is to have some sort of global flag, which you can use to toggle between the old behavior and the new behavior in question. But notice this is very different from what future import division does, right? Future import division only affects the division operators inside your module. If you import some other module that's using old school integer division, that integer division stays the same way that it used to be. So it's a very local. You can reason about what the meaning of division operators is simply by just looking at the top of your file. With a global flag, you don't actually know what the meaning is without walking up the call stack and looking for someone who actually set the global at some point in time. And so we actually try very hard not to do this in PyTorch. And the reason why we do that is going to become clear in my second case study.

Case study two, device context manager. To explain this case study, I have to first explain what a device context manager is. And this is a little tricky because there's no such thing in PyTorch, but it is a thing that has been requested over and over again by many different users. So here's what this hypothetical mechanism would do. When you write PyTorch programs, you often want to write your program in such a way that you have both CPU code and CUDA code.

So what does this look like? Well, you know, like you have your script, you wanna debug it and test it on CPU. And then at some point you wanna rerun it again on CUDA. And if you know anything about like PyTorch's API, we don't exactly make this easy to do. You have to actually plan your program out and like explicitly like, you know, parameterize over the device in question. And then, you know, toggle that with your options.

If you just sort of write like really plain straight line code, you're probably ending up hard coding that it operates on CPU or CUDA. So the device context manager is this concept that lets us, you write the naive code, like allocate a bunch of tensors with no device argument, do a bunch of operations on them, and then implicitly change the meaning of the factory function. So that if you, you know, use this context manager and say, Hey, set. the default device to be CUDA, then whenever you do any inner calls to the factory functions in question, they will actually produce CUDA tensors instead of CPU tensors.

So this is a decent example of dynamic scoping in action, right? Like when you use one of these conics managers, it's not just the like local calls to factory functions that are in your module that would be changed from CPU to CUDA. It's also all the inner calls to like all the modules you might be instantiating and everything else. And this is kind of desirable, right? Because like one of the things that people find very annoying about... how things have to be done today is you have to like plumb the device you want down recursively into all of the like creation functions that you're doing. And in this case, this is like all of the sub modules in your modules.

By the way, we used to not actually let you plumb device down. But Joel Schlosser very recently landed a patch to PyTorch that makes all modules take a device argument. So you can change what the device is, you know, at module construction time. Before that, you had to actually always construct your module on CPU and then move it onto the device you wanted. And that's kind of inefficient, and a lot of people didn't like having to do that.

So anyway, this device context manager would let you change, for example, where your modules get allocated without having to actually explicitly pass in this device argument. And so a lot of people would like this. It would make things very convenient, and we don't want to do it. Why don't we want to do it? Well, the reason we don't want to do it is because of the fact that it, you know, actually recursively goes down and all of your calls in the call set change their semantics, right? This is like both a blessing and a curse.

The blessing of it is that you don't have to coordinate with anyone to change the device. You just set this context manager and then magically the meanings of all of your factory functions change. The curse of it is you don't have to coordinate with anyone. So if someone writes some code that assumes that TorchMT is just going to give you a CPU tensor, because when I tested the code on my machine, it gave me a CPU tensor. How difficult could this possibly be? That code is going to unpredictably break.

And in practice, this code unpredictably breaks, because we have a janky version of device context managers called setDefaultTensorType. which you can actually use to change the default tensor type from CPU to CUDA. Please don't do this. We really hate this function. We want to get rid of it. But this one, people always post forum posts being like, hey, I did this thing. And like my code, some code library code that I'm calling doesn't work.

So the problem with untyped dynamic scoping is that it is a global tax on all code written in your library. If you have primitive function calls that are modulated by some dynamic scope, by a context manager, everyone who writes library code is obligated to make sure that their code works under all possible settings of the context manager. So in this case, whenever I write a bare torch.empty and not torch.empty device equals CPU, I'm obligated to make sure that this will work even if you do a CUDA device.

And maybe this is like possible and maybe this is even the right trade-off to make. But historically, PyTorch doesn't have this requirement. And so a lot of code is not written under this assumption. And so if you want to add a device context manager and you want it to do it right, and when I say right, I mean like, this context manager actually works in like 99% of all the situations you use it in, you actually have to go and painstakingly audit all of your Python code to make sure that it's actually doing the right thing in this case. Blech.

So like, you know, dynamic scoping leads to unpredictable effects because it like lets you reach into code that wasn't expecting to be modulated. Sometimes this is a good thing, right? Like it saves you from having to explicitly pass arguments around. If you're Emacs, you know, actually, like, you love dynamic scoping because it makes it so easy to just set some variables and then use them later inside somewhere else without having to muck about with function signatures. But, like, this implicitness also comes with a cost.

Okay, I have one last case study, and this relates to Torch function and also... a sort of new mechanism proposed by NumPy for handling factory functions. So a little bit of backstory here. So torch function is this thing where you can write an object, you put a torch function magic method on it. And then whenever you pass these objects into torch dot cat, torch dot add any of the functions in the torch namespace, we'll actually just call this magic torch function method so that you can override the meaning of operations involving tensor subclasses.

So this is very useful, and you can use it to implement all sorts of interesting tensor-like objects without having to actually monkey patch all of PyTorch's functions to do something different in this case. But there is a problem. And the problem is Torch function is predicated on the idea that any given function operation takes in an actual tensor as an argument. Because the way it does dispatch... is in the very Pythonic dynamic dispatch style, we look for an object that has a torch function on it, and that's the torch function implementation we call.

So what happens when you have a function that doesn't have any tensor arguments? An example of that is a factory function, right? Torch.empty, which just takes in a list of sizes and gives you a tensor in question. So custom classes have a problem, which is they need to also somehow override these factory functions, but they have no way of doing so because their standard mechanism of overriding is via dynamic dispatch, but there is no dynamics dispatch in this situation.

So there are a bunch of ways to solve this problem. As the saying goes, if the mountain won't come to Muhammad, Muhammad must go to the mountain. So if you want dynamic dispatch and the factory function doesn't have dynamic dispatch, well, turn it into a call that does have dynamic dispatch. So we have a bunch of functions on tensors like new empty and new zeros. And you can use those in place of the good old-fashioned torch function. torch factory function in the main namespace. And that will indeed work. And then you just have to define those things in your torch function to get things going. And this just preserves the same property, right, which is that you are using the objects that are lexically in scope to do the dynamic dispatch to get to the implementation you want.

There's an elaboration on this idea, which is a NumPy proposal at this point in time, which instead of directly creating new variants of methods for tensors for all the factories functions instead, wrap them up into a module call. So given a tensor, you can extract out a module that corresponds to the type of module that you would have called the factory functions on. But this one is specialized for the subclass in question.

So what does this look like? So I've got a tensor. I want to create a new tensor. So on this tensor, I call the module accessor, which gives me a torch module, something that looks like torch. So it's got empty and it's got ones and it got zeros on it. But this module is special because if I call zeros on this module, I will actually get a tensor that is of the same subclass as whatever my original tensor that I got this module out from the beginning.

So same idea, right? Use the lexically scoped values to get out the module and then do the dynamic description on the module itself. So you just don't have to like shove everything into the method namespace.

Of course, there's another way to do this, and that's using a context manager. And this is actually more likely than you might think. So in previous podcasts, I've talked about Functorch, a method for doing functional transformations on PyTorch programs. And in Functorch, there's a very natural place where a context manager would be applied. And that's when you use one of the higher order combinators like VMAP to actually do an operation on a tensor.

So when I enter the VMAP, what I'm effectively going to do is I'm going to basically turn on the VMAPPiness. And what that also means is that I might very reasonably want to override the behavior of all the factory functions as well implicitly when I do this. And this is actually very natural. And in fact, in JAX, this concept is called omni-staging, where previously JAX only did data-dependent control flow. But at some point in the future... they realized, hey, actually, it's really useful to be able to, you know, override the behavior, all these free functions. And so, you know, let me just go ahead and do that. And so that's, that's called omni staging in JAX.

So which of these is the right thing?

Well, if we look back to our previous case study on Device Context Manager, PyTorch said, hey, you know, we want explicitness. We don't want, we've got all this code that's been written already that doesn't think that you're going to like change the meaning of things under your feet. So like, you know, let's just make sure that you keep doing things explicitly. And so we're... You don't really want to add this conics manager. But then when we look at this Torch function module case, there is a solution that you can do to stay with the lexical attitude, which honestly is PyTorch's attitude.

But you can also see that there is a lot of merit to doing the dynamic scoping. And these problems of backwards compatibility don't, they're not as pressing. because although you might not have written your code so that it works correctly under CPU or CUDA, with VMAP, well, you're explicitly asking for VMAP in this case. So one is you're probably gonna make sure all the code you're calling is stuff that works correctly in this case. And two is that VMAP actually is very carefully written so that the code on the inside looks exactly like you're doing a single example. case. So it really is supposed to work, even if you like to change out the semantics of everything. It's just, you're just, you know, adding these batch dimensions in a way that like is your code should be indifferent to.

So what's the right answer? Well, I don't really know. When I talk to people, and they asked me for device context manager, um, you know, I used to call over Greg and Greg was like, No, we're not going to do this, because everyone's code is not going to work in this case. Well, maybe If you're willing to put in the work to make this all work correctly and all the library and all the ecosystem, I think some dynamic scoping might actually be pretty helpful. But there's a lot of work, and I want to see this work actually have an honest attempt for this. That's everything I wanted to talk about for today. Talk to you next time.