API Reference¶
torch_webgpu Module¶
Importing¶
Importing torch_webgpu registers the WebGPU backend with PyTorch. After importing, you can use device="webgpu".
Compiler¶
webgpu_backend¶
The torch.compile backend for WebGPU.
Usage:
compiled_fn = torch.compile(fn, backend=webgpu_backend)
compiled_model = torch.compile(model, backend=webgpu_backend)
Parameters for torch.compile:
| Parameter | Type | Description |
|---|---|---|
model |
callable | Function or nn.Module to compile |
backend |
str or callable | Use webgpu_backend |
dynamic |
bool | Enable dynamic shapes (default: False) |
fullgraph |
bool | Require full graph capture (default: False) |
disable |
bool | Disable compilation (default: False) |
Device Operations¶
Creating Tensors¶
# Direct creation
x = torch.tensor([1, 2, 3], device="webgpu")
x = torch.randn(3, 4, device="webgpu")
x = torch.zeros(10, device="webgpu")
x = torch.ones(5, 5, device="webgpu")
Moving Tensors¶
Internal APIs¶
Warning
Internal APIs may change without notice.
High-Level IR¶
Operations in the high-level intermediate representation.
Low-Level IR¶
Operations in the low-level intermediate representation.
Lowering¶
Maps IR operations to actual PyTorch/WebGPU operations.
C++ Extension¶
The C++ extension provides the core WebGPU operations:
torch.ops.webgpu¶
Registered operations accessible via torch.ops.webgpu:
| Operation | Description |
|---|---|
create_buffer |
Create a WebGPU buffer |
write_buffer |
Write data to a buffer |
Environment Variables¶
| Variable | Description |
|---|---|
DAWN_PREFIX |
Path to Dawn installation (for building) |
Supported Devices¶
| Device String | Description |
|---|---|
"webgpu" |
WebGPU device |
"cpu" |
CPU device (standard PyTorch) |