一、标签映射与One-Hot编码过程

先进行标签映射,要为每个分类建立一个整数索引,对于每个样本的标签,使用整数索引创建一个长度为类别总数的二进制向量。这个向量的所有元素都是0,除了与整数索引相对应的位置,该位置的值为1。

二、pytorch的官方实现

在pytorch中实现了one hot编码,就在torch.nn.functional里面,下面是它的注释当中的示例,我们开看看:

Examples:>>> F.one_hot(torch.arange(0, 5) % 3)tensor([[1, 0, 0],[0, 1, 0],[0, 0, 1],[1, 0, 0],[0, 1, 0]])>>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)tensor([[1, 0, 0, 0, 0],[0, 1, 0, 0, 0],[0, 0, 1, 0, 0],[1, 0, 0, 0, 0],[0, 1, 0, 0, 0]])>>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)tensor([[[1, 0, 0], [0, 1, 0]],[[0, 0, 1], [1, 0, 0]],[[0, 1, 0], [0, 0, 1]]])

我们可以根据那自己实现的与它给出的这个示例进行比对,一样就当然没问题了。

三、手写实现

首先,在原先的函数(one_hot)当中numclass=-1,类别当然不能为1,说明这里是自动进行了计算,大家普遍使用的方式都是创建一个全零矩阵,使用 scatter_ 函数进行独热编码,作用是按照给定的索引,在指定的维度上进行赋值。

def one_hot(labels, num_classes=-1):"""将标签转为独热编码, 经过测试与torch.nn.functional里面的函数测试相同:param labels: 标签:param num_classes: 默认为-1, 表示进行自动计算类别最大的那个Examples:>>> label_1 = torch.arange(0, 5) % 3# tensor([0, 1, 2, 0, 1])>>> label_2 = torch.arange(0, 6).view(3, 2) % 3# tensor([[0, 1], [2, 0], [1, 2]])>>> print(one_hot(label_1))tensor([[1, 0, 0],[0, 1, 0],[0, 0, 1],[1, 0, 0],[0, 1, 0]])>>> print(one_hot(label_1, 5))tensor([[1, 0, 0, 0, 0],[0, 1, 0, 0, 0],[0, 0, 1, 0, 0],[1, 0, 0, 0, 0],[0, 1, 0, 0, 0]])>>> print(one_hot(label_2))tensor([[[1, 0, 0], [0, 1, 0]],[[0, 0, 1], [1, 0, 0]],[[0, 1, 0], [0, 0, 1]]])"""if num_classes == -1:num_classes = int(labels.max()) + 1one_hot_tensor = torch.zeros(labels.size() + (num_classes,), dtype=torch.int64)one_hot_tensor.scatter_(-1, labels.unsqueeze(-1).to(torch.int64), 1)return one_hot_tensorlabel_1 = torch.arange(0, 5) % 3# tensor([0, 1, 2, 0, 1])label_2 = torch.arange(0, 6).view(3, 2) % 3# tensor([[0, 1], [2, 0], [1, 2]])print(one_hot(label_1))print(one_hot(label_1, 5))print(one_hot(label_2))

首先是判断分类数是不是为-1,如果是就根据其中的最大值+1进行自动计算。然后创建一个契合分类数量的全零矩阵。

在这里,labels.unsqueeze(-1)用于在标签的最后一个维度上添加一个维度,以便与独热编码张量进行广播操作。

假设原始的 labels 张量的形状为 (batch_size,),那么经过 unsqueeze(-1) 操作后,形状变为 (batch_size, 1)。这样,每个样本的标签都被表示为一个列向量,而不再是一个标量。scatter_函数在最后一个维度进行操作,也就是对类别总数的维度进行操作,而 1 是要赋给相应位置的值。

labels.unsqueeze(-1) 已经确保了与 one_hot_tensor 的形状匹配,所以在这里能够正确地进行广播和赋值操作。

下面这一种是应用于分割网络当中,在保留输入标签张量形状的同时,将独热编码张量的最后一个维度设置为分类数num_classes,确保独热编码张量与输入标签张量具有相同的形状。

def get_one_hot(labels, num_classes=-1):"""用于分割网络的one hot"""labels = torch.as_tensor(labels)ones = one_hot(labels, num_classes)return ones.view(*labels.size(), num_classes)if __name__=="__main__":seg_labels = torch.randint(0, 3, size=[512, 512])print(get_one_hot(seg_labels))print(get_one_hot(seg_labels).shape) # torch.Size([512, 512, 3])

你可以将这里应用于自定义dataset部分。