Open Source | MIT Licensed

Analyze and debug your tile-based DSL kernels

Note: Triton is the default DSL frontend; Amazon NKI support is optional and in active development.

Install
pip install triton-viz
Install (NKI support)
pip install triton-viz[nki] --extra-index-url https://pip.repos.neuron.amazonaws.com

Note: A browser with WebGL/OpenGL enabled is required (standard in modern browsers).
See README.md for development installs, web UI builds, DSL frontend selection, and full documentation.

Analysis Clients

Analyze kernels between visualization, profiling, and sanitization with a single line of code

The visualizer currently supports a detailed look into load, store, and matmul operations for N-dimensional tensors (more operations coming soon).

examples/visualizer/3dims.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("tracer")
@triton.jit
def add_3d_slices_kernel(
    input_ptr1,
    input_ptr2,
    output_ptr,
    stride_x,
    stride_y,
    stride_z,
    slice_x,
    slice_y,
    slice_z,
    BLOCK_SIZE_X: tl.constexpr,
    BLOCK_SIZE_Y: tl.constexpr,
    BLOCK_SIZE_Z: tl.constexpr,
):
    # kernel definition here


if __name__ == "__main__":
    torch.manual_seed(0)
    input1 = torch.randn(16, 16, 32, device="cpu")
    input2 = torch.randn(16, 16, 32, device="cpu")
    output = torch.empty_like(input1)

    slice_z, slice_y, slice_x = input1.shape
    stride_z, stride_y, stride_x = input1.stride()

    grid = (
        triton.cdiv(slice_x, BLOCK_SIZE_X),
        triton.cdiv(slice_y, BLOCK_SIZE_Y),
        triton.cdiv(slice_z, BLOCK_SIZE_Z),
    )

    # Launch kernel
    add_3d_slices_kernel[grid](
        input1,
        input2,
        output,
        stride_x,
        stride_y,
        stride_z,
        slice_x,
        slice_y,
        slice_z,
        BLOCK_SIZE_X=BLOCK_SIZE_X,
        BLOCK_SIZE_Y=BLOCK_SIZE_Y,
        BLOCK_SIZE_Z=BLOCK_SIZE_Z,
    )
    triton_viz.launch(share=False, port=5001)
Fullscreen for full detail

Interactive 3D rendering allows you to inspect tensor layouts and memory access patterns from any perspective. Left-click and drag to change camera angle. Scroll to zoom. Right-click and drag to pan.

examples/visualizer/9dims_copy.py
import torch
import triton
import triton.language as tl
import triton_viz


BLOCK = 1024
SHAPE = (2, 3, 4, 3, 4, 2, 4, 2, 3)


@triton_viz.trace("tracer")
@triton.jit
def copy_9d_kernel(x_ptr, y_ptr, n_elements, BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offsets < n_elements
    values = tl.load(x_ptr + offsets, mask=mask, other=0.0)
    tl.store(y_ptr + offsets, values, mask=mask)


if __name__ == "__main__":
    x = torch.randn(SHAPE, device="cpu")
    y = torch.empty_like(x)
    n_elements = x.numel()
    grid = (triton.cdiv(n_elements, BLOCK),)
    copy_9d_kernel[grid](x, y, n_elements, BLOCK=BLOCK)
    triton_viz.launch(share=False, port=5001)
Fullscreen for full detail

Slice, permute, and inspect every value in N-dimensional tensors just as easily as 1D, 2D, or 3D tensors.

examples/visualizer/3dims.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("tracer")
@triton.jit
def add_3d_slices_kernel(
    input_ptr1,
    input_ptr2,
    output_ptr,
    stride_x,
    stride_y,
    stride_z,
    slice_x,
    slice_y,
    slice_z,
    BLOCK_SIZE_X: tl.constexpr,
    BLOCK_SIZE_Y: tl.constexpr,
    BLOCK_SIZE_Z: tl.constexpr,
):
    # kernel definition here


if __name__ == "__main__":
    torch.manual_seed(0)
    input1 = torch.randn(16, 16, 32, device="cpu")
    input2 = torch.randn(16, 16, 32, device="cpu")
    output = torch.empty_like(input1)

    slice_z, slice_y, slice_x = input1.shape
    stride_z, stride_y, stride_x = input1.stride()

    grid = (
        triton.cdiv(slice_x, BLOCK_SIZE_X),
        triton.cdiv(slice_y, BLOCK_SIZE_Y),
        triton.cdiv(slice_z, BLOCK_SIZE_Z),
    )

    # Launch kernel
    add_3d_slices_kernel[grid](
        input1,
        input2,
        output,
        stride_x,
        stride_y,
        stride_z,
        slice_x,
        slice_y,
        slice_z,
        BLOCK_SIZE_X=BLOCK_SIZE_X,
        BLOCK_SIZE_Y=BLOCK_SIZE_Y,
        BLOCK_SIZE_Z=BLOCK_SIZE_Z,
    )
    triton_viz.launch(share=False, port=5001)
Fullscreen for full detail

Inspect operation inputs/outputs at specific Program IDs (PIDs) to understand block-level behavior. View which elements are loaded/stored in each program (highlighted in blue/orange respectively). Toggle "All Program IDs" to see how all programs load tensor values simultaneously.

examples/visualizer/matmul.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("tracer")
@triton.jit
def matmul_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    M: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    # kernel definition here


if __name__ == "__main__":
    torch.manual_seed(0)
    M, N, K = 32, 32, 64
    a = torch.randn((M, K), dtype=torch.float32)
    b = torch.randn((K, N), dtype=torch.float32)
    c = torch.empty((M, N), dtype=torch.float32)

    BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    matmul_kernel[grid](
        a,
        b,
        c,
        M,
        N,
        K,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K,
    )
    triton_viz.launch(share=False, port=5001)
Fullscreen for full detail

Directly map visual operations back to your source code lines for seamless debugging.

examples/visualizer/matmul.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("tracer")
@triton.jit
def matmul_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    M: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    # kernel definition here


if __name__ == "__main__":
    torch.manual_seed(0)
    M, N, K = 32, 32, 64
    a = torch.randn((M, K), dtype=torch.float32)
    b = torch.randn((K, N), dtype=torch.float32)
    c = torch.empty((M, N), dtype=torch.float32)

    BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    matmul_kernel[grid](
        a,
        b,
        c,
        M,
        N,
        K,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K,
    )
    triton_viz.launch(share=False, port=5001)
Fullscreen for full detail

Visualize tensor values with color gradients to quickly identify outliers, zeros, or saturation.

examples/visualizer/matmul.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("tracer")
@triton.jit
def matmul_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    M: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    stride_am,
    stride_ak,
    stride_bk,
    stride_bn,
    stride_cm,
    stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    # kernel definition here


if __name__ == "__main__":
    torch.manual_seed(0)
    M, N, K = 32, 32, 64
    a = torch.randn((M, K), dtype=torch.float32)
    b = torch.randn((K, N), dtype=torch.float32)
    c = torch.empty((M, N), dtype=torch.float32)

    BLOCK_M, BLOCK_N, BLOCK_K = 16, 16, 32
    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))

    matmul_kernel[grid](
        a,
        b,
        c,
        M,
        N,
        K,
        a.stride(0),
        a.stride(1),
        b.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K,
    )
    triton_viz.launch(share=False, port=5001)
Fullscreen for full detail

Analyze the distribution of values within your tensors to understand data distributions and make the best quantization decisions.

The profiler flags performance hazards like non-unrolled loops, inefficient mask usage, and missing buffer_load optimizations while tracking load/store byte counts with low-overhead sampling.

examples/profiler/load_store.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("profiler")
@triton.jit
def simple_kernel(
    x_ptr,
    output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    tl.store(output_ptr + offsets, x, mask=mask)


if __name__ == "__main__":
    cfg.reset()
    device = "cpu"
    size = 12
    BLOCK_SIZE = 8
    torch.manual_seed(0)
    x = torch.arange(size, dtype=torch.float32, device=device)
    output = torch.empty_like(x)
    grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),)
    simple_kernel[grid](x, output, size, BLOCK_SIZE)

Collects detailed performance metrics including cycle counts, memory bandwidth, and cache hit rates.

TERMINAL OUTPUT
============================================================ Profiler Issues Summary ============================================================
============================================================
---------- Profiler: For-Loop Unrolling Statistics ---------
============================================================
No for-loops detected.
============================================================
============================================================
---------- Profiler: Mask Ratio Statistics -----------------
============================================================
────────────────────────────────────────
Overall Load Operations:
  Total mask elements:     8
  False elements:          4
  Masked percentage:       50.00%
Overall Store Operations:
  Total mask elements:     8
  False elements:          4
  Masked percentage:       50.00%
────────────────────────────────────────
Per-Operation Breakdown:
Top 5 Operations by False Elements:
────────────────────────────────────────
#1. LOAD at load_store.py:22
    Total elements: 8
    False elements: 4 (50.0%)
    Code: x = tl.load(x_ptr + offsets, mask=mask)
#2. STORE at load_store.py:23
    Total elements: 8
    False elements: 4 (50.0%)
    Code: tl.store(output_ptr + offsets, x, mask=mask)
============================================================
============================================================
---------- Profiler: Buffer Load Issue Detection -----------
============================================================
>>>>>> Warning: Potential Buffer Load Issue Detected! <<<<<<
Some memory access offsets are within 32-bit range,
but Buffer Load optimization was NOT used in the kernel.
This may lead to suboptimal performance on AMD GPUs.
Consider enabling Buffer Load optimization.
============================================================
============================================================ Profiler Issues Summary Ends =======================================================

The sanitizer symbolically checks tensor memory accesses for out-of-bounds errors and emits rich reports with tensor metadata, call stack, and expression trees, with optional fake-memory storage to avoid real reads.

examples/sanitizer/gemm.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("sanitizer")
@triton.jit
def gemm_kernel(
    A, B, C, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, TILE_SIZE: tl.constexpr
):
    m_block = tl.program_id(0)
    n_block = tl.program_id(1)
    range_m = tl.arange(0, TILE_SIZE)
    range_n = tl.arange(0, TILE_SIZE)
    range_k = tl.arange(0, TILE_SIZE)
    range_m_block = TILE_SIZE * m_block + range_m[:, None]
    range_n_block = TILE_SIZE * n_block + range_n[None, :]
    accum = tl.zeros((TILE_SIZE, TILE_SIZE), dtype=tl.float32)
    for k_block in range(K // TILE_SIZE):
        range_k_block = TILE_SIZE * k_block + range_k
        A_off = K * range_m_block + range_k_block[None, :]
        A_tile = tl.load(A + A_off)

        B_off = N * range_k_block[:, None] + range_n_block
        B_tile = tl.load(B + B_off)

        accum += tl.dot(A_tile, B_tile, allow_tf32=False)
    C_off = N * range_m_block + range_n_block
    tl.store(C + C_off, accum)


def test_gemm():
    M, N, K = 32, 32, 32
    A = torch.randn((M, K))
    B = torch.randn((K, N))
    C = torch.empty((M, N))
    tile_size = 16

    gemm_kernel[(M // tile_size, N // tile_size)](A, B, C, M, N, K, tile_size)
    print("GEMM ran without any out-of-bounds errors!")

test_gemm()

When code is correct, the sanitizer validates all memory accesses without interrupting execution.

TERMINAL OUTPUT
GEMM ran without any out-of-bounds errors!
examples/sanitizer/gemm_oob.py
import torch
import triton
import triton.language as tl
import triton_viz


@triton_viz.trace("sanitizer")
@triton.jit
def gemm_kernel(
    A, B, C, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, TILE_SIZE: tl.constexpr
):
    m_block = tl.program_id(0)
    n_block = tl.program_id(1)
    range_m = tl.arange(0, TILE_SIZE)
    range_n = tl.arange(0, TILE_SIZE)
    range_k = tl.arange(0, TILE_SIZE)
    range_m_block = TILE_SIZE * m_block + range_m[:, None]
    range_n_block = TILE_SIZE * n_block + range_n[None, :]
    accum = tl.zeros((TILE_SIZE, TILE_SIZE), dtype=tl.float32)
    for k_block in range(K // TILE_SIZE):
        range_k_block = TILE_SIZE * k_block + range_k
        A_off = K * range_m_block + range_k_block[None, :]
        A_tile = tl.load(A + A_off + 1) # Out-Of-Bounds Access HERE!

        B_off = N * range_k_block[:, None] + range_n_block
        B_tile = tl.load(B + B_off)

        accum += tl.dot(A_tile, B_tile, allow_tf32=False)
    C_off = N * range_m_block + range_n_block
    tl.store(C + C_off, accum)


def test_gemm():
    M, N, K = 32, 32, 32
    A = torch.randn((M, K))
    B = torch.randn((K, N))
    C = torch.empty((M, N))
    tile_size = 16

    gemm_kernel[(M // tile_size, N // tile_size)](A, B, C, M, N, K, tile_size)
    print("GEMM ran without any out-of-bounds errors!")

test_gemm()

The sanitizer catches invalid memory accesses instantly, providing a detailed report with tensor info, call stack, and symbolic expression trees.

TERMINAL OUTPUT
🚨 ILLEGAL MEMORY ACCESS DETECTED 🚨
━━━ Code Context ━━━
File: triton-viz/examples/sanitizer/gemm_oob.py
Function: gemm_kernel
Line 25:
  22 │     for k_block in range(K // TILE_SIZE):
  23 │         range_k_block = TILE_SIZE * k_block + range_k
  24 │         A_off = K * range_m_block + range_k_block[None, :]
→ 25 │         A_tile = tl.load(A + A_off + 1) # Out-Of-Bounds Access HERE!
  26 │
  27 │         B_off = N * range_k_block[:, None] + range_n_block
  28 │         B_tile = tl.load(B + B_off)
━━━ Tensor Information ━━━
arg:         A
dtype:       torch.float32             shape:       torch.Size([32, 32])
strides:     (32, 1)                   device:      cpu
contiguous:  True                      base_ptr:    0x000000002fa75080
size:        4096 bytes                valid_range: [0x000000002fa75080, 0x000000002fa76080)
━━━ Call Stack ━━━
#1 gemm_kernel at gemm_oob.py:25
   └─ A_tile = tl.load(A + A_off + 1) # Out-Of-Bounds Access HERE!
━━━ Violation Details ━━━
Violation address: 0x000000002fa76080
━━━ Symbolic Expression Tree ━━━
load [dtype=pointer]
├── ptr: addptr [dtype=<['16', '16'], pointer>]
│   ├── ptr: splat [dtype=<['16', '16'], pointer>]
│   │   ├── block_type: const=<['16', '16'], pointer> [dtype=<['16', '16'], pointer>]
│   │   └── arg: const=799494272 [dtype=pointer]
│   └── offset: add [dtype=<['16', '1'], int32>]
│       ├── lhs: add [dtype=<['16', '1'], int32>]
│       │   ├── lhs: broadcast [dtype=<['16', '1'], int32>]
│       │   │   └── arg: mul [dtype=<['16', '1'], int32>]
│       │   │       ├── lhs: splat [dtype=<['16', '1'], int32>]
│       │   │       │   ├── block_type: const=<['16', '1'], int32> [dtype=<['16', '1'], int32>]
│       │   │       │   └── arg: const=32 [dtype=int32]
│       │   │       └── rhs: add [dtype=<['16', '1'], int32>]
│       │   │           ├── lhs: splat [dtype=<['16', '1'], int32>]
│       │   │           │   ├── block_type: const=<['16', '1'], int32> [dtype=<['16', '1'], int32>]
│       │   │           │   └── arg: mul [dtype=int32]
│       │   │           │       ├── lhs: const=16 [dtype=int32]
│       │   │           │       └── rhs: pid_0 [dtype=int32]
│       │   │           │           └── axis: const=0 [dtype=int32]
│       │   │           └── rhs: expand_dims [dtype=<['16'], int32>]
│       │   │               ├── arg: arange [dtype=<['16'], int32>]
│       │   │               │   ├── ret_ty: const=<['16'], int32> [dtype=<['16'], int32>]
│       │   │               │   ├── start: const=0 [dtype=int32]
│       │   │               │   └── end: const=16 [dtype=int32]
│       │   │               └── axis: const=1 [dtype=int32]
│       │   └── rhs: broadcast [dtype=<['16'], int32>]
│       │       └── arg: expand_dims [dtype=<['16'], int32>]
│       │           ├── arg: add [dtype=<['16'], int32>]
│       │           │   ├── lhs: splat [dtype=<['16'], int32>]
│       │           │   │   ├── block_type: const=<['16'], int32> [dtype=<['16'], int32>]
│       │           │   │   └── arg: mul [dtype=int32]
│       │           │   │       ├── lhs: const=16 [dtype=int32]
│       │           │   │       └── rhs: const=loop_i_12 [dtype=int32]
│       │           │   └── rhs: arange [dtype=<['16'], int32>]
│       │           │       ├── ret_ty: const=<['16'], int32> [dtype=<['16'], int32>]
│       │           │       ├── start: const=0 [dtype=int32]
│       │           │       └── end: const=16 [dtype=int32]
│       │           └── axis: const=0 [dtype=int32]
│       └── rhs: splat [dtype=<['16', '16'], int32>]
│           ├── block_type: const=<['16', '16'], int32> [dtype=<['16', '16'], int32>]
│           └── arg: const=1 [dtype=int32]
├── mask: None
└── other: None
End of IMA Diagnostic Report

Sponsored by

Special thanks to AMD for supporting Triton-Viz.

Ready to visualize your kernels?

Get started with Triton-Viz in minutes. No GPU required for development.

View on GitHub Browse Examples