from Hacker News

How the jax.jit() JIT compiler works in jax-js

by ekzhang on 5/14/25, 2:07 PM with 1 comments

  • by ekzhang on 5/14/25, 5:51 PM

    Hello - I wanted to share a bit about an open-source project I’ve been working on, building a fundamental library for ML and numerical computing in the browser.

    There are different tradeoffs here, but I’d like to replicate the speed and ergonomics of Python libraries with NumPy/PyTorch/JAX without the difficulty of distribution. jax-js is lightweight enough to be embedded in a website, but optimized to take advantage of tech like Wasm and WebGPU.

    Let me know if you have any feedback on this compiler design. It’s working well and faster than tfjs at matrix multiplication, while having kernel fusion as well.