Using device="webgpu"¶
torch-webgpu registers a custom PyTorch device called webgpu. You can use it just like cuda or mps.
Creating Tensors¶
import torch
import torch_webgpu
# Create directly on WebGPU
x = torch.tensor([1.0, 2.0, 3.0], device="webgpu")
y = torch.randn(3, 4, device="webgpu")
z = torch.zeros(10, device="webgpu")
Moving Tensors¶
# CPU to WebGPU
cpu_tensor = torch.randn(3, 3)
webgpu_tensor = cpu_tensor.to("webgpu")
# WebGPU to CPU
back_to_cpu = webgpu_tensor.to("cpu")
# Alternative syntax
webgpu_tensor = cpu_tensor.webgpu() # Not yet supported
Checking Device¶
x = torch.tensor([1.0], device="webgpu")
print(x.device) # webgpu
print(x.is_webgpu) # Not yet supported, use str(x.device) == "webgpu"
Supported Data Types¶
Currently, torch-webgpu supports:
- ✅
torch.float32(default, recommended) - 🚧
torch.float16(coming soon) - 🚧
torch.int32 - 🚧
torch.int64
Device Transfers¶
Data transfer between CPU and WebGPU:
# CPU → WebGPU
x_cpu = torch.randn(1000, 1000)
x_webgpu = x_cpu.to("webgpu") # Copies data to GPU memory
# WebGPU → CPU
x_back = x_webgpu.to("cpu") # Copies data back
# In-place operations stay on device
x_webgpu.mul_(2) # Still on WebGPU
Limitations¶
Current limitations of the device backend:
- Only float32 - Other dtypes coming soon
- Synchronous execution - Operations block until complete
- No direct CUDA interop - Must go through CPU
Best Practices¶
- Minimize transfers - Keep data on WebGPU as long as possible
- Use torch.compile - For models, the compiler is more optimized
- Batch operations - Larger operations are more efficient