Open Source | MIT Licensed

Analyze and debug your tile-based DSL kernels

Note: Triton is best supported today; Amazon NKI DSL support is in active development.

Install
pip install triton-viz

Note: A browser with WebGL/OpenGL enabled is required (standard in modern browsers).
See README.md for development installs, frontend builds, and full documentation.

Analysis Clients

Analyze kernels between visualization, profiling, and sanitization with a single line of code

The visualizer currently supports a detailed look into load, store, and matmul operations for 1/2/3D tensors (more operations and dimensions coming soon).

examples/visualizer/3dims.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("tracer")
@triton.jit
def add_3d_slices_kernel(
    input_ptr1,
    input_ptr2,
    output_ptr,
    stride_x,
    stride_y,
    stride_z,
    slice_x,
    slice_y,
    slice_z,
    BLOCK_SIZE_X: tl.constexpr,
    BLOCK_SIZE_Y: tl.constexpr,
    BLOCK_SIZE_Z: tl.constexpr,
):
    # kernel definition here


if __name__ == "__main__":
    torch.manual_seed(0)
    input1 = torch.randn(16, 16, 32, device="cpu")
    input2 = torch.randn(16, 16, 32, device="cpu")
    output = torch.empty_like(input1)

    slice_z, slice_y, slice_x = input1.shape
    stride_z, stride_y, stride_x = input1.stride()

    grid = (
        triton.cdiv(slice_x, BLOCK_SIZE_X),
        triton.cdiv(slice_y, BLOCK_SIZE_Y),
        triton.cdiv(slice_z, BLOCK_SIZE_Z),
    )

    # Launch kernel
    add_3d_slices_kernel[grid](
        input1,
        input2,
        output,
        stride_x,
        stride_y,
        stride_z,
        slice_x,
        slice_y,
        slice_z,
        BLOCK_SIZE_X=BLOCK_SIZE_X,
        BLOCK_SIZE_Y=BLOCK_SIZE_Y,
        BLOCK_SIZE_Z=BLOCK_SIZE_Z,
    )
    triton_viz.launch(share=False, port=5001)
Fullscreen for full detail

Interactive 3D rendering allows you to inspect tensor layouts and memory access patterns from any perspective. Left-click and drag to change camera angle. Scroll to zoom. Right-click and drag to pan.

examples/visualizer/3dims.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("tracer")
@triton.jit
def add_3d_slices_kernel(
    input_ptr1,
    input_ptr2,
    output_ptr,
    stride_x,
    stride_y,
    stride_z,
    slice_x,
    slice_y,
    slice_z,
    BLOCK_SIZE_X: tl.constexpr,
    BLOCK_SIZE_Y: tl.constexpr,
    BLOCK_SIZE_Z: tl.constexpr,
):
    # kernel definition here


if __name__ == "__main__":
    torch.manual_seed(0)
    input1 = torch.randn(16, 16, 32, device="cpu")
    input2 = torch.randn(16, 16, 32, device="cpu")
    output = torch.empty_like(input1)

    slice_z, slice_y, slice_x = input1.shape
    stride_z, stride_y, stride_x = input1.stride()

    grid = (
        triton.cdiv(slice_x, BLOCK_SIZE_X),
        triton.cdiv(slice_y, BLOCK_SIZE_Y),
        triton.cdiv(slice_z, BLOCK_SIZE_Z),
    )

    # Launch kernel
    add_3d_slices_kernel[grid](
        input1,
        input2,
        output,
        stride_x,
        stride_y,
        stride_z,
        slice_x,
        slice_y,
        slice_z,
        BLOCK_SIZE_X=BLOCK_SIZE_X,
        BLOCK_SIZE_Y=BLOCK_SIZE_Y,
        BLOCK_SIZE_Z=BLOCK_SIZE_Z,
    )
    triton_viz.launch(share=False, port=5001)
Fullscreen for full detail

Inspect operation inputs/outputs at specific Program IDs (PIDs) to understand block-level behavior. View which elements are loaded/stored in each program (highlighted in blue/orange respectively). Toggle "All Program IDs" to see how all programs load tensor values simultaneously.

examples/visualizer/matmul.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("tracer")
@triton.jit
def matmul_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    M: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    # kernel definition here


if __name__ == "__main__":
    torch.manual_seed(0)
    M, N, K = 32, 32, 64
    a = torch.randn((M, K), dtype=torch.float32)
    b = torch.randn((K, N), dtype=torch.float32)
    c = torch.empty((M, N), dtype=torch.float32)

    BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    matmul_kernel[grid](
        a,
        b,
        c,
        M,
        N,
        K,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K,
    )
    triton_viz.launch(share=False, port=5001)
Fullscreen for full detail

Directly map visual operations back to your source code lines for seamless debugging.

examples/visualizer/matmul.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("tracer")
@triton.jit
def matmul_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    M: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    # kernel definition here


if __name__ == "__main__":
    torch.manual_seed(0)
    M, N, K = 32, 32, 64
    a = torch.randn((M, K), dtype=torch.float32)
    b = torch.randn((K, N), dtype=torch.float32)
    c = torch.empty((M, N), dtype=torch.float32)

    BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    matmul_kernel[grid](
        a,
        b,
        c,
        M,
        N,
        K,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K,
    )
    triton_viz.launch(share=False, port=5001)
Fullscreen for full detail

Visualize tensor values with color gradients to quickly identify outliers, zeros, or saturation.

examples/visualizer/matmul.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("tracer")
@triton.jit
def matmul_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    M: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    # kernel definition here


if __name__ == "__main__":
    torch.manual_seed(0)
    M, N, K = 32, 32, 64
    a = torch.randn((M, K), dtype=torch.float32)
    b = torch.randn((K, N), dtype=torch.float32)
    c = torch.empty((M, N), dtype=torch.float32)

    BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    matmul_kernel[grid](
        a,
        b,
        c,
        M,
        N,
        K,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K,
    )
    triton_viz.launch(share=False, port=5001)
Fullscreen for full detail

Analyze the distribution of values within your tensors to understand data distributions and make the best quantization decisions.

The profiler flags performance hazards like non-unrolled loops, inefficient mask usage, and missing buffer_load optimizations while tracking load/store byte counts with low-overhead sampling.

examples/profiler/load_store.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("profiler")
@triton.jit
def simple_kernel(
    x_ptr,
    output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    tl.store(output_ptr + offsets, x, mask=mask)


if __name__ == "__main__":
    cfg.reset()
    device = "cpu"
    size = 12
    BLOCK_SIZE = 8
    torch.manual_seed(0)
    x = torch.arange(size, dtype=torch.float32, device=device)
    output = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),)
    simple_kernel[grid](x, output, size, BLOCK_SIZE)

Collects detailed performance metrics including cycle counts, memory bandwidth, and cache hit rates.

TERMINAL OUTPUT
============================================================ Profiler Issues Summary ============================================================
============================================================
---------- Profiler: For-Loop Unrolling Statistics ---------
============================================================
No for-loops detected.
============================================================
============================================================
---------- Profiler: Mask Ratio Statistics -----------------
============================================================
────────────────────────────────────────
Overall Load Operations:
  Total mask elements:     8
  False elements:          4
  Masked percentage:       50.00%
Overall Store Operations:
  Total mask elements:     8
  False elements:          4
  Masked percentage:       50.00%
────────────────────────────────────────
Per-Operation Breakdown:
Top 5 Operations by False Elements:
────────────────────────────────────────
#1. LOAD at load_store.py:22
    Total elements: 8
    False elements: 4 (50.0%)
    Code: x = tl.load(x_ptr + offsets, mask=mask)
#2. STORE at load_store.py:23
    Total elements: 8
    False elements: 4 (50.0%)
    Code: tl.store(output_ptr + offsets, x, mask=mask)
============================================================
============================================================
---------- Profiler: Buffer Load Issue Detection -----------
============================================================
>>>>>> Warning: Potential Buffer Load Issue Detected! <<<<<<
Some memory access offsets are within 32-bit range,
but Buffer Load optimization was NOT used in the kernel.
This may lead to suboptimal performance on AMD GPUs.
Consider enabling Buffer Load optimization.
============================================================
============================================================ Profiler Issues Summary Ends =======================================================

The sanitizer symbolically checks tensor memory accesses for out-of-bounds errors and emits rich reports with tensor metadata, call stack, and expression trees, with an optional fake-memory backend to avoid real reads.

examples/sanitizer/gemm.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("sanitizer")
@triton.jit
def gemm_kernel(
    A, B, C, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, TILE_SIZE: tl.constexpr
):
    m_block = tl.program_id(0)
    n_block = tl.program_id(1)
    range_m = tl.arange(0, TILE_SIZE)
    range_n = tl.arange(0, TILE_SIZE)
    range_k = tl.arange(0, TILE_SIZE)
    range_m_block = TILE_SIZE * m_block + range_m[:, None]
    range_n_block = TILE_SIZE * n_block + range_n[None, :]
    accum = tl.zeros((TILE_SIZE, TILE_SIZE), dtype=tl.float32)
    for k_block in range(K // TILE_SIZE):
        range_k_block = TILE_SIZE * k_block + range_k
        A_off = K * range_m_block + range_k_block[None, :]
        A_tile = tl.load(A + A_off)

        B_off = N * range_k_block[:, None] + range_n_block
        B_tile = tl.load(B + B_off)

        accum += tl.dot(A_tile, B_tile, allow_tf32=False)
    C_off = N * range_m_block + range_n_block
    tl.store(C + C_off, accum)


def test_gemm():
    M, N, K = 32, 32, 32
    A = torch.randn((M, K))
    B = torch.randn((K, N))
    C = torch.empty((M, N))
    tile_size = 16

    gemm_kernel[(M // tile_size, N // tile_size)](A, B, C, M, N, K, tile_size)
    print("GEMM ran without any out-of-bounds errors!")

test_gemm()

When code is correct, the sanitizer validates all memory accesses without interrupting execution.

TERMINAL OUTPUT
GEMM ran without any out-of-bounds errors!
examples/sanitizer/gemm_oob.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("sanitizer")
@triton.jit
def gemm_kernel(
    A, B, C, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, TILE_SIZE: tl.constexpr
):
    m_block = tl.program_id(0)
    n_block = tl.program_id(1)
    range_m = tl.arange(0, TILE_SIZE)
    range_n = tl.arange(0, TILE_SIZE)
    range_k = tl.arange(0, TILE_SIZE)
    range_m_block = TILE_SIZE * m_block + range_m[:, None]
    range_n_block = TILE_SIZE * n_block + range_n[None, :]
    accum = tl.zeros((TILE_SIZE, TILE_SIZE), dtype=tl.float32)
    for k_block in range(K // TILE_SIZE):
        range_k_block = TILE_SIZE * k_block + range_k
        A_off = K * range_m_block + range_k_block[None, :]
        A_tile = tl.load(A + A_off + 1) # Out-Of-Bounds Access HERE!

        B_off = N * range_k_block[:, None] + range_n_block
        B_tile = tl.load(B + B_off)

        accum += tl.dot(A_tile, B_tile, allow_tf32=False)
    C_off = N * range_m_block + range_n_block
    tl.store(C + C_off, accum)


def test_gemm():
    M, N, K = 32, 32, 32
    A = torch.randn((M, K))
    B = torch.randn((K, N))
    C = torch.empty((M, N))
    tile_size = 16

    gemm_kernel[(M // tile_size, N // tile_size)](A, B, C, M, N, K, tile_size)
    print("GEMM ran without any out-of-bounds errors!")

test_gemm()

The sanitizer catches invalid memory accesses instantly, providing a detailed report with tensor info, call stack, and symbolic expression trees.

TERMINAL OUTPUT
🚨 ILLEGAL MEMORY ACCESS DETECTED 🚨
━━━ Code Context ━━━
File: triton-viz/examples/sanitizer/gemm_oob.py
Function: gemm_kernel
Line 25:
  22 │     for k_block in range(K // TILE_SIZE):
  23 │         range_k_block = TILE_SIZE * k_block + range_k
  24 │         A_off = K * range_m_block + range_k_block[None, :]
→ 25 │         A_tile = tl.load(A + A_off + 1) # Out-Of-Bounds Access HERE!
  26 │
  27 │         B_off = N * range_k_block[:, None] + range_n_block
  28 │         B_tile = tl.load(B + B_off)
━━━ Tensor Information ━━━
arg:         A
dtype:       torch.float32             shape:       torch.Size([32, 32])
strides:     (32, 1)                   device:      cpu
contiguous:  True                      base_ptr:    0x000000002fa75080
size:        4096 bytes                valid_range: [0x000000002fa75080, 0x000000002fa76080)
━━━ Call Stack ━━━
#1 gemm_kernel at gemm_oob.py:25
   └─ A_tile = tl.load(A + A_off + 1) # Out-Of-Bounds Access HERE!
━━━ Violation Details ━━━
Violation address: 0x000000002fa76080
━━━ Symbolic Expression Tree ━━━
load [dtype=pointer]
├── ptr: addptr [dtype=<['16', '16'], pointer>]
│   ├── ptr: splat [dtype=<['16', '16'], pointer>]
│   │   ├── block_type: const=<['16', '16'], pointer> [dtype=<['16', '16'], pointer>]
│   │   └── arg: const=799494272 [dtype=pointer]
│   └── offset: add [dtype=<['16', '1'], int32>]
│       ├── lhs: add [dtype=<['16', '1'], int32>]
│       │   ├── lhs: broadcast [dtype=<['16', '1'], int32>]
│       │   │   └── arg: mul [dtype=<['16', '1'], int32>]
│       │   │       ├── lhs: splat [dtype=<['16', '1'], int32>]
│       │   │       │   ├── block_type: const=<['16', '1'], int32> [dtype=<['16', '1'], int32>]
│       │   │       │   └── arg: const=32 [dtype=int32]
│       │   │       └── rhs: add [dtype=<['16', '1'], int32>]
│       │   │           ├── lhs: splat [dtype=<['16', '1'], int32>]
│       │   │           │   ├── block_type: const=<['16', '1'], int32> [dtype=<['16', '1'], int32>]
│       │   │           │   └── arg: mul [dtype=int32]
│       │   │           │       ├── lhs: const=16 [dtype=int32]
│       │   │           │       └── rhs: pid_0 [dtype=int32]
│       │   │           │           └── axis: const=0 [dtype=int32]
│       │   │           └── rhs: expand_dims [dtype=<['16'], int32>]
│       │   │               ├── arg: arange [dtype=<['16'], int32>]
│       │   │               │   ├── ret_ty: const=<['16'], int32> [dtype=<['16'], int32>]
│       │   │               │   ├── start: const=0 [dtype=int32]
│       │   │               │   └── end: const=16 [dtype=int32]
│       │   │               └── axis: const=1 [dtype=int32]
│       │   └── rhs: broadcast [dtype=<['16'], int32>]
│       │       └── arg: expand_dims [dtype=<['16'], int32>]
│       │           ├── arg: add [dtype=<['16'], int32>]
│       │           │   ├── lhs: splat [dtype=<['16'], int32>]
│       │           │   │   ├── block_type: const=<['16'], int32> [dtype=<['16'], int32>]
│       │           │   │   └── arg: mul [dtype=int32]
│       │           │   │       ├── lhs: const=16 [dtype=int32]
│       │           │   │       └── rhs: const=loop_i_12 [dtype=int32]
│       │           │   └── rhs: arange [dtype=<['16'], int32>]
│       │           │       ├── ret_ty: const=<['16'], int32> [dtype=<['16'], int32>]
│       │           │       ├── start: const=0 [dtype=int32]
│       │           │       └── end: const=16 [dtype=int32]
│       │           └── axis: const=0 [dtype=int32]
│       └── rhs: splat [dtype=<['16', '16'], int32>]
│           ├── block_type: const=<['16', '16'], int32> [dtype=<['16', '16'], int32>]
│           └── arg: const=1 [dtype=int32]
├── mask: None
└── other: None
End of IMA Diagnostic Report

Sponsored by

Special thanks to AMD for supporting Triton-Viz.

Ready to visualize your kernels?

Get started with Triton-Viz in minutes. No GPU required for development.

View on GitHub Browse Examples