================================①=========================================
def encode_onehot(labels): classes = set(labels) classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)} labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32) return labels_onehot
np.identity() 生成单位矩阵,所以每一行即代表了一个label的one-hot表示,通过字典的形式保存下来
map(classes_dict.get, labels) 对于lables中的每一个lable,都带入到字典中,即得到其对应的one-hot编码。
================================②=========================================
另外可以直接调用pytorch中的one-hot方法:
labels = torch.tensor([1,2,3,2,1]) [nn.functional.one_hot(labels[i], 4) for i in range(5)]