If you find yourself needing einsum a lot, consider using the Xarray [1] package instead, which lets you name the axes of your arrays. Broadcasting and axis alignment is then automatically done by matching names. You don't need einsum then usually, since you can just use "dot" or "sum". However, if you do need einsum or other functions from NumPy, you can easily drop down to the NumPy level and re-add the Xarray metadata when you're done.
[1] https://xarray.dev