If you're talking about Jax, there's a couple different reasons to bother for research
1. Full numpy compatibility.
2. More efficient higher order gradients (because of forward mode auto diff). Naively it's asymptotic improvement, but I believe Pytorch uses some autodiff tricks to perform higher order gradients with backwards mode, at the cost of a decently high constant factor.
3. Some cool transformations like vmap.
4. Full code gen, which is neat especially for scientific computing purposes.
5. A neat API for higher order gradients.
2. and 5. are the most appealing for DL research, 1., 3. and 4. are appealing for those in the stats/scientific computing communities.
PyTorch is working on all of these, to various degrees of effort, but Jax currently has an advantage in these points (and may have a fundamental advantage in design in some).
1. Meh. PyTorch is close enough to not worry about it, and is better in some places.
2. Meh. All the methods people use in practice for deep learning in particular do not use higher order gradients. Most higher order methods are prohibitively memory expensive, and memory is at a premium in acceleration hardware (and so is the bus bandwidth - so you can't "swap to RAM"). I do agree that higher order gradients are the next frontier in optimization though - current optimizer research seems to have stalled, so people focus on training with huge batches and stuff like that. Most SOTA models in my field are trained with SGD+momentum - super primitive stuff. I don't see how Jax would solve the memory problem though. You still have to store those Hessians somewhere, at least partially.
3. Do agree, that's cool if it actually parallelizes nontrivial stuff which e.g. tf.vectorized_map barfs on. Although in a lot of cases you can "vectorize" by concatenating input tensors into a higher dimensional tensor.
4. Meh. Not sure why I'd want that if I already have tracing and JIT.
5. This is #2
With PyTorch though, you get close enough to Numpy to feel at home in both, and there's so much code written for it already that you can usually find a good starting point for your research pretty easily on github and then build on top of that.
If you need to deploy, there's also tracing and jit, which lets you load and serve models with libtorch.
I see what you're saying regarding "advantages", I'm just pointing out that PyTorch might be "good enough" for most people. If I were on that team, I'd focus on providing comfortable transition from TF 2.x which is a dumpster fire (with the exception of TensorFlow Lite which is excellent). That, IMO, would be the only way for this project to achieve mainstream success unless PyTorch disintegrates over time.
I agree from a practitioner standpoint - but you were talking about research :)
1. Much of the scientific computing/stats community is stuck in the past. Many are still using Matlab! As opposed to the CS community, who are used to learning new frameworks, offering the ability to "import jax.numpy as np" and having their scripts just run is valuable to that community. As is having an API that they've only just started to become familiar with (and has way more documentation about).
2. Once again, this is true for practitioners, but not research. Hessian vector products show up in a decent amount of places. For example, if you have an inner optimization loop (a la. most meta learning approaches or Deep Set Prediction Networks) you have a Hessian Vector Product! Perhaps not prevalent in models that practitioners run but definitely something to keep an eye on in research.
3. My understanding is that it actually does a pretty decent job. Enough that it's useful in the prototyping phase.
4. PyTorch JIT is neat, and is what I meant by Pytorch team is "working" on it. However, the JIT doesn't do full code gen (thus, significant operator overhead for say, scalar networks) and has significantly less man hours poured into compared to XLA.
5. I was specifically talking about how you call grad on a function to get a function that returns its gradient. It's a cleaner API than PyTorch's autograd.
Jax is definitely not meant for deployment or industry usage, and I believe their developers hope they'll never be pushed along that direction :^)
I definitely agree that PyTorch is "good enough" for most people. However, among researchers, there's a decent amount of subgroups it could gain favor in.
You'd be surprised how many papers get submitted to ICML/Neurips that don't use PyTorch or TensorFlow at all, in favor of raw numpy, C++, or even MatLab! I think the numbers I had were something about 30% of papers don't use any ML framework. Jax could easily gain favor in this crowd.
There's also the crowd that cares a lot about higher order gradients. Also, admittedly a specific subgroup, but growing. Meta learning people care a lot. So do Neural ODE people. All it takes is for one of these subfields to blow up for higher order gradients to all of a sudden become a lot more appealing.
And finally, you have Google. Google researchers are never going to use PyTorch en masse (probably). If researchers at Google want to switch from TF, their only option is Jax. This is a pretty big subgroup of researchers :)
I definitely agree that Jax has a difficult hill to climb. But, they have a solid foothold within Google, and several subfields very amenable to their advantages.
PyTorch seems like the predominant research framework currently, but if any framework is going to erode their lead, I'd place my bets on Jax.
>> You'd be surprised how many papers get submitted to ICML/Neurips that don't use PyTorch or TensorFlow at all
I do keep up with literature and I do some applied research as well, so yeah, I see such things from time to time. The volume of papers is so intense though that unless there are other redeeming qualities if the paper does not use frameworks I already know (TF and PyTorch), I ignore it entirely. I wouldn't say I missed much that could help me in practice. One exception is Leslie Smith's work on cyclic learning rates and momentum modulation - he did it on some ridiculous setup, but it works for what I do.
I'm more surprised how many papers are written for tiny little datasets that you'd never use in practice, especially optimization papers. I mean, come on guys, I get it it's fast to train on CIFAR or fashion MNIST, but those results rarely translate to anything practical. And some papers are just plain not reproducible at all.
>> Google researchers are never going to use PyTorch en masse
As an ex-Googler, I'd place my bets in something else TBH. Google projects that aren't critical to Google's bottom line tend to deteriorate over time. Just look at TF. I'm not cruel enough to suggest it to my clients anymore, even though I could charge twice as much (because it would take twice as long to get the same result).
Certainly if you ignore those papers, you'd likely have no issue in practice - I suspect many of them are about more theoretical concerns. Perhaps I'll take a look at/post a list tomorrow.
Either way, I believe that our original discussion was on why somebody should bother. I provided a list of (admittedly) somewhat niche reasons. My personal opinion is that Jax will stick around, and at the very least, provide some neat ideas for Pytorch to ... independently come up with :)
>>> I'd place my bets on Jax
Hey hey hey context! Pytorch is currently dominant in research, so who could supplant it? Anecdotally, since I published my article (https://thegradient.pub/state-of-ml-frameworks-2019-pytorch-...) there has been more momentum towards Pytorch (preferred networks and openAI).
So if not Tensorflow, then who? I think Pytorch represents a local optima and is "good enough" for most people. So any newcomer framework needs to bring something new to the table, even if it's niche. I think Jax looks the most promising.
I'd like to see something based on a proper, high performance, statically typed programming language s.t. I could have a modicum of certainty that things would work when someone changes something. With Python, sadly, you don't know until you run things and hit error conditions dynamically. This is unacceptable in larger codebases.
Then one provided the reasons for that, and you brought out your own opinions to defense that you don't wanna use Jax (without even trying)!? With that in mind, a thousand more reasons would not satisfy you.
From the perspective of a person who does DL research stuff with math background, I find Jax way more intuitive than any of other frameworks, Pytorch included. In math, you just don't "zero-ing the gradient" every iteration. In the same way, you just don't have to "forward" first to get the gradient. The more one does experiments with Jax, the more she / he would find it's faster / less steps / more intuitive to test new ideas. And I don't like the object-oriented design of Pytorch, functional seems more intuitive for me. That's why I would bother with Jax.
You might also like the per-example gradients example that appears first in the JAX Github page: this is only one line of code, but important for research areas such as differential privacy.
1. Full numpy compatibility.
2. More efficient higher order gradients (because of forward mode auto diff). Naively it's asymptotic improvement, but I believe Pytorch uses some autodiff tricks to perform higher order gradients with backwards mode, at the cost of a decently high constant factor.
3. Some cool transformations like vmap.
4. Full code gen, which is neat especially for scientific computing purposes.
5. A neat API for higher order gradients.
2. and 5. are the most appealing for DL research, 1., 3. and 4. are appealing for those in the stats/scientific computing communities.
PyTorch is working on all of these, to various degrees of effort, but Jax currently has an advantage in these points (and may have a fundamental advantage in design in some).
reply