drawing

For about two years now I have been applying deep learning to medical applications such as hearing aids. During this time I used PyTorch to build neural networks and learned all about the fundamentals of building and training a NN. My confidence was at an all-time high in my ability in this field, but then I came to a realization, I had no idea what PyTorch was doing under the hood.

I knew the basics of the math behind NN. Dot products, derivatives, backpropagation, etc. However, what I had no idea about were these terms that I saw floating around the PyTorch community such as autograd and computational graphs. What the heck does autograd even mean, does it relate to graduating something? I wanted to learn and understand these terms and I knew there was no other place to start except with the person himself.

drawing
Andrej Karpathy

Former director of AI and autopilot vision at Tesla Andrej Karpathy has put out great educational content on youtube. The most valuable one, in my opinion, is his walk-through video of micrograd, an ML framework he wrote. During this walkthrough, while I don’t know if he says the terms ‘autograd’ and ‘computational graph’ exactly, he does an excellent job at explaining each.

Now that I had a good grasp of what these meant, I wanted to try to develop my version of micrograd called slimgrad. I wanted slimgrad to be different from micrograd in that I wanted it to have matrix operations as well. The reason for this is that I find it easier to conceptualize NN’s using matrices. Then I ran into the question, how the heck do you represent matrix operations, specifically dot products, in the form of a computational graph?

What even is a computational graph?

Before I dive into the question I asked above I want to take a moment to explain what even a computational graph is. All a computational graph does is show us a mathematical expression in the form of a directed graph. Being in the form of a directed graph helps us in understanding the flow of how it operates. Take a look at this example.

drawing

Here we can see the operation $5+4=9$ in the form of a computational graph where $5$, $4$, $+$, and $9$ are all represented as nodes. This is just a basic example but any mathematical operation can be depicted this way. Now that we understand what a computational graph is we are one step closer to understanding the computational graph of a dot product.

What is a matrix?

Studying computer science in college brought its fair share of matrices. However, just like all other math, I studied for the test and then forgot everything shortly after. It wasn’t until I was trying to implement NN’s that I needed to have a better understanding of what matrices are.

At first, I thought they were in a class of their own. You know there is algebra here, calculus over there, and matrices somewhere else. I never thought about how they related to one another. After hours of pulling my hair out, I finally concluded what matrices are. A matrix is a data structure to store values. Matrix operations, such as dot products, instruct us on the order in which we need to perform scalar operations to keep the original structure of the two matrices. Take a look at this illustration.

drawing
Visualization of the dot product

Here we can see an illustration of the dot product between two matrices. As you can see, the dot product tells us the order in which we need to apply scalar operations. Beyond the order of operations for the dot product, there is no complex math, just addition, and multiplication.

What does this mean for the computational graph of a dot product?

Now to answer the question the big question of this blog post, what does the computational graph of a dot product look like? After accounting for everything we talked about, it won’t look any different than the example we did for the computational graph of $5+4=9$. This is because all the scalar operations that take place within a dot product are addition and multiplication. Take a look at this example.

drawing
Computational graph of the dot product

As you can see this computational graph is no different than any other one we would do for scalar operations.

Conclusion

Now I am no mathematician, but it seems like at the end of the day fields of math such as matrices, calculus, and others boil down to scalar operations. Knowing this makes it easier to build computational graphs not only for matrices, but for all branches of mathematics!