1. Element-wise Multiplication
*
torch.Tensor.mul()
torch.mul()
2. Matrix Multiplication
torch.Tensor.matmul()
torch.matmul()
torch.Tensor.mm()
torch.mm()
3. Batch Matrix Multiplication
torch.bmm()
torch.bmm(out_theta.transpose(1, 2), out_phi)