Using torch.compile¶
The torch.compile backend is the recommended way to run models on WebGPU. It traces your model and optimizes it for WebGPU execution.
Basic Usage¶
import torch
from torch_webgpu.compiler.webgpu_compiler import webgpu_backend
# Compile a function
@torch.compile(backend=webgpu_backend)
def my_function(x, y):
return torch.relu(x @ y)
# Or compile a model
model = MyModel()
compiled_model = torch.compile(model, backend=webgpu_backend)
How It Works¶
- Tracing - PyTorch traces your function/model to create an FX graph
- IR Conversion - The graph is converted to torch-webgpu's IR
- Lowering - IR is lowered to WebGPU operations
- Execution - Operations run on WebGPU
Compiling Models¶
import torch.nn as nn
from torch_webgpu.compiler.webgpu_compiler import webgpu_backend
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(512, 256)
def forward(self, x):
return torch.relu(self.linear(x))
model = MyModel()
model.eval() # Important: set to eval mode
# Compile
compiled = torch.compile(model, backend=webgpu_backend)
# Run
with torch.no_grad():
output = compiled(torch.randn(1, 512))
Compiler Options¶
# Dynamic shapes (for variable input sizes)
compiled = torch.compile(model, backend=webgpu_backend, dynamic=True)
# Disable for debugging
compiled = torch.compile(model, backend=webgpu_backend, disable=True)
What Gets Compiled¶
The compiler handles:
- ✅ Linear layers
- ✅ Activations (ReLU, SiLU, GELU, etc.)
- ✅ Normalization (LayerNorm, RMSNorm)
- ✅ Attention (scaled_dot_product_attention)
- ✅ Embeddings
- ✅ Element-wise ops (add, mul, div, etc.)
- ✅ Reductions (sum, mean, max, etc.)
- ✅ Shape ops (view, reshape, transpose, etc.)
Debugging¶
Enable debug output to see what's happening:
This prints shape information for each operation.
Common Issues¶
Graph Break¶
If you see "graph break", some operation isn't supported:
# This may cause a graph break
def forward(self, x):
if x.sum() > 0: # Dynamic condition - graph break!
return x * 2
return x
Solution: Keep control flow static.
Unsupported Op¶
The operation isn't implemented yet. Check Supported Ops or open an issue.
Performance Tips¶
- Use eval mode -
model.eval()before compiling - Use no_grad -
with torch.no_grad():for inference - Warm up - First run compiles; subsequent runs are faster
- Batch inputs - Larger batches are more efficient