Lab 6: Matrix Multiply – Tensor Cores

Prologue: Logistics

Due Dates

For this lab, you’ll be turning in the following deliverables:

See the “Deliverables” section at the end of this document for more information on what you’ll be turning in.

Starter Code

You can get the starter code for this lab by cloning the lab repository:

git clone git@github.com:accelerated-computing-class/lab6.git

Update October 18: The starter code has been updated since its initial release, and now includes the file matmul_3.cu for Part 3. If you cloned the repo before this update, please pull the latest commit, or manually copy the new matmul_3.cu file to your computer. (link to raw file)

Introduction

Goals for This Lab

So far in our exploration of matrix multiplication, we’ve focused primarily on optimizing data movement (Lab 4) and work partitioning (Lab 5). As we’ve worked to reduce bottlenecks along those dimensions, the run times of our implementations have increasingly become dominated by the cost of the floating point computations in our kernels’ innermost loops. Up until now, we’ve been implementing those core floating point computations using fused multiply-add (FMA) instructions. However, we can do better: modern NVIDIA GPUs support so-called tensor core instructions, which are designed specifically to accelerate matrix multiplication workloads. In this third and final matrix multiplication lab, we’ll be looking at how to use those tensor core instructions to speed up our kernels.

The tensor core instructions we’ll be using in this lab aren’t exposed by default in the CUDA C++ language, so we’ll be accessing them via inline PTX assembly.1 Since we haven’t worked with inline PTX before, this lab will walk through how to make use of tensor core instructions step-by-step:

  1. The first part of this lab is a brief introduction to inline PTX, using a simple bit manipulation instruction as an example.

  2. Next, we’ll look at how to access tensor core instructions in PTX, and how to work with the data layouts those tensor core instructions expect.

  3. Finally, we’ll integrate tensor core instructions into our full matrix multiplication kernel, and try to obtain a speedup over what were able to achieve using FMAs.

Note on Terminology: What is a “Tensor Core?”

Although the phrase “tensor core” might conjure up mental images of something similar to a “CPU core” – perhaps something with its own register file and program counter, decoding and executing a programmable stream of instructions in sequence – a tensor core is not actually that kind of “core” in the traditional computer architecture sense.

The phrase “tensor core” is just NVIDIA’s name for a particular kind of functional unit which exists on recent generations of NVIDIA GPUs. Tensor cores are not fundamentally different from ALUs or FPUs – each tensor core is attached to a warp scheduler, and solely executes math operations.

From a software point of view, tensor cores simply provide another kind of math instruction which your code is able to invoke. As we’ll see, these tensor core instructions have some interesting and unusual properties, but their existence doesn’t radically alter anything about the CUDA programming model.

Part 1: Warm-Up – Introduction to Inline PTX

Before attempting to use tensor cores, we’ll look at an example of how to use inline PTX to access a simpler instruction: the lop3 instruction, which stands for “Logical Operation on 3 Inputs.”

The lop3 instruction generalizes bit-wise operations like &, |, ~, and ^. It computes an arbitrary bit-wise function of three 32-bit integer inputs (per CUDA thread), with the exact behavior of the bit-wise function determined by a programmer-supplied lookup table. Because there are only 23=82^3 = 8 possible tuples of three bits, the full contents of this lookup table can be encoded in a single 8-bit integer. In this exercise, we’ll be looking at how to use lop3 to compute the specific bit-wise operation (a & b) | c, the lop3 lookup table for which can be expressed using the constant 0b11101010.

The lop3 instruction can be useful in applications requiring bit-level manipulation, such as when decompressing neural network weights from 4-bit to 16-bit precision (search for “lop3” in that paper!). However, the CUDA C++ language does not by default expose any way to explicitly invoke lop3 from user code.2 To access it, we’ll need to use inline PTX.

To embed inline PTX in our CUDA programs, we can use the asm(...) construct, which looks like this:

asm(
    " snippet of PTX code, with 'holes' to fill in "
    : /* output variables to associate with holes */
    : /* input variables to associate with holes */
);

You can read about CUDA’s inline PTX syntax in detail here: inline PTX syntax docs. For our purposes, the following should be sufficient:

Now we’re ready to write some inline PTX!

Deliverable: In the file exercise_lop3.cu, fill in the body of the function lop3_kernel to perform the operation *out = (*a & *b) | *c; using the lop3.b32 instruction, invoked via inline PTX.

Question 1 for final write-up: Look at the assembly code generated for your exercise_lop3.cu file, using either Compiler Explorer or the --asm flag in Telerun. What does the generated PTX for your kernel look like? Specifically, what happened to the “holes” in the inline PTX string? What does the generated SASS look like?

Part 2: Warm-Up – Invoking Tensor Core Instructions

Now that we’ve looked at how inline PTX works, we can start using it to interact with the tensor cores on our GPU!

The RTX A4000 GPU we’re using belongs to NVIDIA’s Ampere generation (specifically, “Compute Capability 8.6”). On Ampere, there are tensor core instructions available in three different floating-point flavors:

Additionally, for all three of these formats, tensor cores support accumulating results in full 32-bit precision, effectively casting the results of the tensor core’s lower-precision multiplications up to FP32 before adding them together or to an existing partial sum.

Because the kernels we developed in Lab 4 and Lab 5 work in 32-bit precision, we’ll be focusing on TF32 precision in this lab for the sake of compatibility with our existing code. Given that TF32 tensor core instructions perform multiplications in lower precision than the FP32 FMAs we’ve been using until now, we can expect to see some unavoidable accuracy loss when we adapt our kernel to use tensor cores.

So – what kind of TF32-precision tensor core functionality do we actually have on our A4000 GPU? The answer is simple – we have exactly two instructions:3

(You can find the PTX documentation for these instructions here: MMA PTX docs.)

Both of these instructions are “matrix-multiply-accumulate” (MMA) instructions; conceptually, they each implement an operation like:

DAB+CD \leftarrow A B + C

where AA, BB, CC, and DD are matrices. The two instructions differ only in the dimensions of the matrices they operate on:

Instruction DimensionsAA DimensionsBB DimensionsCC, DD Dimensions
m16n8k416 * 44 * 816 * 8
m16n8k816 * 88 * 816 * 8

Empirically, the course staff have observed that these instructions are equivalent in terms of FLOP throughput; the m16n8k4 variant performs half as much work per instruction as m16n8k8, but twice as many m16n8k4 instructions can execute per cycle on average as m16n8k8.

In this part of the lab, we’ll look at how we can use the m16n8k8 TF32 MMA instruction to execute a single 16 * 8 * 8 matrix multiplication. As we’ll see, this isn’t actually trivial – in particular, it will require understanding the unusual way in which tensor core instructions expect their operands to be laid out in registers.

Warp-Level Semantics

To understand how the mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 instruction works, the most important fact to establish is that tensor core instructions fundamentally operate at the warp level.

As we saw in Labs 1 and 2, the GPU’s hardware always executes instructions in a 32-wide SIMD fashion, with every 32 consecutive CUDA threads grouped together as 32 lanes of a vector. Viewing the GPU as a SIMD machine, virtually all the instructions we’ve seen our GPU execute so far in this course have been element-wise vector operations, with each instruction applying an identical, independent operation in each lane (modulo masking). When every instruction we execute is element-wise, we can often get away with ignoring the fact that the GPU is a SIMD machine at all, and simply pretend like every CUDA thread its executing its own independent stream of instructions. However, tensor core instructions break this illusion, because they are not element-wise.4

When a warp executes a tensor core operation like our m16n8k8 instruction, it is not executing a separate, independent matrix multiplication for each CUDA thread in the warp; rather, it is executing a single 16 * 8 * 8 matrix multiplication cooperatively across the entire warp, with the input and output data for the instruction distributed across the registers of all the CUDA threads in the warp. When thinking about tensor core instructions, it’s most helpful to think of each “register” in your program as a 32-word-wide vector register, rather than as a single scalar register per CUDA thread.

With all of that in mind, let’s take a look at how the m16n8k8 instruction we’re using actually expects data to be laid out in registers. First, a bit of math:

Accordingly, the PTX syntax for invoking our tensor core instruction looks like this:

mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32
    {%0, %1, %2, %3},     /* 'D' matrix */
    {%4, %5, %6, %7},     /* 'A' matrix */
    {%8, %9},             /* 'B' matrix */
    {%10, %11, %12, %13}; /* 'C' matrix */

(PTX syntax supports /* ... */ comments.)

From the perspective of each CUDA thread, each of these %0, %1, etc operands is a 1-word scalar register. Collectively across the entire warp, each operand is a 32-word vector register.

How does the m16n8k8 instruction expect data to be packed into these registers? We present the layouts below.

(These diagrams are courtesy of Claude 3.5 Sonnet. You can click any image to expand it. You can also access interactive versions of these diagrams here:

(The PTX documentation also contains its own versions of these diagrams.)

Essentially:

Recall that for the ‘A’ matrix, the “vertical” and “horizontal” dimensions correspond to the i and k indices in the matrix multiply computation, whereas for ‘B’ they correspond to the k and j indices, and for ‘C’ they correspond to i and j.

A Note on Register Types

You now have almost everything you need in order to invoke the mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 instruction to perform a 16 * 8 * 8 matrix-multiply-accumulate. There is, however, one remaining quirk of the PTX interface to be aware of: this PTX instruction expects every operand to be a 32-bit integer register. Of course, the bits these integer values carry will actually encode 32-bit floating-point data, but it expects them to be integer registers nonetheless. To cope with this, you can use the built-in __float_as_uint and __uint_as_float functions to reinterpret the bits of a float as a uint32_t, and vice-versa. (These conversion functions are purely a compile-time formality and should ultimately have zero cost at run time.)

Implementation

Deliverable: In the file exercise_mma.cu, implement the function mma_16x8x8_kernel to perform a single 16 * 8 * 8 matrix multiplication on the matrices stored in a and b, and accumulate the results of that matrix multiplication into c, using the tensor core instruction mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32. In addition to invoking this tensor core instruction, your kernel can use whatever additional CUDA logic you like to compute indices, move data around, etc. The data in a, b, and c is stored in row-major layout in global memory. Note that the kernel mma_16x8x8_kernel will be launched with exactly 32 CUDA threads (one warp).

Question 2 for final write-up: Look at the assembly code generated for your exercise_mma.cu file. What does the generated SASS look like? Can you find the tensor core instruction?

Part 3: Accelerating Matrix Multiply

Now that we’ve seen how to invoke tensor core instructions on our GPU, we’re ready to integrate them into our full matrix multiply kernel!

For this lab, we’ll be focusing on just a subset of the problem sizes from Lab 5. Here they are:

size_isize_jsize_k
307230723072
51230723072
25630723072
12830723072
6430723072
3230723072
1630723072

Analysis

To understand the maximum performance we can achieve with our tensor core implementation on each of these problem sizes, we can repeat a similar analysis to the one we carried out for the previous lab. For this analysis, you can assume that the theoretical peak TF32 tensor core throughput on our A4000 GPU is given by:

  (128 FLOP / tensor core / cycle)
* (4 tensor cores / SM)
* (48 SMs)
* (1.56 GHz)

= 38.34 TFLOP/s

Question 3 for final write-up: For each of the problem sizes in this lab, walk through the following analysis (you may find it helpful to reuse some of your calculations from Lab 5 Question 2):

  1. Considering the total number of FLOPs required to process this problem size, what is the fastest we could process this problem size if tensor core throughput were the only constraint?

  2. Considering (1) as well as the minimum time required to access each unique matrix element in DRAM, what lower bound does this imply for the run time of our algorithm? Is this workload compute-bound or bandwidth-bound?

  3. Considering (2), what is the maximum TFLOP/s we could achieve on this problem size? (This is just (2) divided by the total FLOPs.)

  4. How does (3) compare to the maximum throughput achievable if we were to use FMAs rather than tensor cores (as we calculated in Lab 5 Question 2.6)? Is the workload constrained by the same resource (either compute or bandwidth) in both cases, or is one scenario compute-bound while the other is bandwidth-bound?

Implementation

Our goal for the final part of this lab will be to write a matrix multiply kernel which uses tensor cores to run faster than any FMA-based matrix multiply realistically could on our largest problem sizes.

To calibrate our expectations for how fast an FMA-based kernel could realistically run, we’ve measured the performance of NVIDIA’s highly-optimized cuBLAS library on each problem size when running without tensor cores:5

cuBLAS Performance Without Tensor Cores:

size_isize_jsize_kTime (ms)Throughput (TFLOP/s)
3072307230724.0514.32
512307230720.8012.07
256307230720.4610.40
128307230720.2410.00
64307230720.139.57
32307230720.115.26
16307230720.112.73

Our goal will be to write an implementation which beats these non-tensor-core cuBLAS numbers on the following problem sizes:

For the other problem sizes, your implementation should be correct, but it’s okay if it achieves worse performance than cuBLAS.

Deliverable: In the file matmul_3.cu, implement the function launch_matmul_tensor, and any associated kernels, so that when size_i is 3072, 512, 256, or 128, it achieves a higher throughput than our FMA-based cuBLAS baseline. To do this, you will (almost certainly) need to use tensor cores.

To hit this performance target, you don’t need any techniques other than what we’ve already discussed in Lab 4, Lab 5, and this lab. All the suggestions from the previous labs continue to apply; a few especially important ones which you may find it helpful to keep in mind are:

Good luck! Once you’ve implemented your optimized kernel, you can answer the final question of the lab:

Question 4 for final write-up: How does the performance of your implementation compare to the cuBLAS FMA baseline for each problem size? What fraction of theoretical peak throughput (calculated in Question 3.3) were you able to achieve for each problem size? What did you need to change about your kernel design in order to make use of tensor core instructions? What RRMSE numbers do you observe for your implementation, and how do they compare to the RRMSE numbers for your non-tensor-core implementation from Lab 5? Did you encounter any interesting bugs along the way? Finally, optionally: do you have any ideas for how it might be possible to develop an implementation which runs even faster?

Congratulations – you’ve reached the end of the matrix multiplication labs for 6.S894! We hope you’ve had as much fun working through them as we’ve had creating them.

You’re now well on your way to being able to implement the kinds of high-performance matrix multiplication kernels which power the world’s most computationally demanding deep learning applications, as well as important applications in many other domains.

Further Reading

If you want to learn more about matrix multiplication, there are a huge number of additional topics you may find it interesting to look into, including:

In a few weeks, we’ll start discussing ideas for final projects. If any of the matrix-multiplication-related topics above sound interesting to you, keep them in mind when you’re thinking about what you might want to work on for your final project!

Deliverables

Checkpoint (Due Monday, October 21, 2024)

For the checkpoint for this lab, we ask that you:

On the Gradescope assignment “Lab 6 Checkpoint,” (link) submit your completed code for Part 1, and your answers to the prompts about how you’re doing on Part 2 and Part 3.

Final Submission (Due Friday, October 25, 2024)

On the Gradescope assignment “Lab 6 Final,” (link), submit your completed code for exercise_lop3.cu, exercise_mma.cu, and matmul_3.cu, as well as a PDF write-up containing your answers to Questions 1 - 4.


1

NVIDIA has developed an external library called “CUTLASS” which provides a higher-level C++ interface for interacting with tensor cores. However, CUTLASS is built on top of many layers of complicated C++ template metaprogramming machinery, and in the course staff’s experience, accessing tensor cores directly via PTX provides better clarity about what’s actually going on. Libraries like CUTLASS can be convenient in practice, but they’re never necessary; anything you can do using CUTLASS, you can also do yourself using inline PTX.

2

In practice, thanks to compiler optimizations in the PTX-to-SASS translation step, if you write CUDA code which implements three-way bit-wise operations in terms of normal two-way bit-wise operators like (a & b) | c, the compiler will sometimes end up generating fast LOP3 instructions for you at the SASS level anyway. However, explicitly invoking the lop3 instruction via inline PTX provides more control.

3

In PTX there is also an API called wmma, which superficially appears to offer yet another way to use the machine’s tensor cores. However, inspecting the SASS generated for wmma.mma instructions reveals that, on our GPU, it ultimately compiles to the same HMMA instructions which are already exposed through the mma API we’re using for this lab. As far as we can tell, this alternate wmma API exists mostly for historical reasons.

4

Tensor core instructions aren’t the only instructions on the GPU with warp-level semantics; there are also warp-level reductions and warp-level permutations, among others. This blog post has some interesting commentary on such warp-level functions and their history.

5

We make sure the cuBLAS kernels we’re calling won’t use tensor cores by explicitly requesting matrix multiplies in full FP32 precision as opposed to TF32 precision.