I linked to the website (which was updated in May, but its contents could do with more work) because it has examples of how well the suite fits together.
I don't know much about Jax. I've seen competent benchmarks showing an order of magnitude benefit for using ReverseDiff from the AutoDiff suite over Autograd, which is what Pytorch uses for reverse-mode autodiff
I think it's better to think of JAX as a more general framework for differentiable programming and PyTorch more focused specifically on deep learning/neural networks.
The beauty of JAX is that basic usage is basically a single function: `grad`.
You just write whatever Python function you want and can get the derivative/gradient of it trivially. It gets a bit trickier when you need more sophisticated numeric tools like numpy/scipy, but in those cases it's just about swapping out with a JAX version of those.
In this sense JAX is the spiritual success to Autograd. However the really amazing thing about JAX is that not only do you get the autodiff for basically free, you also get very good performance, and basically GPU parallelism without needing to think about it at all.
PyTorch is an awesome library, but largely focus on building Neural Networks specifically. JAX should be thought of a tool that basically any Python programmer can just throw in there whenever they come across a problem that benefits from having differentiable code (which is a lot of cases once you start thinking about differentiation as a first class feature).
Very cool! I love autograd, it had tape-based autodiff way before pytorch, and the way it wraps numpy is much more convenient than tensorflow/pytorch. Been wanting GPU support in autograd for years now, so am very happy to see this.
I have some academic software (https://github.com/popgenmethods/momi2) that uses autograd, was planning to port it to pytorch since it's better supported/maintained, but now I'll have to consider jax. Though I'm a little worried about the maturity of the project, seems like the numpy/scipy coverage is not all the way there yet. Then again, it would be fun to contribute back to JAX, I did contribute a couple PRs to autograd back in the day so I think I could jump right into it...
No, autograd acts similarly to PyTorch in that it builds a tape that it reverses while PyTorch just comes with more optimized kernels (and kernels that act on GPUs). The AD that I was referencing was tangent (https://github.com/google/tangent). It was an interesting project but it's hard to see who the audience is. Generating Python source code makes things harder to analyze, and you cannot JIT compile the generated code unless you could JIT compile Python. So you might as well first trace to a JIT-compliable sublanguage and do the actions there, which is precisely what Jax does. In theory tangent is a bit more general, and maybe you could mix it with Numba, but then it's hard to justify. If it's more general then it's not for the standard ML community for the same reason as the Julia tools, but then it better do better than the Julia tools in the specific niche that they are targeting. That generality means that it cannot use XLA, and thus from day 1 it wouldn't get the extra compiler optimizations that some which uses XLA does (Jax). Jax just makes much more sense for the people who were building it, it chose its niche very well.
Actually, advanced autodiff is one of its intended points of, er, differentiation :). The authors wrote the original Autograd package [0], released in 2014, that led to “autograd” becoming used as a generic term in PyTorch and other packages. JAX has all of the autodiff operations that Autograd does, including `grad`, `vjp`, `jvp`, etc.
We’re working on the number of supported NumPy ops, which is limited right now, but it’s early days.
Try it out, we’re really excited to see what you build with it!
If you're talking about Jax, there's a couple different reasons to bother for research
1. Full numpy compatibility.
2. More efficient higher order gradients (because of forward mode auto diff). Naively it's asymptotic improvement, but I believe Pytorch uses some autodiff tricks to perform higher order gradients with backwards mode, at the cost of a decently high constant factor.
3. Some cool transformations like vmap.
4. Full code gen, which is neat especially for scientific computing purposes.
5. A neat API for higher order gradients.
2. and 5. are the most appealing for DL research, 1., 3. and 4. are appealing for those in the stats/scientific computing communities.
PyTorch is working on all of these, to various degrees of effort, but Jax currently has an advantage in these points (and may have a fundamental advantage in design in some).
I haven't heard about JAX before, but been tinkering in pytorch. Would I also be able to switch the use of np arrays here to torch, and then do .backwards() and get kinda the same benefits of JAX, or how does it differ in this regard?
At least prior to this announcement: JAX was much faster than PyTorch for differentiable physics. (Better JIT compiler; reduced Python-level overhead.)
E.g for numerical ODE simulation, I've found that Diffrax (https://github.com/patrick-kidger/diffrax) is ~100 times faster than torchdiffeq on the forward pass. The backward pass is much closer, and for this Diffrax is about 1.5 times faster.
It remains to be seen how PyTorch 2.0 will compare, of course!
Right now my job is actually building out the scientific computing ecosystem in JAX, so feel free to ping me with any other questions.
For me the main point is that JAX is slightly lower level than Pytorch and has nicer abstraction (no need to worry about tensors that might not store a gradient or wetehr you are on the GPU: it eliminates lots of newcomers bug) which makes it a great fit to build Deep learning frameworks, but also simulations, on top of it.
Jax might be faster than Pytorch, I don’t know. I’m talking about TF. When I switched from TF to Pytorch 3 years ago, I got no slowdown on any of computer vision models at the time. And I remember looking at a couple of independent benchmarks which also showed them to be roughly the same in speed.
JAX introduced a lot of cool concepts (e.g. autobatching (vmap), autoparallel (pmap)) and supported a lot of things that PyTorch didn't (e.g. forward mode autodiff).
And at least for my applications (scientific computing), it was much faster (~100x) due to a much better JIT compiler and reduced Python overhead.
...but! PyTorch has worked hard to introduce all of the former, and the recent PyTorch 2 announcement was primarily about a better JIT compiler for PyTorch. (I don't think anyone has done serious non-ML benchmarks for this though, so it remains to be seen how this holds up.)
There are still a few differences. E.g. JAX has a better differential equation solving ecosystem. PyTorch has a better protein language model ecosystem. JAX offers some better power-user features like custom vmap rules. PyTorch probably has a lower barrier to entry.
(FWIW I don't know how either hold up specifically for DSP.)
I'd honestly suggest just trying both; always nice to have a broader selection of tools available.
JAX enables using (parts of) existing numpy codebases in disciplines other than deep learning. Autodiff and compilation to GPUs are very useful for all kinds of algorithms and processing pipelines.
OTOH PyTorch seems to be highly explosive if you try to use it outside the mainstream use (i.e. neural networks). There's sadly no performant autodiff system for general purpose Python. Numba is fine for performance, but does not support autodiff. JAX aims to be sort of general purpose, but in practice it is quite explosive when doing something other than neural networks.
A lot of this is probably due to supporting CPUs and GPUs with the same interface. There are quite profound differences in how CPUs and GPUs are programmed, so the interface tends to restrict especially more "CPU-oriented" approaches.
I have nothing against supporting GPUs (although I think their use is overrated and most people would do fine with CPUs), but Python really needs a general purpose, high performance autodiff.
PyTorch is working on catching up — I think they’ve already got some kind of “vmap” style function transformations in beta. And I’m sure they’ll figure out good higher order derivatives too. That’s like 90% of what people want out of Jax, so I think they’ll be able to compete.
The downside of Jax is it’s not easy to debug. PyTorch, for better or for worse, will actually run your Python code as you wrote it.
I really don't feel that there is magic in PyTorch or jax, but that may be because I have written my own autograd libs.
In PyTorch you have a graph that is created on runtime by connecting the operations together in a transparent manner.
Jax may feel a bit magic, but all that's done is sending / splitting tracers and recording the operations and compiling; by limiting the language, you have controlled branching with the proper semantics.
---
The main reason I have had such with Julia is simply because of how early / soon I ended up needing to use them whereas with the other languages, you can get away without getting into the very messed up things.
Then one provided the reasons for that, and you brought out your own opinions to defense that you don't wanna use Jax (without even trying)!? With that in mind, a thousand more reasons would not satisfy you.
From the perspective of a person who does DL research stuff with math background, I find Jax way more intuitive than any of other frameworks, Pytorch included. In math, you just don't "zero-ing the gradient" every iteration. In the same way, you just don't have to "forward" first to get the gradient. The more one does experiments with Jax, the more she / he would find it's faster / less steps / more intuitive to test new ideas. And I don't like the object-oriented design of Pytorch, functional seems more intuitive for me. That's why I would bother with Jax.
this definitely limits its generality relative to jax, which makes it less than ideal for anything other than 'typical' deep neural networks
this is especially true when the research in question is related to things like physics or combining physical models and machine learning, which imho is very interesting. those are use cases that pytorch just isn't good at.
people already ported a lot of stuff from pytorch to jax.
if you're a research scientist or grad student, to a certain extent a lot of projects are "greenfield" so it's easy to jump on a new framework if it is nice to use and offers some advantage.
I don't know much about Jax. I've seen competent benchmarks showing an order of magnitude benefit for using ReverseDiff from the AutoDiff suite over Autograd, which is what Pytorch uses for reverse-mode autodiff
reply