多标签的one_hot编码与解码
Contact me
- Blog -> https://cugtyt.github.io/blog/index
- Email -> cugtyt@qq.com
- GitHub -> Cugtyt@GitHub
在分类问题中一个常做的处理是对标签做one_hot编码,这个处理有很多库可以解决,例如sklearn的OneHotEncoder,keras提供的to_categorical也是十分方便的函数,自己写也是很容易的,就是:
label = np.zeros(shape)
label[idx] = 1
当然还是最好用现有的函数,如果是多标签的话,没有找到直接的方法,这个时候可能就得自己写了,例如:
label = np.zeros(shape)
for idx in idxs:
label[idx] = 1
或者就是把每个单标签相加:
label = np.zeros(shape)
for idx in idxs:
label += to_categorical(idx)
编码还是非常直接的,但是解码呢?
如果是单标签,也有直接的函数:np.argmax
,这个就是读取最大值的索引,因为都是0,1值,所以单标签可以直接获得,也算一个解码函数,但如果是多标签,这个方法就不是很好用了,或者我们可以手动写:
labs = []
for i in label.shape[0]:
if label[i] == 1:
labs.append(i)
或者我们可以用numpy提供的另一个函数,np.where
,一个使用的简单例子是:
test = np.array([
[1,0,1],
[0,1,0]
])
np.where(test != 0)
# output
# (array([0, 0, 1]), array([0, 2, 1]))
它输出了一个元组,第一个是行索引,第二个是列索引,所以列索引就是我们要的标签,这样我们直接获得one_hot编码中的多标签原始信息,非常方便。