Nice, this is really helpful to compare with minigpt in pytorch. I think I like the dimensions in the variable names. Kinda miss seeing the modules used in a module being setup in __init__ as in pytorch, but I get its all cool to jit. I prefer einsum to transpose etc for multi head so at least you only have one line to stare at, but that could be done in torch too (sidenote, I kinda wish einsum let you have longer names so it could be something like `... Lq H Dh, ... Lk H Dh -> ... H Lq Lk`). Jax is more modular in libraries so then you gotta get up to speed on linen, optax, orbax ... but those are usual tradeoffs. Why are the commit messages removed?
Nice, this is really helpful to compare with minigpt in pytorch. I think I like the dimensions in the variable names. Kinda miss seeing the modules used in a module being setup in __init__ as in pytorch, but I get its all cool to jit. I prefer einsum to transpose etc for multi head so at least you only have one line to stare at, but that could be done in torch too (sidenote, I kinda wish einsum let you have longer names so it could be something like `... Lq H Dh, ... Lk H Dh -> ... H Lq Lk`). Jax is more modular in libraries so then you gotta get up to speed on linen, optax, orbax ... but those are usual tradeoffs. Why are the commit messages removed?