Skip to the content.

多标签的one_hot编码与解码

Contact me


在分类问题中一个常做的处理是对标签做one_hot编码,这个处理有很多库可以解决,例如sklearn的OneHotEncoder,keras提供的to_categorical也是十分方便的函数,自己写也是很容易的,就是:

label = np.zeros(shape)
label[idx] = 1

当然还是最好用现有的函数,如果是多标签的话,没有找到直接的方法,这个时候可能就得自己写了,例如:

label = np.zeros(shape)
for idx in idxs:
    label[idx] = 1

或者就是把每个单标签相加:

label = np.zeros(shape)
for idx in idxs:
    label += to_categorical(idx)

编码还是非常直接的,但是解码呢?

如果是单标签,也有直接的函数:np.argmax,这个就是读取最大值的索引,因为都是0,1值,所以单标签可以直接获得,也算一个解码函数,但如果是多标签,这个方法就不是很好用了,或者我们可以手动写:

labs = []
for i in label.shape[0]:
    if label[i] == 1:
        labs.append(i)

或者我们可以用numpy提供的另一个函数,np.where,一个使用的简单例子是:

test = np.array([
    [1,0,1],
    [0,1,0]
])

np.where(test != 0)

# output
# (array([0, 0, 1]), array([0, 2, 1]))

它输出了一个元组,第一个是行索引,第二个是列索引,所以列索引就是我们要的标签,这样我们直接获得one_hot编码中的多标签原始信息,非常方便。