When I first read about JAX I thought it would kill Pytorch, but I'm not sure I can get on with an immutable language for tensor operations in deep learning.
If I have an array `x` and want to set index 0 to 10, I cannot do:
x[0] = 10
I instead have to do:
y = x.at[0].set(10)
I'm sure I could get used to it, but it really puts me off.
Agreed that that's a bit ugly but at least in the ML context you rarely if ever need to do this (personally I only do this on the input to models, where we use pure numpy).
I feel the same. There are probably more ergonomic and generalizable ways to do whatever it is you need to do. Treat it as functional programming and let the XLA compiler handle things.
If I have an array `x` and want to set index 0 to 10, I cannot do:
I instead have to do: I'm sure I could get used to it, but it really puts me off.reply