Getting Started With CUDA for Python Programmers

[Jeremy Howard]

Introduction to CUDA

Hi there, I’m Jeremy Howard from answer.ai, and this is Getting Started with CUDA. CUDA is, of course, what we use to program NVIDIA GPUs if we want them to go super fast and we want maximum flexibility. And it has a reputation of being very hard to get started with. The truth is, it’s actually not so bad. You just have to know some tricks. And so in this video, I’m going to show you some of those tricks. So let’s switch to the screen and take a look. So I’m going to be doing all of the work today in notebooks. This might surprise you. You might be thinking that to do work with CUDA, we have to do stuff with compilers and terminals and things like that. And the truth is, actually, it turns out we really don’t, thanks to some magic that is provided by PyTorch. You can follow along in all of these steps, and I strongly suggest you do so in your own computer.

[00:01:02]

Accessing the Colab Notebook

You can go to the CUDA mode organization in GitHub, find the lecture two repo there, and you’ll see there is a lecture three folder. This is lecture three of the CUDA mode series. You don’t need to have seen any of the previous ones, however, to follow along. In the readme there, you’ll see there’s a lecture three section, and at the bottom, there is a click to go to the Colab version. Yep, you can run all of this in Colab for free. You don’t even have to have a GPU available to run the whole thing.

Overview of the Book “Programming Massively Parallel Processes”

We’re going to be following along with some of the examples from this book, Programming Massively Parallel Processes. Programming Massively Parallel Processes is a really great book to read, and once you’ve completed today’s lesson, you should be able to make a great start on this book.

[00:02:02]

It goes into a lot more details about some of the things that we’re going to cover on fairly quickly. It’s okay if you don’t have the book, but if you want to go deeper, I strongly suggest you get it. In fact, you’ll see in the repo that lecture two in this series actually was a deep dive into chapters one to three of that book, and so actually you might want to do lecture two, confusingly enough, after this one, lecture three, to get more details about some of what we’re talking about. Okay, so let’s dive into the notebook.

Converting PyTorch Code to CUDA

So what we’re going to be doing today is we’re going be doing a whole lot of stuff with plain old PyTorch first to make sure that we get all the ideas, and then we will try to convert each of these things into CUDA. So in order to do this, we’re going to start by importing a bunch of stuff. In fact, let’s do all of this in Colab.

[00:03:00]

Setting up Colab Runtime

So here we are in Colab, and you should make sure that you set in Colab your runtime to the T4 GPU. That’s one you can use plenty of for free, and it’s easily good enough to run everything we’re doing today. And once you’ve got that running, we can import the libraries we’re going to need, and we can start on our first exercise.

Converting RGB to Grayscale

So the first exercise actually comes from chapter two of the book, and chapter two of the book teaches how to do this problem, which is converting an RGB color picture into a grayscale picture. And it turns out that the recommended formula for this is to take 0.21 of the red pixel, 0.72 of the green pixel, 0.07 of the blue pixel, and add them up together. And that creates the luminance value, which is what we’re seeing here. That’s a common way, kind of the standard way to go from RGB to grayscale. So we’re going to do this. We’re going to make a CUDA kernel to do this.

[00:04:02]

Downloading a Puppy Image

So the first thing we’re going to need is a picture. And anytime you need a picture, I recommend going for a picture of a puppy. So we’ve got here a URL to a picture of a puppy, so we’ll just go ahead and download it. And then we can use torchvision.io to load that. So this is already part of Colab. If you’re interested in running stuff on your own machine or a server in the cloud, I’ll show you how to set that up at the end of this lecture.

Reading the Image

So let’s read in the image. And if we have a look at the shape of it, it says it’s 3 by 1066 by 1600. So I’m going to assume that you know the basics of PyTorch here. If you don’t know the basics of PyTorch, I am a bit biased, but I highly recommend my course, which covers exactly that. You can go to course.fast.ai, and you get the benefit also of having seen some very cute bunnies.

[00:05:02]

And along with the very cute bunnies, it basically takes you through all of everything you need to be an effective practitioner of modern deep learning.

PyTorch Basics

So finish part one if you want to go right into those details. But even if you just do the first two or three lessons, that will give more than enough you need to know to understand this kind of code and these kinds of outputs. So I’m assuming you’ve done all that. So you’ll see here we’ve got a rank 3 tensor. There are three channels, so they’re like the faces of a cube, if you like. There are 1066 rows on each face, so that’s the height. And then there are 16 columns in each row, so that’s the width. So if we then look at the first couple of channels, and the first three rows, and the first four columns, you can see here that these are unsigned 8-bit integers.

[00:06:00]

So they’re bytes. And so here they are. So that’s what an image looks like. Hopefully you know all that already.

Displaying the Image

So let’s take a look at our image. To do that, I’m just going to create a simple little function, show image, that will create a matlib plot. Remove the axes. If it’s color, which this one is, it’ll change the order of the axes from channel by height by width, which is what PyTorch uses, to height by width by channel, which is what matplotlib, I’m having trouble today, expects. So we change the order of the axes to be 1, 2, 0. And then we can show the image, putting it on the CPU if necessary. Now, we’re going to be working with this image in Python, which is going to be just pure Python to start with, before we switch to CUDA. That’s going to be really slow, so we’ll resize it to have the smallest dimension be 150.

[00:07:04]

So that’s the height in this case. So we end up with a 150 by 225 shape, which is a rectangle, which is 3, 3, 750 pixels, each one with R, G, and B values. And there is our puppy. So see, wasn’t it a good idea to make this a puppy?

Grayscale Conversion in Python

Okay, so how do we convert that to grayscale? Well, the book has told us the formula to use. Go through every pixel, and do that to it. All right, so here is the loop. We’re going to go through every pixel, and do that to it, and stick that in the output. So that’s the basic idea. So what are the details of this? Well, here we’ve got channel by row by column. So how do we loop through every pixel? Well, the first thing we need to know is how many pixels are there. So we can say channel by height by width is the shape.

[00:08:03]

So now we’ve defined those three variables. So the number of pixels is the height times the width. And so to loop through all those pixels, an easy way to do them is to flatten them all out into a vector. Now what happens when you flatten them all out into a vector? Well, as we saw, they’re currently stored in this format where we’ve got one face, and then another face, and then there’s a we haven’t got it printed here, but there’s a third face. Within each face then there is one row, we’re just showing the first few, and then the next row, and then the next row, and then with each row you’ve got column, column, column. So let’s say we had a small image in which, in fact, we can do it like this. We could say here’s our red, so we’ve got the pixels of 0, 1, 2, 3, 4, 5.

[00:09:07]

So let’s say this was a height 2, width 3, 3 channel image. So then there’ll be 6, 7, 8, 9, 10, 11, RGB, 12, 13, 14, 15, 16. So let’s say these are the pixels. So when these are flattened out, it’s going to turn into a single vector just like so. 6, 7, 8, 12, 13, 14. So actually when we talk about an image, we initially see it as a bunch of pixels.

[00:10:06]

We can think of it as having three channels. But in practice, in our computer, the memory is all laid out linearly. Everything has just an address in memory. It’s just a whole bunch. You can think of it as your computer’s memory is one giant vector. And so when we say, when we say flatten, then what that’s actually doing is it’s turning our channel by height by width into a big vector like this. Okay, so now that we’ve done that, we can say, all right, the place we’re going to be putting this into, the result, we’re going to start out with just an empty vector of length n.

[00:11:05]

We’ll go through all of the n values from 0 to n minus 1. And we’re going to put in the output value 0.29-ish times the input value at x i. So this will be here in the red bit. And then 0.59 times x i plus n. So n here, n here is this distance. It’s the number of pixels. 1, 2, 3, 4, 5, 6. See? 1, 2, 3, 4, 5, 6. So that’s why to get to green, we have to jump up to i plus n. And then to get to blue, we have to jump to i plus 2n.

[00:12:01]

See? And so that’s how this works. We’ve flattened everything out, and we’re indexing into this flattened out thing directly. And so at the end of that, we’re going to have our grayscale is all done. So we can then just reshape that into height by width. And there it is. There’s our grayscale puppy. And you can see here, the flattened image is just a single vector with all those channel values flattened out as we described. Okay. Now, that is incredibly slow. It’s nearly two seconds to do something with only 34,000 pixels in.

Introduction to CUDA

So to speed it up, we are going to want to use CUDA. How come CUDA is able to speed things up? Well, the reason CUDA is able to speed things up is because it is set up in a very different way to how a normal CPU is set up.

[00:13:10]

Understanding GPU Architecture

And we can actually see that if we look at some of this information about what is in an RTX 3090 card, for example. Now, an RTX 3090 card is a fantastic GPU. You can get them secondhand, pretty good value. It’s a really good choice, particularly for hobbyists. What is inside a 3090? It has 82 SMs. What’s an SM? An SM is a streaming multiprocessor. So you can think of this as almost like a separate CPU in your computer. And so there’s 82 of these. So that’s already a lot more than you have CPUs in your computer. But then each one of these has 128 CUDA cores.

[00:14:01]

So these CUDA cores are all able to operate at the same time. These multiprocessors are all able to operate at the same time. So that gives us 128 times 82. 10,500 CUDA cores in total that can all work at the same time. So that’s a lot more than any CPU we’re familiar with can do. And the 3090 isn’t even at the very top end. It’s really a very good GPU, but there are some with even more CUDA cores.

Using CUDA Cores

So how do we use them all? Well, we need to be able to set up our code in such a way that we can say here is a piece of code that you can run on lots of different pieces of data, lots of different pieces of memory at the same time, so that you can do 10,000 things at the same time. And so CUDA does this in a really simple and pretty elegant way, which is it basically says, OK, take out the inner loop.

[00:15:04]

CUDA Kernel Concept

So here’s our inner loop. The stuff where you can run 10,000 of these at the same time. They’re not going to influence each other at all. So you see these do not influence each other at all. All they do is they stick something into some output memory. So it doesn’t even return something. You can’t return something from these CUDA kernels, as they’re going to be called. All you can do is you can modify memory in such a way that you don’t know what order they’re going to run in. They could all run at the same time. Some could run a little bit before another one, and so forth. So the way that CUDA does this is it says, OK, write a function. And in your function, write a line of code which I’m going call as many dozens, hundreds, thousands, millions of times as necessary to do all the work that’s needed. And I’m going to do this in parallel for you as much as I can. In the case of running on a 3090, up to 10,000 times, up to 10,000 things all at once.

[00:16:06]

And I will get this done as fast as possible. So all you have to do is basically write the line of code you want to be called lots of times. And then the second thing you have to do is say how many times to call that code. And so what will happen is that piece of code, called the kernel, will be called for you. It’ll be passed in whatever arguments you ask to be passed in, which in this case will be the input array tensor, the output tensor, and the size of how many pixels are in each channel. And it’ll tell you, OK, this is the i-th time I’ve called it. Now, we can simulate that in Python very, very simply.

Simulating CUDA Kernels in Python

A single for loop. Now, this doesn’t happen in parallel, so it’s not going to speed it up. But the kind of results, the semantics, are going to be identical to CUDA. So here is a function we’ve called run kernel.

[00:17:01]

We’re going to pass it in a function. We’re going to say how many times to run the function and what arguments to call the function with. And so each time it will call the function, passing in the index, what time, and the arguments that we’ve requested. OK, so we can now create something to call that. So let’s get the, just like before, get the channel, number of channels, height and width, the number of pixels, flatten it out, create the result tensor that we’re going to put things in. And this time, rather than calling the loop directly, we will call run kernel. We will pass in the name of the function to be called as f. We will pass in the number of times, which is the number of pixels for the loop. And we’ll pass in the arguments that are going to be required inside our kernel.

[00:18:00]

So we’re going to need out, we’re going to need x, and we’re going to need n. So you can see here we’re using no external libraries at all. We have just plain Python and a tiny bit of PyTorch, just enough to create a tensor into index into tensors. And that’s all that’s being used. But conceptually, it’s doing the same thing as a CUDA kernel would do, nearly. And we’ll get to the nearly in just a moment. But conceptually, you could see that you could now potentially write something, which if you knew that this was running a bunch of things totally independently of each other, conceptually, you could now truly easily paralyze that. And that’s what CUDA does.

CUDA Kernel Execution with Blocks and Threads

However, it’s not quite that simple. It does not simply create a single list of numbers, like range n does in Python, and pass each one in turn into your kernel.

[00:19:14]

But instead, it actually splits the range of numbers into what’s called blocks. So in this case, you know, maybe there’s like a thousand pixels we wanted to get through. It’s going to group them into blocks of 256 at a time. And so, in Python, it looks like this. In practice, a CUDA kernel runner is not a single for loop that loops n times. But instead, it is a pair of nested for loops. So you don’t just pass in a single number and say this is the number of pixels, but you pass in two numbers.

[00:20:01]

Number of blocks and the number of threads. We’ll get into that in a moment. But these are just numbers. They’re just, you can put any numbers you like here. And if you choose two numbers that multiply to get the thing that we want, which is the n times we want to call it, then this can do exactly the same thing. Because we’re now going to pass in which of the, what’s the index of the outer loop we’re up to? What’s the index in the inner loop we’re up to? How many things do we go through in the inner loop? And therefore, inside the kernel, we can find out what index we’re up to by multiplying the block index times the block dimension. So that is to say the i by the threads and add the inner loop index, the j. So that’s what we pass in with the i, j threads. But inside the kernel, we call it block index, thread index, and block dimension. So if you look at the CUDA book, you’ll see here this is exactly what they do.

[00:21:04]

They say the index is equal to the block index times the block dimension plus the thread index. There’s a .x thing here that we can ignore for now. We’ll look at that in a moment. But in practice, this is actually how CUDA works. So it has all these blocks, and inside there are threads, and you can just think of them as numbers. You can see these blocks, they just have numbers. 0, 1, dot, dot, dot, dot, and so forth. Now that does mean something a little bit tricky, though, which is, well, the first thing I’ll say is how do we pick these numbers, the number of blocks and the number of threads?

Choosing the Number of Threads

For now in practice, we’re just always going to say the number of threads is 256. And that’s a perfectly fine number to use as a default anyway. You can’t go too far wrong just always picking 256, nearly always.

[00:22:01]

So don’t worry about that too much for now, optimizing that number. So if we say, okay, we want to have 256 threads. So remember that’s the inner loop, or if we look inside our kernel runner here, that’s our inner loop. So we’re going to call each of, this is going to be called 256 times. So how many times do you have to call this? Well, you’re going to have to call it n, number of pixels, divided by 256 times. Now that might not be an integer, so you’ll have to round that up, so ceiling. And so that’s how we can calculate the number of blocks we need to make sure that our kernel is called enough times. Now we do have a problem, though, which is that the number of times we would have liked to have called it, which previously was equal to the number of pixels, might not be a multiple of 256.

[00:23:01]

Guard Block in CUDA Kernels

So we might end up going too far, and so that’s why we also need in our kernel now this if statement. And so this is making sure that the index that we’re up to does not go past the number of pixels we have. And this appears in basically every CUDA kernel you’ll see, and it’s called the guard, or the guard block. So this is our guard to make sure we don’t go out of bounds. So this is the same line of code we had before, and now we’ve also just added this thing to calculate the index, and we’ve added the guard. And this is like the pretty standard first lines from any CUDA kernel. So we can now run those, and they’ll do exactly the same thing as before. And so the obvious question is, well, why? Why do CUDA kernels work in this weird block and thread way?

Why Blocks and Threads?

Why don’t we just tell them the number of times to run it?

[00:24:03]

Why do we have to do it by blocks and threads? And the reason why is because of some of this detail that we’ve got here, which is that CUDA sets things up for us so that everything in the same block, or to say it more completely, thread block, which is the same block, they will all be given some shared memory.

Shared Memory in CUDA

And they’ll also all be given the opportunity to synchronize, which is just to basically say, okay, everything in this block has to get to this point before you can move on. All of the threads in a block will be executed on the same streaming multiprocessor. And so we’ll see later in later lectures, but won’t be taught by me, that by using blocks smartly, you can make your code run more quickly. And the shared memory is particularly important.

[00:25:01]

So shared memory is a little bit of memory in the GPU that all the threads in a block share, and it’s fast. It’s super, super, super fast. Now, when we say not very much, it’s like on a 3090, it’s 128K. So very small. So this is basically the same as a cache in a CPU. The difference, though, is that on a CPU, you’re not going to be manually deciding what goes into your cache. But on the GPU, you do. It’s all up to you. So at the moment, this cache is not going to be used when we create our CUDA code, because we’re just getting started. And so we’re not going to worry about that optimization. But to go fast, you want to use that cache. And also, you want to use the register file, something a lot of people don’t realize is that there’s actually quite a lot of register memory, even more register memory than shared memory. So anyway, those are all things to worry about down the track, not needed for getting started.

[00:26:04]

Setting up CUDA Environment

So how do we go about using CUDA? There is a basically standard setup block that I would add. And we are going to add. And what happens in this setup block is we’re going to set an environment variable. You wouldn’t use this in kind of production or for going fast. But this if you get an error, stop right away, basically. So wait to see how things go. And then that way, you can tell us exactly when an error occurs and where it happens. So that slows things down. But it’s good for development. We’re also going to install two modules.

Installing CUDA Modules

One is a build tool, which is required by PyTorch to compile your C++ CUDA code. The second is a very handy little thing called Wurlitzer.

[00:27:03]

And the only place you’re going to see that used is in this line here, where we load this extension called Wurlitzer. Without this, anything you print from your CUDA code, in fact, from your C++ code full stop, won’t appear in a notebook. So you always want to do this where you’re doing stuff in CUDA so you can use print statements to debug things. Okay. So if you’ve got some CUDA code, how do you use it from Python?

Using Load Inline for CUDA Code

The answer is that PyTorch comes with a very handy thing called load inline, which is inside torch.utils.cpp extension. Load inline is a marvelous function that you just pass in a list of any of the CUDA code strings that you want to compile, any of the plain C++ strings that you want to compile, any functions in that C++ you want to make available to PyTorch.

[00:28:14]

And it will go and compile it all, turn it into a Python module, and make it available right away, which is pretty amazing. I’ve just created a tiny little wrapper for that called load CUDA, just to streamline it a tiny bit.

Loading CUDA Code

But behind the scenes, it’s just going to call load inline. The other thing I’ve done is I’ve created a string that contains some C++ code. I mean, this is all C code, I think, but it’s compiled as C++ code. We’ll call it C++ code. C++ code we want included in all of our CUDA files. We need to include this header file to make sure that we can access PyTorch tensor stuff.

[00:29:03]

We want to be able to use IO, and we want to be able to check for exceptions.

Defining Macros for CUDA Code

And then I also define three macros. The first macro just checks that a tensor is CUDA. The second one checks that it’s contiguous in memory, because sometimes PyTorch can actually split things up over different memory pieces. And then if we try to access that in this flattened out form, it won’t work. And then the way we’re actually going to use it, check input, we’ll just check both of those things. So if something’s not on CUDA, and it’s not contiguous, we ain’t going to be able to use it, so we always have this. And then the third thing we do here is we define ceiling division. Ceiling division is just this. Although you can implement it a different way, like this. And so this will do ceiling division.

[00:30:01]

And so this is how we’re going to, this is what we’re going to call in order to figure out how many blocks we need. So this is just, you don’t have to worry about the details of this too much, it’s just a standard setup we’re going to use.

Writing CUDA Kernels

Okay, so now we need to write our CUDA kernel. Now how do you write the CUDA kernel? Well, all I did, and I recommend you do, is take your Python kernel, and paste it into chat-gpt, and say convert this to equivalent C code, using the same names, formatting, etc, where possible. Paste it in, and chat-gpt will do it for you. Unless you’re very comfortable with C, in which case just write it yourself is fine. But this way, since you’ve already got the Python, why not just do this? It basically was pretty much perfect, I found. Although it did assume that these were floats, they’re actually not floats, I had to change a couple of data types, but basically I was able to use it almost as is.

[00:31:07]

Using ChatGPT to Convert Python to C Code

And so particularly, you know, for people who are much more Python programmers nowadays, like me, this is a nice way to write 95% of the code you need. What else do we have to change? Well, as we saw in our picture earlier, it’s not called blockidx, it’s called blockidx.x, blockdim.x, threadidx.x, so we have to add the .x there. Other than that, if we compare, so as you can see, these two pieces of code look nearly identical. We’ve had to add data types to them, we’ve had to add semicolons, we had to get rid of the colon, we had to add curly brackets, that’s about it. So it’s not very different at all.

[00:32:01]

Understanding C Data Types

So if you haven’t done much C programming, yeah, don’t worry about it too much, because, you know, the truth is actually it’s not that different for this kind of calculation-intensive work. One thing we should talk about is this, what’s unsigned char star? This is just how you write uint8 in C. You can just, if you’re not sure how to change a data type between the PyTorch spelling and the C spelling, you could ask chatgpt, or you can Google it, but this is how you write byte. The star, in practice, it’s basically how you say this is an array. So this says that x is an array of bytes. It actually means it’s a pointer, but pointers are treated, as you can see here, as arrays by C.

[00:33:04]

So you don’t really have to worry about the fact it’s a pointer, it just means for us that it’s an array. But in C, the only kind of arrays that it knows how to deal with are these one-dimensional arrays, and that’s why we always have to flatten things out. We can’t use multidimensional tensors really directly in these CUDA kernels in this way. So we’re going to end up with these one-dimensional C arrays. Yeah, other than that, it’s going to look exactly, in fact, I mean, even because we did our Python like that, it’s going to look identical.

CUDA Kernel Syntax

The void here just means it doesn’t return anything, and then the dunder global here is a special thing added by CUDA, and there’s three things that can appear. And this simply says, what should I compile this to do? And so you can put dunder device, and that means compile it so that you can only call it on the GPU. You can say dunder global, and that says, okay, you can call it from the CPU or GPU, and it’ll run on the GPU.

[00:34:08]

Or you can write dunder host, which you don’t have to, and that just means it’s a normal C or C++ program that runs on the CPU side. So anytime we want to call something from the CPU side to run something on the GPU, which is basically almost always when we’re doing kernels, you write dunder global. So here we’ve got dunder global, we’ve got our kernel, and that’s it. So then we need the thing to call that kernel.

Calling CUDA Kernels

So earlier to call the kernel, we called this block kernel function, passed in the kernel, and passed in the blocks and threads and the arguments. With CUDA, we don’t have to use a special function, there is a weird special syntax built into kernel to do it for us. To use the weird special syntax, you say, okay, what’s the kernel, the function that I want to call, and then you use these weird triple angle brackets.

[00:35:06]

Triple Angle Bracket Syntax

So the triple angle brackets is a special CUDA extension to the C++ language, and it means this is a kernel, please call it on the GPU. And between the triple angle brackets, there’s a number of things you can pass, but you have to pass at least the first two things, which is how many blocks, how many threads. So how many blocks, ceiling division, number of pixels divided by threads, and how many threads. As we said before, let’s just pick 256 all the time and not worry about it. So that says call this function as the GPU kernel, and then passing in these arguments. We have to pass in our input tensor, our output tensor, and how many pixels. And you’ll see that for each of these tensors, we have to use a special method.data pointer, and that’s going to convert it into a C pointer to the tensor.

[00:36:07]

Data Pointer Method

So that’s why by the time it arrives in our kernel, it’s a C pointer. You also have to tell it what data type you want it to be treated as. This says treat it as Uint8. So this is a C++ template parameter here, and this is a method. The other thing you need to know is in C++, dot means call a method of an object, or else colon colon is basically like in C, in Python, calling a method of a class. So you don’t say torch dot empty, you say torch colon colon empty to create our output, or else back we did it in Python, we said torch dot empty. Also in Python, oh, okay, so in Python, that’s right, we just created a length and vector and then did a dot view.

[00:37:06]

It doesn’t really matter how we do it, but in this case, we actually created a two-dimensional tensor by passing, we passed in this thing in curly brackets here, this is called a C++ list initializer, and it’s just basically a little list containing height comma width. So this tells it to create a two-dimensional matrix, which is why we don’t need dot view at the end.

Creating Output Tensors

We could have done it the dot view way as well, probably be better to keep it consistent, but this is what I wrote at the time. The other interesting thing when we create the output is if you pass in input dot options, so this is our input tensor, that just says, oh, use the same data type and the same device, CUDA device, as our input has. This is a nice really convenient way, which I don’t even think we have in Python, to say, make sure that this is the same data type in the same device. If you say auto here, this is quite convenient, you don’t have to specify what type this is.

[00:38:02]

We could have written torch colon colon tensor, but by writing auto, it just says figure it out yourself, which is another convenient little C++ thing. After we call the kernel, if there’s an error in it, we won’t necessarily get told, so to tell it to check for an error, you have to write this.

Checking for CUDA Errors

This is a macro that’s again provided by PyTorch. The details don’t matter. You should just always call it after you call a kernel to make sure it works, and then you can return the tensor that you allocated, and then you passed as a pointer, and then that you filled in. Okay, now, as well as the CUDA source, you also need C++ source, and the C++ source is just something that says here is a list of all of the details of the functions that I want you to make available to the outside world, in this case Python.

[00:39:04]

C++ Source Code

And so this is basically your header, effectively. So you can just copy and paste the full line here from your function definition, and stick a semicolon on the end. That’s something you can always do. And so then we call our load CUDA function that we looked at earlier, passing in the CUDA source code, the C++ source code, and then a list of the names of the functions that are defined there that you want to make available to Python.

Loading CUDA Module

So we just have one, which is the RGB to grayscale function. And believe it or not, that’s all you have to do. This will automatically, you can see it running in the background now, compiling with a hugely long thing, our files. So it’s created a main.cpp for us, and it’s going to put it into a main.o for us, and compile everything up, link it all together, and create a module.

[00:40:05]

And you can see here, we then take that module, it’s been passed back, put it into a variable called module. And then when it’s done, it will load that module. And if we look inside the module that we just created, you’ll see now that apart from the normal auto-generated stuff Python adds, it’s got a function in it, RGB to grayscale.

Using CUDA Function from Python

Okay, so that’s amazing. We now have a CUDA function that’s been made available from Python. And we can even see, if we want to, this is where it put it all. So we can have a look. And there it is. You can see it’s created a main.cpp, it’s compiled it into a main.o, it’s created a library that we can load up, it’s created a CUDA file, it’s created a build script, and we could have a look at that build script if we wanted to.

[00:41:04]

And there it is. So none of this matters too much. It’s just nice to know that PyTorch is doing all this stuff for us, and we don’t have to worry about it. So that’s pretty cool. So in order to pass a tensor to this, we’re going to be checking that it’s contiguous and on CUDA. So we better make sure it is.

Running CUDA Code on the Full Image

So we’re going to create an imageC variable, which is the image made contiguous and put onto the CUDA device. And now we can actually run this on the full-sized image, not on the tiny little minimized image we created before. This has got much more pixels. It’s got 1.7 million pixels, whereas before we had, I think it was 35,000, 34,000. And it’s gone down from 1.5 seconds to 1 millisecond.

[00:42:01]

So that is amazing. It’s dramatically faster, both because it’s now running in compiled code and because it’s running on the GPU. The step of putting the data onto the GPU is not part of what we timed. And that’s probably fair enough, because normally you do that once, and then you run a whole lot of CUDA things on it. We have, though, included the step of moving it off the GPU and putting it onto the CPU as part of what we’re timing. And one key reason for that is that if we didn’t do that, it can actually run our Python code at the same time that the CUDA code is still running, and so the amount of time shown could be dramatically less, because it hasn’t finished synchronizing. So by adding this, it forces it to complete the CUDA run and to put the data back onto the CPU.

[00:43:02]

That kind of synchronization you can also trigger just by printing a value from it, or you can synchronize it manually. So after we’ve done that, and we can have a look, and we should get exactly the same grayscale puppy.

Creating a CUDA Kernel from Python

Okay. So we have successfully created our first real working code from Python, a CUDA kernel. This approach of writing it in Python and then converting it to CUDA is not particularly common. But I’m not just doing it as an educational exercise. That’s how I like to write my CUDA kernels. At least as much of it as I can. Because it’s much easier to debug in Python.

[00:44:03]

It’s much easier to see exactly what’s going on. And so, and I don’t have to worry about compiling. It takes about 45 or 50 seconds to compile even our simple example here. I can just run it straight away. And once it’s working, to convert that into C, as I mentioned, you know, ChatGPT can do most of it for us. So I think this is actually a fantastically good way of writing CUDA kernels, even as you start to get somewhat familiar with them.

Benefits of Writing CUDA Kernels in Python

Because it lets you debug and develop much more quickly. A lot of people avoid writing CUDA just because that process is so painful. And so here’s a way that we can make that process less painful. So let’s do it again.

Implementing Matrix Multiplication

And this time we’re going to do it to implement something very important, which is matrix multiplication.

[00:45:00]

So matrix multiplication, as you probably know, is fundamentally critical for deep learning. It’s like the most basic linear algebra operation we have. And the way it works is that you have an input matrix M and a second input matrix N. And we go through every row of M. So we go through every row of M until we get to, here we are up to this one. And every column of N. And here we are up to this one. And then we take the dot product at each point of that row with that column. And this here is the dot product of those two things. And that is what matrix multiplication is. So it’s a very simple operation, conceptually. And it’s one that we do many, many, many times in deep learning.

[00:46:03]

And basically every deep learning, every neural network has this as its most fundamental operation. Of course, we don’t actually need to implement matrix multiplication from scratch, because it’s done for us in libraries. But we will often do things where we have to kind of fuse in some kind of matrix multiplication-like pieces. And so, you know, and of course, it’s also just a good exercise.

Matrix Multiplication in Python

So let’s take a look at how to do matrix multiplication, first of all, in pure Python. So in the actually in the fast.ai course that I mentioned, there’s a very complete in-depth dive into matrix multiplication in part two, lesson 11, where we spend like an hour or two talking about nothing but matrix multiplication. We’re not going to go into that much detail here. But what we do do in that is we use the MNIST data set to do this.

[00:47:03]

And so we’re going to do the same thing here. We’re going to grab the MNIST data set of handwritten digits. And they are 28 by 28 digits. They look like this. 28 by 28 is 784. So to do a, you know, to basically do a single layer of a neural net, or without the activation function, we would do a matrix multiplication of the image flattened out by a weight matrix with 784 rows and however many columns we like. And I’m going to need, if we’re going to go straight to the output, so this would be a linear function, a linear model, we’d have 10 layers, one for each digit. So here’s, this is our weights. We’re not actually going to do any learning here. This is just not any deep learning or logistic regression learning. This is just for an example. Okay. So we’ve got our weights and we’ve got our input data, xchain and xvalid.

[00:48:05]

And so we’re going to start off by implementing this in Python. Now, again, Python’s really slow. So let’s make this smaller. So matrix one will just be five rows. Matrix two will be all the weights. So that’s going to be a 5 by 784 matrix multiplied by a 784 by 10 matrix. Now, these two have to match. Of course, they have to match because otherwise this dot product won’t work. Those two are going to have to match the row by the column. Okay. So let’s pull that out into A rows, A columns, B rows, B columns. And obviously A columns and B rows are the things that have to match. And then the output will be A rows by B columns. So 5 by 10. So let’s create an output full of zeros with rows by columns in it.

[00:49:03]

And so now we can go ahead and go through every row of A, every column of B, and do the dot product, which involves going through every item in the innermost dimension, all 784 of them, multiplying together the equivalent things from M1 and M2 and summing them up into the output tensor that we created. So that’s going to give us, as we said, a 5 by 10 output. And here it is.

Creating a Matrix Multiplication Function

Okay. So this is how I always create things in Python. I basically almost never have to debug. I almost never have, like, errors, unexpected errors in my code because I’ve written every single line one step at a time in Python. I’ve checked them all as I go.

[00:50:00]

And then I copy all cells and merge them together, stick a function header on, like so. And so here is matmul. So this is exactly the code we’ve already seen. And we can call it. And we’ll see that for 39,200 innermost operations, we took us about a second. So that’s pretty slow.

Implementing Matrix Multiplication Kernel

Okay. So now that we’ve done that, you might not be surprised to hear that we now need to do the innermost loop as a kernel call in such a way that it can be run in parallel. Now, in this case, the innermost loop is not this line of code. It’s actually this line of code. I mean, we can choose to be whatever we want it to be. But in this case, this is how we’re going to do it. We’re going to say for every cell in the output tensor, like this one here, is going to be one CUDA thread.

[00:51:07]

So one CUDA thread is going to do the dot product. So this is the bit that does the dot product. So that’ll be our kernel. So we can write that matmul block kernel is going to contain that. Okay. So that’s exactly the same thing that we just copied from above. And so now we’re going to need something to run this kernel. And you might not be surprised to hear that in CUDA, we are going to call this using blocks and threads.

Using 2D Blocks and Threads

But something that’s rather handy in CUDA is that the blocks and threads don’t have to be just a 1D vector. They can be a 2D or even 3D tensor.

[00:52:01]

So in this case, you can see we’ve got one, two, it’s a little hard to see exactly where they stop, two, three, four blocks. And so then for each block, and that’s kind of in one dimension, and then there’s also one, two, three, four, five blocks in the other dimension. And so each of these blocks has an index. So this one here is going to be zero, zero, a little bit hard to see. This one here is going to be one, three, and so forth. And this one over here is going to be three, four.

[00:53:01]

So rather than just having a integer block index, we’re going to have a tuple block index. And then within a block, there’s going to be, to pick let’s say this exact spot here, didn’t do that very well, there’s going to be a thread index. And again, the thread index won’t be a single index into a vector. It’ll be a, it’ll be two elements. So in this case, it’d be 0, 1, 2, 3, 4, 5, 6, rows down, and 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, I can’t count, 12 maybe across. So the, this here is actually going to be defined by two things. One is by the block.

[00:54:02]

And so the block is 3, 4. And the thread is 6, 12. So that’s how CUDA lets us index into two-dimensional grids using blocks and threads.

Indexing into 2D Grids

We don’t have to. It’s just a convenience if we want to. And in fact, it can, we can use up to three dimensions. So to create our kernel runner, now rather than just having, so rather than just having two nested loops for blocks and threads, we’re going to have to have two, lots of two nested loops for our, both of our x and y blocks and threads, or our rows and columns blocks and threads.

[00:55:10]

Creating a 2D Block Kernel Runner

So it ends up looking a bit messy, because we now have four nested for loops. So we’ll go through our blocks on the y-axis, and then through our blocks on the x-axis, and then through our threads on the y-axis, and then through our threads on the x-axis. And so what that means is that for, you can think of this Cartesian product as being for each block, for each thread. Now to get the dot y and the dot x, we’ll use this handy little Python standard library thing called simple namespace. I’d use that so much, I just give it an NS name, because I use namespaces all the time in my quick and dirty code. So we go through all those four. We then call our kernel, and we pass in an object containing the y and x coordinates.

[00:56:04]

And that’s going to be our block. And we also pass in our thread, which is an object with the y and x coordinates of our thread. And it’s going to eventually do all possible blocks, and all possible threads numbers for each of those blocks. And we also need to tell it how big is each block, how high and how wide. And so that’s what this is. This is going to be a simple namespace, an object with an x and y, as you can see. So we need to know how big they are. Just like earlier on, we had to know the block dimension, and that’s why we passed in threads. So remember this is all pure PyTorch. We’re not actually calling any out to any CUDA, we’re not calling out to any libraries other than just a tiny bit of PyTorch for the indexing and tensor creation. So you can run all of this by hand, make sure you understand.

[00:57:01]

You can put it in the debugger, you can step through it.

Calling the Matrix Multiplication Kernel

And so it’s going to call our function. So here’s our matrix multiplication function. As we said, it’s a kernel that contains the dot product that we wrote earlier. So now the guard is going to have to check that the row number we’re up to is not taller than we have, and the column number we’re up to is not wider than we have. And we also need to know what row number we’re up to. And this is exactly the same, actually I should say the column, is exactly the same as we’ve seen before. And in fact, you might remember in the CUDA, we had block idx.x. And this is why, right? Because in CUDA, it always gives you these three-dimensional DIMM3 structures. So you have to put this dot x. So we can find out the column this way, and then we can find out the row by seeing how many blocks have we gone through, how big is each block in the y-axis, and how many threads have we gone through in the y-axis.

[00:58:05]

So what row number are we up to? What column number are we up to? Is that inside the bounds of our tensor? If not, then just stop. And then otherwise, do our dot product and put it into our output tensor.

Pure Python Implementation of Matrix Multiplication

So that’s all pure Python. And so now we can call it by getting the height width of our first input, the height and width of our second input, and so then k and k2, the inner dimensions, ought to match. We can then create our output. And so now threads per block is not just the number 256, but it’s a pair of numbers. It’s an x and a y. And we’ve selected two numbers that multiply together to create 256. So again, this is a reasonable choice if you’ve got two-dimensional inputs to spread it out nicely.

[00:59:04]

One thing to be aware of here is that your threads per block can’t be bigger than 1024. So we’re using 256, which is safe, right? And notice that you have to multiply these together. 16 times 16 is going to be the number of threads per block. So these are safe numbers to use. You’re not going to run out of blocks, though. 2 to the 31 is the number of maximum blocks for dimension 0, and then 2 to the 16 for dimensions 1 and 2. I think it’s actually minus 1, but don’t worry about that. So don’t have too many threads, but you can have lots of blocks. But of course, each symmetric model processor is going to run all of these on the same device, and they’re also going to have access to shared memory. So that’s why you use a few threads per block. So our blocks, the x, we’re going to use the ceiling division.

[01:00:02]

The y, we’re going to use the same ceiling division. So if any of this is unfamiliar, go back to our earlier example, because the code’s all copied from there. And now we can call our 2D block kernel runner, passing in the kernel, the number of blocks, the number of threads per block, our input matrices flattened out, our output matrix flattened out, and the dimensions that it needs, because they get all used here. And return the result. And so if we call that matmul with a 2D block, and we can check that they are close to what we got in our original manual loops, and of course they are because it’s running the same code.

Implementing Matrix Multiplication in CUDA

So now that we’ve done that, we can do the CUDA version. Now the CUDA version is going to be so much faster, we do not need to use this slimmed down matrix anymore.

[01:01:05]

Using the Full MNIST Dataset

We can use the whole thing. So to check that it’s correct, I want a fast CPU based approach that I can compare to. So previously it took about a second to do 39,000 elements. So I’m not going to explain how this works, but I’m going to use a broadcasting approach to get a fast CPU based approach.

Fast CPU-Based Matrix Multiplication

If you check the fast AI course, we teach you how to do this broadcasting approach, but it’s a pure Python approach which manages to do it all in a single loop rather than three nested loops. It gives the same answer for the cut down tensors, but much faster, only four milliseconds. And so it’s fast enough that we can now run it on the whole input matrices, and it takes about 1.3 seconds.

[01:02:03]

And so this broadcast optimized version, as you can see, it’s much faster. And now we’ve got 392 million additions going on in the middle of our three loops, effectively three loops, but we’re broadcasting them. So this is much faster. But the reason I’m really doing this is so that we can store this result to compare to. So that makes sure that our CUDA version is correct.

Converting Python Matrix Multiplication to CUDA

Okay. So how do we convert this to CUDA? You might not be surprised to hear that what I did was I grabbed this function and I passed it over to chat.gpt and said, please rewrite this in C. And it gave me something basically that I could use first time. And here it is. This time I don’t have unsigned char star, I have float star. Other than that, this looks almost exactly like the Python we had, with exactly the same changes we saw before.

[01:03:03]

We’ve now got the dot Y and dot X versions. Once again, we’ve got dunder global, which says please run this on the GPU when we call it from the CPU. So the kernel, I don’t think there’s anything to talk about there.

Calling the CUDA Matrix Multiplication Kernel

And then the thing that calls the kernel, matmul, is going to be passed in two tensors. We’re going to check that they’re both contiguous and check that they are on the CUDA device. We’ll grab the height and width of the first and second tensors. We’re going to grab the inner dimension. We’ll make sure that the inner dimensions of the two matrices match, just like before. And this is how you do an assertion in PyTorch CUDA code. You call torch check, pass in the thing to check, pass in the message to pop up if there’s a problem. So these are a really good thing to spread around all through your CUDA code to make sure that everything is as you thought it was going to be.

[01:04:03]

Just like before, we create an output. So now when we create our number of threads, we don’t say threads is 256.

Using Dim3 Structure for Threads per Block

We instead say this is a special thing provided by CUDA for us, dim3. So this is basically a tuple with three elements. So we’re going to create a dim3 called tpb. It’s going to be 16 by 16. Now I said it has three elements. Where’s the third one? That’s okay. It just treats the third one as being one. So we can ignore it. So that’s the number of threads per block. And then how many blocks will there be? Well, in the x dimension, it’ll be w divided by x, ceiling division. In the y dimension, it will be h divided by y, ceiling division. And so that’s the number of blocks we have. So just like before, we call our kernel just by calling it like a normal function, but then we add this weird triple angle bracket thing telling it how many blocks and how many threads.

[01:05:09]

Calling CUDA Kernels with Dim3 Structures

So these aren’t ints anymore. These are now dim3 structures. And that’s what we use, these dim3 structures. And in fact, even before, what actually happened behind the scenes when we did the grayscale thing is even though we passed in 256 instance, we actually ended up with a dim3 structure, just in which case the index 1 and 2 or the dot y and dot z values were just set to 1 automatically. So we’ve actually already used a dim3 structure without quite realizing it. And then just like before, pass in all of the tensors we want, casting them to pointers.

[01:06:01]

Maybe they’re not just casting, converting them to pointers to some particular data type, and passing in any other information that our function will need, our kernel will need.

Loading the CUDA Module

Okay, so then we call load CUDA again. That’ll compile this into a module. Make sure that they’re both contiguous and on the CUDA device. And then after we call module.mapmul, passing those in, putting on the CPU, and checking that they’re all close, and it says yes they are. So this is now running not on just the first five rows, but on the entire MNIST dataset.

Running CUDA Matrix Multiplication on MNIST

And on the entire MNIST dataset, using a optimized CPU approach, it took 1.3 seconds. Using CUDA, it takes 6 milliseconds. So that is quite a big improvement.

[01:07:02]

Cool. The other thing I will mention, is PyTorch can do a matrix multiplication for us just by using at.

Using PyTorch’s Matrix Multiplication Function

And obviously gives the same answer. How long does that take to run? That takes 2 milliseconds. So three times faster. And in many situations it will be much more than three times faster. So why are we still pretty slow compared to PyTorch? I mean, this isn’t bad to do 392 million of these calculations in 6 milliseconds. But if PyTorch can do it so much faster, what are they doing?

Optimizing CUDA Performance with Shared Memory

Well, the trick is that they are taking advantage in particular of this shared memory. So shared memory is a small memory space that is shared amongst the threads in a block, and it is much faster than global memory.

[01:08:00]

In our matrix multiplication, when we have one of these blocks, and so it’s going to do one block at a time all in the same SM, it’s going to be reusing the same 16 by 16 block. It’s going to be using the same 16 rows and columns again and again and again, each time with access to the same shared memory. So you can see how you could really potentially cache a lot of the information you need and reuse it rather than going back to the slower memory again and again. So this is an example of the kinds of things that you could optimize potentially once you get to that point.

Using 1D, 2D, or 3D Blocks and Threads

The only other thing that I wanted to mention here is that this 2D block idea is totally optional. You can do everything with 1D blocks or with 2D blocks or with 3D blocks and threads. And just to show that, I’ve actually got an example at the end here which converts RGB to grayscale using the 2D blocks.

[01:09:11]

Implementing RGB to Grayscale with 2D Blocks

Because remember earlier when we did this, it was with 1D blocks. It gives exactly the same result. And if we compare the code, so if we compare the code, the version actually that was done with the 1D threads and blocks is quite a bit shorter than the version that uses 2D threads and blocks. And so in this case, even as though we’re manipulating pixels, where you might think that using the 2D approach would be neater and more convenient, in this particular case it wasn’t really. I mean it’s still pretty simple code, but we have to deal with the columns and rows.x.y separately. The guard’s a little bit more complex.

[01:10:00]

We have to find out what index we’re actually up to here. Whereas this kernel, it was just much more direct, just two lines of code. And then calling the kernel, you know, again it’s a little bit more complex with the threads per block stuff rather than this. But the key thing I wanted to point out is that these two pieces of code do exactly the same thing.

Choosing the Block and Thread Structure

So don’t feel like if you don’t want to use a 2D or 3D block thread structure, you don’t have to. You can just use a 1D one. The 2D stuff is only there if it’s convenient for you to use, and you want to use it. Don’t feel like you have to. So yeah, I think that’s basically all the key things that I wanted to show you all today.

Conclusion

The main thing I hope you take from this is that even for Python programmers, for data scientists, it’s not way outside our comfort zone.

[01:11:03]

You know, we can write these things in Python, we can convert them pretty much automatically. We end up with code that doesn’t look, you know, it looks reasonably familiar, even though it’s now in a different language. We can do everything inside notebooks. We can test everything as we go. We can print things from our kernels. And so, you know, it’s hopefully feeling a little bit less beyond our capabilities than we might have previously imagined.

Importance of CUDA for Modern Deep Learning

So I’d say, yeah, you know, go for it. I think it’s also, like, I think it’s increasingly important to be able to write CUDA code nowadays, because for things like flash attention, or for things like quantization, GPTQ, AWQ, bits and bytes, these are all things you can’t write in PyTorch. Our models are getting more sophisticated.

[01:12:03]

The kind of assumptions that libraries like PyTorch make about what we want to do are, you know, increasingly less and less accurate. So we’re having to do more and more of this stuff ourselves nowadays in CUDA. And so I think it’s a really valuable capability to have.

Setting up CUDA on Local Machines

Now, the other thing I’ll mention is we did it all in Colab today. But we can also do things on our own machines, if you have a GPU or on a cloud machine. And getting set up for this, again, it’s much less complicated than you might expect. And in fact, I can show you, it’s basically like four lines of code, or four lines or three or four lines of bash script to get it all set up. It’ll run on Windows, it’ll run under WSL, it’ll also run on Linux. Of course, CUDA stuff doesn’t really work on Mac, so not on Mac.

[01:13:01]

Actually, I’ll put a link to this into the video notes. But for now, I’m just going to jump to a Twitter thread where I wrote this all down, to show you all the steps.

Using Conda for CUDA Development

So the way to do it is to use something called Conda. Conda is something that very, very, very few people understand. A lot of people think it’s like a replacement for like pip or poetry or something. It’s not. It’s better to think of it as a replacement for Docker. You can literally have multiple different versions of Python, multiple different versions of CUDA, multiple different C++ compilation systems, all in parallel at the same time on your machine, and switch between them. You can only do this with Conda. And everything just works, right? So you don’t have to worry about all the confusing stuff around .run files or Ubuntu packages or anything like that.

[01:14:00]

You can do everything with just Conda.

Installing Conda

You need to install Conda. I’ve actually got a script, which you just run the script. It’s a tiny script, as you see. If you just run the script, it’ll automatically figure out which mini Conda you need. It’ll automatically figure out what shell you’re on, and it’ll just go ahead and download it and install it for you. Okay, so run that script, restart your terminal. Now you’ve got Conda.

Finding the Required CUDA Version

Step two is find out what version of CUDA PyTorch wants you to have. So if I click Linux, Conda, CUDA, 12.1 is the latest. So then step three is run this shell command, replacing 12.1 with whatever the current version of PyTorch is.

Installing CUDA Tools

It’s actually 12.1 for me at this point. And that will install everything. All the stuff you need to profile, debug, build, etc.

[01:15:03]

All the NVIDIA tools you need, the full suite, will all be installed, and it’s coming directly from NVIDIA, so you’ll have the proper versions. As I said, you can have multiple versions installed at once in different environments, no problem at all.

Installing PyTorch with CUDA

And then finally, install PyTorch. And this command here will install PyTorch. For some reason I wrote nightly here. You don’t need the nightly, so just remove dash nightly. So this will install the latest version of PyTorch using the NVIDIA CUDA stuff that you just installed. If you’ve used Conda before and it was really slow, that’s because it used to use a different solver which was thousands or tens of thousands of times slower than the modern one, which has just been added and made default in the last couple of months. So nowadays this should all run very fast. And as I said, it’ll run under WSL on Windows. It’ll run on Ubuntu. It’ll run on Fedora. It’ll run on Debian.

[01:16:00]

It’ll all just work.

Benefits of Using Conda

So that’s how I strongly recommend getting yourself set up for local development. You don’t need to worry about using Docker. As I said, you can switch between different CUDA versions, different Python versions, different compilers, and so forth, without having to worry about any of the Docker stuff. And it’s also efficient enough that if you’ve got the same libraries and so forth installed in multiple environments, it’ll hard link them. So it won’t even use additional hard drive space. So it’s also very efficient.

Setting up CUDA on Cloud Machines

Great. So that’s how you can get started on your own machine, or on the cloud, or whatever. So hopefully you’ll find that helpful as well.

Next Steps

All right. Thanks very much for watching. I hope you found this useful. And I look forward to hearing about what you create with CUDA.

[01:17:00]

In terms of going to the next steps, check out the other CUDA mode lectures. I will link to them. And I would also recommend trying out some projects of your own. So, for example, you could try to implement something like 4-bit quantization, or FlashAttention, or anything like that. Those are pretty big projects, but you can try to break them up into smaller things that you build up one step at a time. And, of course, look at other people’s code. So look at the implementation of FlashAttention, look at the implementation of Bits and Bytes, look at the implementation of GPTQ, and so forth. The more time you spend reading other people’s code, the better.

Outro

All right. I hope you found this useful. And thank you very much for watching.