Note: Triton is best supported today; Amazon NKI DSL support is in active development.
pip install triton-viz
Note: A browser with WebGL/OpenGL enabled is required (standard in modern browsers).
See README.md for development installs, frontend builds, and full documentation.
Analyze kernels between visualization, profiling, and sanitization with a single line of code
The visualizer currently supports a detailed look into load, store, and matmul operations for 1/2/3D tensors (more operations and dimensions coming soon).
import torch
import triton
import triton.language as tl
import triton_viz
@triton_viz.trace("tracer")
@triton.jit
def add_3d_slices_kernel(
input_ptr1,
input_ptr2,
output_ptr,
stride_x,
stride_y,
stride_z,
slice_x,
slice_y,
slice_z,
BLOCK_SIZE_X: tl.constexpr,
BLOCK_SIZE_Y: tl.constexpr,
BLOCK_SIZE_Z: tl.constexpr,
):
# kernel definition here
if __name__ == "__main__":
torch.manual_seed(0)
input1 = torch.randn(16, 16, 32, device="cpu")
input2 = torch.randn(16, 16, 32, device="cpu")
output = torch.empty_like(input1)
slice_z, slice_y, slice_x = input1.shape
stride_z, stride_y, stride_x = input1.stride()
grid = (
triton.cdiv(slice_x, BLOCK_SIZE_X),
triton.cdiv(slice_y, BLOCK_SIZE_Y),
triton.cdiv(slice_z, BLOCK_SIZE_Z),
)
# Launch kernel
add_3d_slices_kernel[grid](
input1,
input2,
output,
stride_x,
stride_y,
stride_z,
slice_x,
slice_y,
slice_z,
BLOCK_SIZE_X=BLOCK_SIZE_X,
BLOCK_SIZE_Y=BLOCK_SIZE_Y,
BLOCK_SIZE_Z=BLOCK_SIZE_Z,
)
triton_viz.launch(share=False, port=5001)
Interactive 3D rendering allows you to inspect tensor layouts and memory access patterns from any perspective. Left-click and drag to change camera angle. Scroll to zoom. Right-click and drag to pan.
import torch
import triton
import triton.language as tl
import triton_viz
@triton_viz.trace("tracer")
@triton.jit
def add_3d_slices_kernel(
input_ptr1,
input_ptr2,
output_ptr,
stride_x,
stride_y,
stride_z,
slice_x,
slice_y,
slice_z,
BLOCK_SIZE_X: tl.constexpr,
BLOCK_SIZE_Y: tl.constexpr,
BLOCK_SIZE_Z: tl.constexpr,
):
# kernel definition here
if __name__ == "__main__":
torch.manual_seed(0)
input1 = torch.randn(16, 16, 32, device="cpu")
input2 = torch.randn(16, 16, 32, device="cpu")
output = torch.empty_like(input1)
slice_z, slice_y, slice_x = input1.shape
stride_z, stride_y, stride_x = input1.stride()
grid = (
triton.cdiv(slice_x, BLOCK_SIZE_X),
triton.cdiv(slice_y, BLOCK_SIZE_Y),
triton.cdiv(slice_z, BLOCK_SIZE_Z),
)
# Launch kernel
add_3d_slices_kernel[grid](
input1,
input2,
output,
stride_x,
stride_y,
stride_z,
slice_x,
slice_y,
slice_z,
BLOCK_SIZE_X=BLOCK_SIZE_X,
BLOCK_SIZE_Y=BLOCK_SIZE_Y,
BLOCK_SIZE_Z=BLOCK_SIZE_Z,
)
triton_viz.launch(share=False, port=5001)
Inspect operation inputs/outputs at specific Program IDs (PIDs) to understand block-level behavior. View which elements are loaded/stored in each program (highlighted in blue/orange respectively). Toggle "All Program IDs" to see how all programs load tensor values simultaneously.
import torch
import triton
import triton.language as tl
import triton_viz
@triton_viz.trace("tracer")
@triton.jit
def matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# kernel definition here
if __name__ == "__main__":
torch.manual_seed(0)
M, N, K = 32, 32, 64
a = torch.randn((M, K), dtype=torch.float32)
b = torch.randn((K, N), dtype=torch.float32)
c = torch.empty((M, N), dtype=torch.float32)
BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
matmul_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
triton_viz.launch(share=False, port=5001)
Directly map visual operations back to your source code lines for seamless debugging.
import torch
import triton
import triton.language as tl
import triton_viz
@triton_viz.trace("tracer")
@triton.jit
def matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# kernel definition here
if __name__ == "__main__":
torch.manual_seed(0)
M, N, K = 32, 32, 64
a = torch.randn((M, K), dtype=torch.float32)
b = torch.randn((K, N), dtype=torch.float32)
c = torch.empty((M, N), dtype=torch.float32)
BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
matmul_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
triton_viz.launch(share=False, port=5001)
Visualize tensor values with color gradients to quickly identify outliers, zeros, or saturation.
import torch
import triton
import triton.language as tl
import triton_viz
@triton_viz.trace("tracer")
@triton.jit
def matmul_kernel(
a_ptr,
b_ptr,
c_ptr,
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# kernel definition here
if __name__ == "__main__":
torch.manual_seed(0)
M, N, K = 32, 32, 64
a = torch.randn((M, K), dtype=torch.float32)
b = torch.randn((K, N), dtype=torch.float32)
c = torch.empty((M, N), dtype=torch.float32)
BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
matmul_kernel[grid](
a,
b,
c,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
triton_viz.launch(share=False, port=5001)
Analyze the distribution of values within your tensors to understand data distributions and make the best quantization decisions.
The profiler flags performance hazards like non-unrolled loops, inefficient mask usage, and missing buffer_load optimizations while tracking load/store byte counts with low-overhead sampling.
import torch
import triton
import triton.language as tl
import triton_viz
@triton_viz.trace("profiler")
@triton.jit
def simple_kernel(
x_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
tl.store(output_ptr + offsets, x, mask=mask)
if __name__ == "__main__":
cfg.reset()
device = "cpu"
size = 12
BLOCK_SIZE = 8
torch.manual_seed(0)
x = torch.arange(size, dtype=torch.float32, device=device)
output = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),)
simple_kernel[grid](x, output, size, BLOCK_SIZE)
Collects detailed performance metrics including cycle counts, memory bandwidth, and cache hit rates.
============================================================ Profiler Issues Summary ============================================================
============================================================
---------- Profiler: For-Loop Unrolling Statistics ---------
============================================================
No for-loops detected.
============================================================
============================================================
---------- Profiler: Mask Ratio Statistics -----------------
============================================================
────────────────────────────────────────
Overall Load Operations:
Total mask elements: 8
False elements: 4
Masked percentage: 50.00%
Overall Store Operations:
Total mask elements: 8
False elements: 4
Masked percentage: 50.00%
────────────────────────────────────────
Per-Operation Breakdown:
Top 5 Operations by False Elements:
────────────────────────────────────────
#1. LOAD at load_store.py:22
Total elements: 8
False elements: 4 (50.0%)
Code: x = tl.load(x_ptr + offsets, mask=mask)
#2. STORE at load_store.py:23
Total elements: 8
False elements: 4 (50.0%)
Code: tl.store(output_ptr + offsets, x, mask=mask)
============================================================
============================================================
---------- Profiler: Buffer Load Issue Detection -----------
============================================================
>>>>>> Warning: Potential Buffer Load Issue Detected! <<<<<<
Some memory access offsets are within 32-bit range,
but Buffer Load optimization was NOT used in the kernel.
This may lead to suboptimal performance on AMD GPUs.
Consider enabling Buffer Load optimization.
============================================================
============================================================ Profiler Issues Summary Ends =======================================================
The sanitizer symbolically checks tensor memory accesses for out-of-bounds errors and emits rich reports with tensor metadata, call stack, and expression trees, with an optional fake-memory backend to avoid real reads.
import torch
import triton
import triton.language as tl
import triton_viz
@triton_viz.trace("sanitizer")
@triton.jit
def gemm_kernel(
A, B, C, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, TILE_SIZE: tl.constexpr
):
m_block = tl.program_id(0)
n_block = tl.program_id(1)
range_m = tl.arange(0, TILE_SIZE)
range_n = tl.arange(0, TILE_SIZE)
range_k = tl.arange(0, TILE_SIZE)
range_m_block = TILE_SIZE * m_block + range_m[:, None]
range_n_block = TILE_SIZE * n_block + range_n[None, :]
accum = tl.zeros((TILE_SIZE, TILE_SIZE), dtype=tl.float32)
for k_block in range(K // TILE_SIZE):
range_k_block = TILE_SIZE * k_block + range_k
A_off = K * range_m_block + range_k_block[None, :]
A_tile = tl.load(A + A_off)
B_off = N * range_k_block[:, None] + range_n_block
B_tile = tl.load(B + B_off)
accum += tl.dot(A_tile, B_tile, allow_tf32=False)
C_off = N * range_m_block + range_n_block
tl.store(C + C_off, accum)
def test_gemm():
M, N, K = 32, 32, 32
A = torch.randn((M, K))
B = torch.randn((K, N))
C = torch.empty((M, N))
tile_size = 16
gemm_kernel[(M // tile_size, N // tile_size)](A, B, C, M, N, K, tile_size)
print("GEMM ran without any out-of-bounds errors!")
test_gemm()
When code is correct, the sanitizer validates all memory accesses without interrupting execution.
GEMM ran without any out-of-bounds errors!
import torch
import triton
import triton.language as tl
import triton_viz
@triton_viz.trace("sanitizer")
@triton.jit
def gemm_kernel(
A, B, C, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, TILE_SIZE: tl.constexpr
):
m_block = tl.program_id(0)
n_block = tl.program_id(1)
range_m = tl.arange(0, TILE_SIZE)
range_n = tl.arange(0, TILE_SIZE)
range_k = tl.arange(0, TILE_SIZE)
range_m_block = TILE_SIZE * m_block + range_m[:, None]
range_n_block = TILE_SIZE * n_block + range_n[None, :]
accum = tl.zeros((TILE_SIZE, TILE_SIZE), dtype=tl.float32)
for k_block in range(K // TILE_SIZE):
range_k_block = TILE_SIZE * k_block + range_k
A_off = K * range_m_block + range_k_block[None, :]
A_tile = tl.load(A + A_off + 1) # Out-Of-Bounds Access HERE!
B_off = N * range_k_block[:, None] + range_n_block
B_tile = tl.load(B + B_off)
accum += tl.dot(A_tile, B_tile, allow_tf32=False)
C_off = N * range_m_block + range_n_block
tl.store(C + C_off, accum)
def test_gemm():
M, N, K = 32, 32, 32
A = torch.randn((M, K))
B = torch.randn((K, N))
C = torch.empty((M, N))
tile_size = 16
gemm_kernel[(M // tile_size, N // tile_size)](A, B, C, M, N, K, tile_size)
print("GEMM ran without any out-of-bounds errors!")
test_gemm()
The sanitizer catches invalid memory accesses instantly, providing a detailed report with tensor info, call stack, and symbolic expression trees.
🚨 ILLEGAL MEMORY ACCESS DETECTED 🚨 ━━━ Code Context ━━━ File: triton-viz/examples/sanitizer/gemm_oob.py Function: gemm_kernel Line 25: 22 │ for k_block in range(K // TILE_SIZE): 23 │ range_k_block = TILE_SIZE * k_block + range_k 24 │ A_off = K * range_m_block + range_k_block[None, :] → 25 │ A_tile = tl.load(A + A_off + 1) # Out-Of-Bounds Access HERE! 26 │ 27 │ B_off = N * range_k_block[:, None] + range_n_block 28 │ B_tile = tl.load(B + B_off) ━━━ Tensor Information ━━━ arg: A dtype: torch.float32 shape: torch.Size([32, 32]) strides: (32, 1) device: cpu contiguous: True base_ptr: 0x000000002fa75080 size: 4096 bytes valid_range: [0x000000002fa75080, 0x000000002fa76080) ━━━ Call Stack ━━━ #1 gemm_kernel at gemm_oob.py:25 └─ A_tile = tl.load(A + A_off + 1) # Out-Of-Bounds Access HERE! ━━━ Violation Details ━━━ Violation address: 0x000000002fa76080 ━━━ Symbolic Expression Tree ━━━ load [dtype=pointer] ├── ptr: addptr [dtype=<['16', '16'], pointer >] │ ├── ptr: splat [dtype=<['16', '16'], pointer >] │ │ ├── block_type: const=<['16', '16'], pointer > [dtype=<['16', '16'], pointer >] │ │ └── arg: const=799494272 [dtype=pointer ] │ └── offset: add [dtype=<['16', '1'], int32>] │ ├── lhs: add [dtype=<['16', '1'], int32>] │ │ ├── lhs: broadcast [dtype=<['16', '1'], int32>] │ │ │ └── arg: mul [dtype=<['16', '1'], int32>] │ │ │ ├── lhs: splat [dtype=<['16', '1'], int32>] │ │ │ │ ├── block_type: const=<['16', '1'], int32> [dtype=<['16', '1'], int32>] │ │ │ │ └── arg: const=32 [dtype=int32] │ │ │ └── rhs: add [dtype=<['16', '1'], int32>] │ │ │ ├── lhs: splat [dtype=<['16', '1'], int32>] │ │ │ │ ├── block_type: const=<['16', '1'], int32> [dtype=<['16', '1'], int32>] │ │ │ │ └── arg: mul [dtype=int32] │ │ │ │ ├── lhs: const=16 [dtype=int32] │ │ │ │ └── rhs: pid_0 [dtype=int32] │ │ │ │ └── axis: const=0 [dtype=int32] │ │ │ └── rhs: expand_dims [dtype=<['16'], int32>] │ │ │ ├── arg: arange [dtype=<['16'], int32>] │ │ │ │ ├── ret_ty: const=<['16'], int32> [dtype=<['16'], int32>] │ │ │ │ ├── start: const=0 [dtype=int32] │ │ │ │ └── end: const=16 [dtype=int32] │ │ │ └── axis: const=1 [dtype=int32] │ │ └── rhs: broadcast [dtype=<['16'], int32>] │ │ └── arg: expand_dims [dtype=<['16'], int32>] │ │ ├── arg: add [dtype=<['16'], int32>] │ │ │ ├── lhs: splat [dtype=<['16'], int32>] │ │ │ │ ├── block_type: const=<['16'], int32> [dtype=<['16'], int32>] │ │ │ │ └── arg: mul [dtype=int32] │ │ │ │ ├── lhs: const=16 [dtype=int32] │ │ │ │ └── rhs: const=loop_i_12 [dtype=int32] │ │ │ └── rhs: arange [dtype=<['16'], int32>] │ │ │ ├── ret_ty: const=<['16'], int32> [dtype=<['16'], int32>] │ │ │ ├── start: const=0 [dtype=int32] │ │ │ └── end: const=16 [dtype=int32] │ │ └── axis: const=0 [dtype=int32] │ └── rhs: splat [dtype=<['16', '16'], int32>] │ ├── block_type: const=<['16', '16'], int32> [dtype=<['16', '16'], int32>] │ └── arg: const=1 [dtype=int32] ├── mask: None └── other: None End of IMA Diagnostic Report
Special thanks to AMD for supporting Triton-Viz.
Get started with Triton-Viz in minutes. No GPU required for development.