PyTorch is working on catching up — I think they’ve already got some kind of “vmap” style function transformations in beta. And I’m sure they’ll figure out good higher order derivatives too. That’s like 90% of what people want out of Jax, so I think they’ll be able to compete.
The downside of Jax is it’s not easy to debug. PyTorch, for better or for worse, will actually run your Python code as you wrote it.
> I ran it on the TPU VM, saw the loss curve go down, and it was like an electric shock. "Wow! That actually... worked? Huh. that's weird. Things never work on the first try. I'm impressed."
> Then I plopped `import pdb; pdb.set_trace()` in the middle of the `loss` function and ran it again. It dropped me into the Python debugger.
> There was a tensor named `X_bt`. I typed `X_bt`. The debugger printed the value of `X_bt`.
> I was able to print out all the values of every variable, just like you'd expect Python to be able to do.
> There was a tensor named `Y_bt`. I typed `X_bt + Y_bt`. I was now staring at exactly what I expected: the sum of those two tensors.
> I could write `x + y`, or create new variables, or anything else I wanted.
> Now I was real impressed.
> If it sounds weird that I'm so easily impressed, it's because, you godda understand: until now, TPUs were a complete pain in the ass to use. I kept my feelings to myself, because I understood that the Cloud TPU team were working hard to improve TPUs, and the TFRC support team was wonderful, and I had so many TPUs to play with. But holy moly, if you were expecting any of the above examples to just work on the first try when using Tensorflow V1 on TPUs, you were in for a rude awakening. And if you thought "Well, Tensorflow v2 is supposedly a lot better, right? Surely I'll be able to do basic things without worrying...."
> ... no. Not even close. Not until Jax + TPU VMs.
In the subsequent year, it's been nothing but joy.
> This module introduces the host callback functions call(), id_tap(), and id_print(), that send their arguments from the device to the host and invoke user-defined Python functions on the host, optionally returning results back to the device computation.
If you scroll to the very bottom of that file, you'll see an example of compiling your own XLA JIT'ed code which subsequently calls back into Python. TPUs do precisely the same thing.
Point being:
> PyTorch, for better or for worse, will actually run your Python code as you wrote it.
... is also true of jax, to within a rounding error less than "I personally don't mind writing id_print(x) instead of print(x)." :)
thanks, this is going to be very helpful for me. i guess it’s kind of like that old piece of advice, if you want some free Linux tech support, just post “Linux can’t do this but Windows can” :)
I've found jax's debugging to be in different ways better and worse. The fact that the function transformations are traced is great. It means you can step debug in the tracing steps just as well as the actual eval steps, and you just have jaxpr.Tracers instead of jnp.ndarrays, or whatever. Outside of the transformations, it's just as easy to debug as numpy, which is a blessing. That's one of the biggest selling points.
Debugging jitted and pmapped code, on the other hand, is a pain. Since you can always step out of them to debug, it means that it's debugging performance issues that sucks. And boy does it suck. If anyone knows a good story for figuring out why my jitted thing is slow as hell on TPU, I'm all ears. The profiling section of the official docs is one of their weaker sections. (but big props to the overall documentation quality!)
The downside of Jax is it’s not easy to debug. PyTorch, for better or for worse, will actually run your Python code as you wrote it.
reply