by ekzhang on 5/14/25, 2:07 PM with 1 comments
by ekzhang on 5/14/25, 5:51 PM
There are different tradeoffs here, but I’d like to replicate the speed and ergonomics of Python libraries with NumPy/PyTorch/JAX without the difficulty of distribution. jax-js is lightweight enough to be embedded in a website, but optimized to take advantage of tech like Wasm and WebGPU.
Let me know if you have any feedback on this compiler design. It’s working well and faster than tfjs at matrix multiplication, while having kernel fusion as well.