Jax might be faster than Pytorch, I don’t know. I’m talking about TF. When I switched from TF to Pytorch 3 years ago, I got no slowdown on any of computer vision models at the time. And I remember looking at a couple of independent benchmarks which also showed them to be roughly the same in speed.
I was reading this and thinking it was a pretty terrible answer - glad it is just generated by an AI and not you personally so I'm not insulting you.
JAX is basically numpy on steroids and lets you do a lot of non-standard things (like a differentiable physics simulation or something) that would be harder with Pytorch.
They are both "high-performance."
Pytorch is more geared towards traditional deep learning and has the utilities and idioms to support it.
JAX introduced a lot of cool concepts (e.g. autobatching (vmap), autoparallel (pmap)) and supported a lot of things that PyTorch didn't (e.g. forward mode autodiff).
And at least for my applications (scientific computing), it was much faster (~100x) due to a much better JIT compiler and reduced Python overhead.
...but! PyTorch has worked hard to introduce all of the former, and the recent PyTorch 2 announcement was primarily about a better JIT compiler for PyTorch. (I don't think anyone has done serious non-ML benchmarks for this though, so it remains to be seen how this holds up.)
There are still a few differences. E.g. JAX has a better differential equation solving ecosystem. PyTorch has a better protein language model ecosystem. JAX offers some better power-user features like custom vmap rules. PyTorch probably has a lower barrier to entry.
(FWIW I don't know how either hold up specifically for DSP.)
I'd honestly suggest just trying both; always nice to have a broader selection of tools available.
That's my experience as well. PyTorch dominates the ecosystem.
Which is a shame, because JAX's approach is superior.[a]
---
[a] In my experience, anytime I've have to do anything in PyTorch that isn't well supported out-of-the-box, I've quickly found myself tinkering with Triton, which usually becomes... very frustrating. Meanwhile, JAX offers decent parallelization of anything I write in plain Python, plus really nice primitives like jax.lax.while_loop, jax.lax.associative_scan, jax.lax.select, etc. And yet, I keep using PyTorch... because of the ecosystem.
IMO, this is not a fair comparison because Pytorch spans a larger amount of abstraction than jax (I don't quite know how to explain it other than "spans a larger amount of abstraction").
You can do much of the jax stuff in pytorch, you can't do the high level nn.LSTM stuff in jax, you have to use like flax or objax or something.
Wait wat, jax and also pytorch is used in a lot more areas then NN's.
Jax is even consider to do better in that department in terms on performance then all of julia so wat are u talking about
I’m a researcher, not using anything in production, but I find jax more usable as a general GPU-accelerated tensor math library. PyTorch is more specifically targeted at the neural network use case. It can be shoehorned into other use cases, but is clearly designed & documented for NN training & inference.
Very true! I should have appended this under the algorithmic improvements. This is also a reason ML has exploded. Tensorflow and Pytorch enabled easy GPU usage and so we can spend a lot less time programming and debugging. (even writing CUDA subroutines is easier!) I mean pytorch is basically numpy but with GPU access so it is fantastic for any optimization, even non ml. I haven't played around with JAX but I hear it is better for more statistical stuff because of this.
For me the main point is that JAX is slightly lower level than Pytorch and has nicer abstraction (no need to worry about tensors that might not store a gradient or wetehr you are on the GPU: it eliminates lots of newcomers bug) which makes it a great fit to build Deep learning frameworks, but also simulations, on top of it.
Meh, the comparison is somewhat pointless when it doesn't account for the slowdown that the vast majority of pytorch codebases experience on TPU's vs using JAX and it's accelerations specific to TPUs, and vice versa.
Tensorflow has some advantages, like being able to use tf-lite for embedded devices. JAX is amazing on the TPU, which AFAIK pytorch doesn't have great support for.
I assume most people will still research in PyTorch, but then move it over the Keras for production models if they need multi-platform support.
At least prior to this announcement: JAX was much faster than PyTorch for differentiable physics. (Better JIT compiler; reduced Python-level overhead.)
E.g for numerical ODE simulation, I've found that Diffrax (https://github.com/patrick-kidger/diffrax) is ~100 times faster than torchdiffeq on the forward pass. The backward pass is much closer, and for this Diffrax is about 1.5 times faster.
It remains to be seen how PyTorch 2.0 will compare, of course!
Right now my job is actually building out the scientific computing ecosystem in JAX, so feel free to ping me with any other questions.
it is really a matter of having faith on pytorch (or JAX) or on third-party cross-platform supports like llama-cpp. Apparently pytorch reduces a lot of complexity and grows extremely faster on cross-platform supports.
Have you actually tried that or are you just regurgitating Google’s marketing? I’ve seen Jax perform _slower_ than PyTorch on practical GPU workloads on the exact same machine, and not by a little, by something like 20%. I too thought I’d be getting great performance and “saving money”, but reality turned out to be a bit more complicated than that - you have to benchmark and tune.
reply