Lab 6: Matrix Multiply – Tensor Cores
Prologue: Logistics
Due Dates
For this lab, you’ll be turning in the following deliverables:
-
Checkpoint: Due Monday, October 21, 11:59pm (Gradescope)
-
Final Submission: Due Friday, October 25, 11:59pm (Gradescope)
See the “Deliverables” section at the end of this document for more information on what you’ll be turning in.
Starter Code
You can get the starter code for this lab by cloning the lab repository:
Update October 18: The starter code has been updated since its initial release, and now includes the file
matmul_3.cu
for Part 3. If you cloned the repo before this update, please pull the latest commit, or manually copy the newmatmul_3.cu
file to your computer. (link to raw file)
Introduction
Goals for This Lab
So far in our exploration of matrix multiplication, we’ve focused primarily on optimizing data movement (Lab 4) and work partitioning (Lab 5). As we’ve worked to reduce bottlenecks along those dimensions, the run times of our implementations have increasingly become dominated by the cost of the floating point computations in our kernels’ innermost loops. Up until now, we’ve been implementing those core floating point computations using fused multiply-add (FMA) instructions. However, we can do better: modern NVIDIA GPUs support so-called “tensor core” instructions, which are designed specifically to accelerate matrix multiplication workloads. In this third and final matrix multiplication lab, we’ll be looking at how to use those tensor core instructions to speed up our kernels.
The tensor core instructions we’ll be using in this lab aren’t exposed by default in the CUDA C++ language, so we’ll be accessing them via inline PTX assembly.1 Since we haven’t worked with inline PTX before, this lab will walk through how to make use of tensor core instructions step-by-step:
-
The first part of this lab is a brief introduction to inline PTX, using a simple bit manipulation instruction as an example.
-
Next, we’ll look at how to access tensor core instructions in PTX, and how to work with the data layouts those tensor core instructions expect.
-
Finally, we’ll integrate tensor core instructions into our full matrix multiplication kernel, and try to obtain a speedup over what were able to achieve using FMAs.
Note on Terminology: What is a “Tensor Core?”
Although the phrase “tensor core” might conjure up mental images of something similar to a “CPU core” – perhaps something with its own register file and program counter, decoding and executing a programmable stream of instructions in sequence – a tensor core is not actually that kind of “core” in the traditional computer architecture sense.
The phrase “tensor core” is just NVIDIA’s name for a particular kind of functional unit which exists on recent generations of NVIDIA GPUs. Tensor cores are not fundamentally different from ALUs or FPUs – each tensor core is attached to a warp scheduler, and solely executes math operations.
From a software point of view, tensor cores simply provide another kind of math instruction which your code is able to invoke. As we’ll see, these tensor core instructions have some interesting and unusual properties, but their existence doesn’t radically alter anything about the CUDA programming model.
Part 1: Warm-Up – Introduction to Inline PTX
Before attempting to use tensor cores, we’ll look at an example of how to use inline PTX to access a simpler instruction: the lop3
instruction, which stands for “Logical Operation on 3 Inputs.”
The lop3
instruction generalizes bit-wise operations like &
, |
, ~
, and ^
. It computes an arbitrary bit-wise function of three 32-bit integer inputs (per CUDA thread), with the exact behavior of the bit-wise function determined by a programmer-supplied lookup table. Because there are only possible tuples of three bits, the full contents of this lookup table can be encoded in a single 8-bit integer. In this exercise, we’ll be looking at how to use lop3
to compute the specific bit-wise operation (a & b) | c
, the lop3
lookup table for which can be expressed using the constant 0b11101010
.
The lop3
instruction can be useful in applications requiring bit-level manipulation, such as when decompressing neural network weights from 4-bit to 16-bit precision (search for “lop3” in that paper!). However, the CUDA C++ language does not by default expose any way to explicitly invoke lop3
from user code.2 To access it, we’ll need to use inline PTX.
To embed inline PTX in our CUDA programs, we can use the asm(...)
construct, which looks like this:
asm;
You can read about CUDA’s inline PTX syntax in detail here: inline PTX syntax docs. For our purposes, the following should be sufficient:
-
The
lop3
instruction we want to invoke can be written using the PTX string"lop3.b32 %0, %1, %2, %3, 0b11101010;"
-
The identifier
lop3.b32
is the full name of the instruction, with the suffixb32
indicating that it’s meant to operate on 32-bit values. -
The percent-prefixed numbers
%0
,%1
,%2
,%3
are “holes” in the PTX string, which will be replaced with references to specific PTX registers during compilation. We’ll need to associate each hole with a variable from the surrounding CUDA program in order for the compiler to know which register names it should plug in for the final PTX. -
The
%0
operand is the destination register, and the%1
,%2
,%3
operands are the source registers (PTX is more like “Intel syntax” than “AT&T syntax”). -
The constant
0b11101010
encodes the lookup table corresponding to the bit-wise operation(a & b) | c
. -
Note the quotes, and the final semicolon in the PTX inside the quotes!
-
-
To specify an output variable, use either the syntax
"=<type-abbreviation>"(var_name)
or"+<type-abbreviation>"(var_name)
-
Use
=
if the output operand is only written to by the instruction, and+
if it is both read from and written to. (In the case oflop3.b32
, and for the tensor core instructions we’ll be seeing later, the output operands are only written to.) -
Replace
<type-abbreviation>
with a letter indicating the type of the register. For 32-bit integer registers, the appropriate letter isr
. You can find the letters corresponding to other types of registers in the inline PTX docs. -
In our case, this will look something like
"=r"(var_name)
, wherevar_name
is someuint32_t
variable from the surrounding scope in your CUDA program.
-
-
To specify an input variable, use just
"<type-abbreviation>"(var_name)
, without the leading=
or+
.-
In our case, this will look something like
"r"(var_name)
. -
When you have multiple input or output variables in an
asm
statement, you can separate them with commas, like"r"(var_a), "r"(var_b), "r"(var_c)
.
-
Now we’re ready to write some inline PTX!
Deliverable: In the file
exercise_lop3.cu
, fill in the body of the functionlop3_kernel
to perform the operation*out = (*a & *b) | *c;
using thelop3.b32
instruction, invoked via inline PTX.
Question 1 for final write-up: Look at the assembly code generated for your
exercise_lop3.cu
file, using either Compiler Explorer or the--asm
flag in Telerun. What does the generated PTX for your kernel look like? Specifically, what happened to the “holes” in the inline PTX string? What does the generated SASS look like?
Part 2: Warm-Up – Invoking Tensor Core Instructions
Now that we’ve looked at how inline PTX works, we can start using it to interact with the tensor cores on our GPU!
The RTX A4000 GPU we’re using belongs to NVIDIA’s Ampere generation (specifically, “Compute Capability 8.6”). On Ampere, there are tensor core instructions available in three different floating-point flavors:
-
f16
– The 16-bit floating point format defined by the IEEE 754 standard, with 5 exponent bits and 10 mantissa bits. -
bf16
– “Brain float 16,” with 8 exponent bits and 7 mantissa bits. Popular in deep learning. -
tf32
– “TensorFloat-32,” which is basically ordinary 32-bit floating point, but with compromises made in the accuracy of the multiplications performed by the tensor core. The mantissa of each input value is implicitly truncated to 10 bits (down from the ordinary 23) before participating in the multiplication.
Additionally, for all three of these formats, tensor cores support accumulating results in full 32-bit precision, effectively casting the results of the tensor core’s lower-precision multiplications up to FP32 before adding them together or to an existing partial sum.
Because the kernels we developed in Lab 4 and Lab 5 work in 32-bit precision, we’ll be focusing on TF32 precision in this lab for the sake of compatibility with our existing code. Given that TF32 tensor core instructions perform multiplications in lower precision than the FP32 FMAs we’ve been using until now, we can expect to see some unavoidable accuracy loss when we adapt our kernel to use tensor cores.
So – what kind of TF32-precision tensor core functionality do we actually have on our A4000 GPU? The answer is simple – we have exactly two instructions:3
mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32
(HMMA.1684.F32.TF32
in SASS)mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
(HMMA.1688.F32.TF32
in SASS)
(You can find the PTX documentation for these instructions here: MMA PTX docs.)
Both of these instructions are “matrix-multiply-accumulate” (MMA) instructions; conceptually, they each implement an operation like:
where , , , and are matrices. The two instructions differ only in the dimensions of the matrices they operate on:
Instruction Dimensions | Dimensions | Dimensions | , Dimensions |
---|---|---|---|
m16n8k4 | 16 * 4 | 4 * 8 | 16 * 8 |
m16n8k8 | 16 * 8 | 8 * 8 | 16 * 8 |
Empirically, the course staff have observed that these instructions are equivalent in terms of FLOP throughput; the m16n8k4
variant performs half as much work per instruction as m16n8k8
, but twice as many m16n8k4
instructions can execute per cycle on average as m16n8k8
.
In this part of the lab, we’ll look at how we can use the m16n8k8
TF32 MMA instruction to execute a single 16 * 8 * 8
matrix multiplication. As we’ll see, this isn’t actually trivial – in particular, it will require understanding the unusual way in which tensor core instructions expect their operands to be laid out in registers.
Warp-Level Semantics
To understand how the mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
instruction works, the most important fact to establish is that tensor core instructions fundamentally operate at the warp level.
As we saw in Labs 1 and 2, the GPU’s hardware always executes instructions in a 32-wide SIMD fashion, with every 32 consecutive CUDA threads grouped together as 32 lanes of a vector. Viewing the GPU as a SIMD machine, virtually all the instructions we’ve seen our GPU execute so far in this course have been element-wise vector operations, with each instruction applying an identical, independent operation in each lane (modulo masking). When every instruction we execute is element-wise, we can often get away with ignoring the fact that the GPU is a SIMD machine at all, and simply pretend like every CUDA thread its executing its own independent stream of instructions. However, tensor core instructions break this illusion, because they are not element-wise.4
When a warp executes a tensor core operation like our m16n8k8
instruction, it is not executing a separate, independent matrix multiplication for each CUDA thread in the warp; rather, it is executing a single 16 * 8 * 8
matrix multiplication cooperatively across the entire warp, with the input and output data for the instruction distributed across the registers of all the CUDA threads in the warp. When thinking about tensor core instructions, it’s most helpful to think of each “register” in your program as a 32-word-wide vector register, rather than as a single scalar register per CUDA thread.
With all of that in mind, let’s take a look at how the m16n8k8
instruction we’re using actually expects data to be laid out in registers. First, a bit of math:
- The matrix is
16 * 8
words, so we need16 * 8 / 32 = 4
registers to store it. - The matrix is
8 * 8
words, so we need8 * 8 / 32 = 2
registers to store it. - The / matrix is
16 * 8
words, so we need16 * 8 / 32 = 4
registers to store it.
Accordingly, the PTX syntax for invoking our tensor core instruction looks like this:
mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
{%0, %1, %2, %3}, /* 'D' matrix */
{%4, %5, %6, %7}, /* 'A' matrix */
{%8, %9}, /* 'B' matrix */
{%10, %11, %12, %13}; /* 'C' matrix */
(PTX syntax supports /* ... */
comments.)
From the perspective of each CUDA thread, each of these %0
, %1
, etc operands is a 1-word scalar register. Collectively across the entire warp, each operand is a 32-word vector register.
How does the m16n8k8
instruction expect data to be packed into these registers? We present the layouts below.
(These diagrams are courtesy of Claude 3.5 Sonnet. You can click any image to expand it. You can also access interactive versions of these diagrams here:
(The PTX documentation also contains its own versions of these diagrams.)
Essentially:
-
The ‘A’ matrix is split into 4 quadrants. Each quadrant is mapped to a separate register, first top-to-bottom then left-to-right. Within each quadrant, the data is laid out in the corresponding register in row-major order.
-
The ‘B’ matrix is split vertically into 2 halves. Each half is mapped to a separate register, first top then bottom. Within each half, the data is laid out in the corresponding register in column-major order.
-
The ‘C’ matrix (and ‘D’ matrix) is split vertically into 2 halves, and each half is sliced into alternating vertical stripes, with all the even stripes grouped together and all the odd stripes grouped together. Each stripe-set in each half is mapped to a separate register. Within each stripe-set, the data is laid out in the corresponding register in row-major order.
Recall that for the ‘A’ matrix, the “vertical” and “horizontal” dimensions correspond to the i
and k
indices in the matrix multiply computation, whereas for ‘B’ they correspond to the k
and j
indices, and for ‘C’ they correspond to i
and j
.
A Note on Register Types
You now have almost everything you need in order to invoke the mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
instruction to perform a 16 * 8 * 8
matrix-multiply-accumulate. There is, however, one remaining quirk of the PTX interface to be aware of: this PTX instruction expects every operand to be a 32-bit integer register. Of course, the bits these integer values carry will actually encode 32-bit floating-point data, but it expects them to be integer registers nonetheless. To cope with this, you can use the built-in __float_as_uint
and __uint_as_float
functions to reinterpret the bits of a float
as a uint32_t
, and vice-versa. (These conversion functions are purely a compile-time formality and should ultimately have zero cost at run time.)
Implementation
Deliverable: In the file
exercise_mma.cu
, implement the functionmma_16x8x8_kernel
to perform a single16 * 8 * 8
matrix multiplication on the matrices stored ina
andb
, and accumulate the results of that matrix multiplication intoc
, using the tensor core instructionmma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
. In addition to invoking this tensor core instruction, your kernel can use whatever additional CUDA logic you like to compute indices, move data around, etc. The data ina
,b
, andc
is stored in row-major layout in global memory. Note that the kernelmma_16x8x8_kernel
will be launched with exactly 32 CUDA threads (one warp).
Question 2 for final write-up: Look at the assembly code generated for your
exercise_mma.cu
file. What does the generated SASS look like? Can you find the tensor core instruction?
Part 3: Accelerating Matrix Multiply
Now that we’ve seen how to invoke tensor core instructions on our GPU, we’re ready to integrate them into our full matrix multiply kernel!
For this lab, we’ll be focusing on just a subset of the problem sizes from Lab 5. Here they are:
size_i | size_j | size_k |
---|---|---|
3072 | 3072 | 3072 |
512 | 3072 | 3072 |
256 | 3072 | 3072 |
128 | 3072 | 3072 |
64 | 3072 | 3072 |
32 | 3072 | 3072 |
16 | 3072 | 3072 |
Analysis
To understand the maximum performance we can achieve with our tensor core implementation on each of these problem sizes, we can repeat a similar analysis to the one we carried out for the previous lab. For this analysis, you can assume that the theoretical peak TF32 tensor core throughput on our A4000 GPU is given by:
(128 FLOP / tensor core / cycle)
* (4 tensor cores / SM)
* (48 SMs)
* (1.56 GHz)
= 38.34 TFLOP/s
Question 3 for final write-up: For each of the problem sizes in this lab, walk through the following analysis (you may find it helpful to reuse some of your calculations from Lab 5 Question 2):
Considering the total number of FLOPs required to process this problem size, what is the fastest we could process this problem size if tensor core throughput were the only constraint?
Considering (1) as well as the minimum time required to access each unique matrix element in DRAM, what lower bound does this imply for the run time of our algorithm? Is this workload compute-bound or bandwidth-bound?
Considering (2), what is the maximum TFLOP/s we could achieve on this problem size? (This is just (2) divided by the total FLOPs.)
How does (3) compare to the maximum throughput achievable if we were to use FMAs rather than tensor cores (as we calculated in Lab 5 Question 2.6)? Is the workload constrained by the same resource (either compute or bandwidth) in both cases, or is one scenario compute-bound while the other is bandwidth-bound?
Implementation
Our goal for the final part of this lab will be to write a matrix multiply kernel which uses tensor cores to run faster than any FMA-based matrix multiply realistically could on our largest problem sizes.
To calibrate our expectations for how fast an FMA-based kernel could realistically run, we’ve measured the performance of NVIDIA’s highly-optimized cuBLAS library on each problem size when running without tensor cores:5
cuBLAS Performance Without Tensor Cores:
size_i | size_j | size_k | Time (ms) | Throughput (TFLOP/s) |
---|---|---|---|---|
3072 | 3072 | 3072 | 4.05 | 14.32 |
512 | 3072 | 3072 | 0.80 | 12.07 |
256 | 3072 | 3072 | 0.46 | 10.40 |
128 | 3072 | 3072 | 0.24 | 10.00 |
64 | 3072 | 3072 | 0.13 | 9.57 |
32 | 3072 | 3072 | 0.11 | 5.26 |
16 | 3072 | 3072 | 0.11 | 2.73 |
Our goal will be to write an implementation which beats these non-tensor-core cuBLAS numbers on the following problem sizes:
3072 * 3072 * 3072
512 * 3072 * 3072
256 * 3072 * 3072
128 * 3072 * 3072
For the other problem sizes, your implementation should be correct, but it’s okay if it achieves worse performance than cuBLAS.
Deliverable: In the file
matmul_3.cu
, implement the functionlaunch_matmul_tensor
, and any associated kernels, so that whensize_i
is3072
,512
,256
, or128
, it achieves a higher throughput than our FMA-based cuBLAS baseline. To do this, you will (almost certainly) need to use tensor cores.
To hit this performance target, you don’t need any techniques other than what we’ve already discussed in Lab 4, Lab 5, and this lab. All the suggestions from the previous labs continue to apply; a few especially important ones which you may find it helpful to keep in mind are:
-
Avoid register spills, by not exceeding the register file capacity and by not accessing arrays using dynamic indices.
-
Use microtiles for register-level reuse. This is still relevant when using tensor cores! (But how should you structure the microtiles now?)
-
Overlap data movement with computation, using e.g. asynchronous copy instructions.
-
Avoid many-way bank conflicts, either by adjusting your data layouts or by adjusting your access patterns.
Good luck! Once you’ve implemented your optimized kernel, you can answer the final question of the lab:
Question 4 for final write-up: How does the performance of your implementation compare to the cuBLAS FMA baseline for each problem size? What fraction of theoretical peak throughput (calculated in Question 3.3) were you able to achieve for each problem size? What did you need to change about your kernel design in order to make use of tensor core instructions? What RRMSE numbers do you observe for your implementation, and how do they compare to the RRMSE numbers for your non-tensor-core implementation from Lab 5? Did you encounter any interesting bugs along the way? Finally, optionally: do you have any ideas for how it might be possible to develop an implementation which runs even faster?
Congratulations – you’ve reached the end of the matrix multiplication labs for 6.S894! We hope you’ve had as much fun working through them as we’ve had creating them.
You’re now well on your way to being able to implement the kinds of high-performance matrix multiplication kernels which power the world’s most computationally demanding deep learning applications, as well as important applications in many other domains.
Further Reading
If you want to learn more about matrix multiplication, there are a huge number of additional topics you may find it interesting to look into, including:
-
Reduced precisions:
-
The tensor cores on our GPU support 16-bit floating point (PTX docs), 8-bit integer (PTX docs), and 4-bit integer (PTX docs) precisions. Handling these precisions efficiently poses new and interesting challenges, like needing to pack multiple values into each 32-bit word (CUDA docs).
-
Newer architectures, like Hopper, also support 8-bit floating point, and NVIDIA’s upcoming Blackwell GPUs will support exotic formats like 4-bit floating point and 6-bit (!) floating point (link).
-
-
Fused and variant kernels:
- It’s very common for high-performance machine learning applications to employ kernels which implement matrix multiplication combined with some other operation. This can look like applying an element-wise operation or a reduction on the final output of a matrix multiplication before writing it back to DRAM (relevant paper), but it also includes even more complicated kernel designs like FlashAttention (link, and state-of-the art version) which embed matrix-multiply-like computations inside other kinds of workloads. An interesting recent example is FLUX (link), which combines matrix multiplication with cross-GPU communication.
-
Higher-level tools:
- In this lab we’ve been programming the GPU at a very low level of abstraction, but it’s also possible, and often preferable, to write matrix multiplications using higher-level CUDA libraries like CUTLASS (link), or domain-specific languages like Triton (link). Many of the performance optimizations used in the implementations of those tools are techniques that you now know how to implement by hand!
-
Improved scheduling and partitioning:
- NVIDIA’s best-performing matrix multiply kernels use techniques slightly more advanced than (but very similar to!) the ideas we’ve discussed in these past three labs. Some of the key differences from the designs we’ve covered include overlapping data movement with computation using software pipelines with more than two stages (i.e. staging data into a ring buffer rather than double-buffering it) (link), and using partitioning strategies which achieve better load-balancing than split-k, such as “stream-k” (link).
-
Newer GPUs:
- NVIDIA’s Hopper generation of GPUs includes several new features which are relevant when writing matrix multiply kernels. Notably, these include an evolved version of asynchronous copy instructions using something called the “tensor memory accelerator” (TMA) (link), as well as new extremely high-throughput tensor core instructions which run asynchronously, can load their inputs directly from shared memory, and work at a granularity of four warps at a time (PTX docs). Even though these features are different than the ones we’ve covered in this course, we hope you feel that you’re now very well-prepared to figure out how to use them (we think you are!).
-
Alternative platforms:
- Because deep learning is such an active space, there’s now a veritable zoo of different hardware accelerator platforms all designed to multiply matrices. Spanning various levels of readiness, performance, and public availability, these include: Google’s TPUs, AMD’s GPUs, Apple’s M-series chips, Intel’s Gaudi, Graphcore’s IPUs, Amazon’s Trainium, Tenstorrent’s various accelerators, Cerebras’s Wafer-Scale Engine, Groq’s TSPs, the chips in Tesla’s Dojo supercomputer, Etched’s Sohu, SambaNova’s RDU, Microsoft’s MAIA, and Meta’s MTIA – and we’re sure we forgot some! You may get a chance to program some of these in the future – and although they’re not NVIDIA GPUs, many of the fundamental ideas we’ve covered should be relevant as you learn to program any of them!
In a few weeks, we’ll start discussing ideas for final projects. If any of the matrix-multiplication-related topics above sound interesting to you, keep them in mind when you’re thinking about what you might want to work on for your final project!
Deliverables
Checkpoint (Due Monday, October 21, 2024)
For the checkpoint for this lab, we ask that you:
-
Complete Part 1, so that you’re familiar with inline PTX syntax by the time we meet together for live lab.
-
See how far you can get on Part 2 and Part 3, and let us know anything you’re stuck on.
- We won’t require it, in case people get stuck, but if you can, please try to complete Part 2 and get started on Part 3 before live lab. If everyone gets started on Part 3 ahead of time, then we can spend as much time as possible during live lab discussing strategies for how to make our matrix multiply implementations run fast.
On the Gradescope assignment “Lab 6 Checkpoint,” (link) submit your completed code for Part 1, and your answers to the prompts about how you’re doing on Part 2 and Part 3.
Final Submission (Due Friday, October 25, 2024)
On the Gradescope assignment “Lab 6 Final,” (link), submit your completed code for exercise_lop3.cu
, exercise_mma.cu
, and matmul_3.cu
, as well as a PDF write-up containing your answers to Questions 1 - 4.
NVIDIA has developed an external library called “CUTLASS” which provides a higher-level C++ interface for interacting with tensor cores. However, CUTLASS is built on top of many layers of complicated C++ template metaprogramming machinery, and in the course staff’s experience, accessing tensor cores directly via PTX provides better clarity about what’s actually going on. Libraries like CUTLASS can be convenient in practice, but they’re never necessary; anything you can do using CUTLASS, you can also do yourself using inline PTX.
In practice, thanks to compiler optimizations in the PTX-to-SASS translation step, if you write CUDA code which implements three-way bit-wise operations in terms of normal two-way bit-wise operators like (a & b) | c
, the compiler will sometimes end up generating fast LOP3
instructions for you at the SASS level anyway. However, explicitly invoking the lop3
instruction via inline PTX provides more control.
In PTX there is also an API called wmma
, which superficially appears to offer yet another way to use the machine’s tensor cores. However, inspecting the SASS generated for wmma.mma
instructions reveals that, on our GPU, it ultimately compiles to the same HMMA
instructions which are already exposed through the mma
API we’re using for this lab. As far as we can tell, this alternate wmma
API exists mostly for historical reasons.
Tensor core instructions aren’t the only instructions on the GPU with warp-level semantics; there are also warp-level reductions and warp-level permutations, among others. This blog post has some interesting commentary on such warp-level functions and their history.
We make sure the cuBLAS kernels we’re calling won’t use tensor cores by explicitly requesting matrix multiplies in full FP32 precision as opposed to TF32 precision.