官方示例:
>>> loss = nn.CrossEntropyLoss() >>> input = torch.randn(3, 5, requires_grad=True) >>> target = torch.empty(3, dtype=torch.long).random_(5) >>> output = loss(input, target) >>> output.backward()
1.在loss中的输入中,target为类别的index,而非one-hot编码。
2.在输入的target的index中,数据的范围为[0, c-1],其中c为类别的总数,注意index的编码从0开始。