from Hacker News

NNX – Neural Networks for JAX

by cgarciae on 6/12/23, 1:38 PM with 1 comments

  • by cgarciae on 6/12/23, 1:38 PM

    NNX is a Neural Networks library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Flax with a simplified, Pythonic API akin to that of PyTorch.