• by aconz2 on 6/5/2024, 12:47:51 PM

    Nice, this is really helpful to compare with minigpt in pytorch. I think I like the dimensions in the variable names. Kinda miss seeing the modules used in a module being setup in __init__ as in pytorch, but I get its all cool to jit. I prefer einsum to transpose etc for multi head so at least you only have one line to stare at, but that could be done in torch too (sidenote, I kinda wish einsum let you have longer names so it could be something like `... Lq H Dh, ... Lk H Dh -> ... H Lq Lk`). Jax is more modular in libraries so then you gotta get up to speed on linen, optax, orbax ... but those are usual tradeoffs. Why are the commit messages removed?