A copy of this work was available on the public web and has been preserved in the Wayback Machine. The capture dates from 2022; you can also visit the original URL.
The file type is application/pdf
.
Equinox: neural networks in JAX via callable PyTrees and filtered transformations
[article]
2021
JAX and PyTorch are two popular Python autodifferentiation frameworks. JAX is based around pure functions and functional programming. PyTorch has popularised the use of an object-oriented (OO) class-based syntax for defining parameterised functions, such as neural networks. That this seems like a fundamental difference means current libraries for building parameterised functions in JAX have either rejected the OO approach entirely (Stax) or have introduced OO-to-functional transformations,
doi:10.48550/arxiv.2111.00254
fatcat:4oaoly3cjbgpfe7dbmhqarwmmq