np.argmax()

分类问题

很多问题都可以转换为分类问题,情感分析大多是二分类:pos & neg ;NER 命名实体识别多分类问题:name,place,entity等等
这种问题一般是监督问题,给定一批有标注的语料,训练得到一个分类模型,然后在无标注的语料上进行测试
以情感分析为例,输入是(sentense ,label);模型给出预测
pred=model(sentense);然后这个pred跟原本的事实 label 比对测试,下面是一个 sentence classification 问题中最后处理时候涉及到多分类的 np.argmax 的用法,代码很通用:

def flat_accuracy(preds, labels):
  pred_flat = np.argmax(preds, axis=1).flatten() # [3, 5, 8, 1, 2, ....]
   labels_flat = labels.flatten()
 return np.sum(pred_flat == labels_flat) / len(labels_flat)
'''
这个地方是一个批处理batch=32,num_class=18;所以这个地方preds 是一个np.array(32,18),对每条数据在18个类别上做出预测概率
label=np.array(18);np.argmax(preds,axis=1) axis=1指在内层每个样例18中选出最大的预测值
经过 flatten()函数之后转换为一维(只适用于array)
所以pred_flat=lables=np.array(32);预测值跟真是值都是32维
然后计算正确率返回
'''
import numpy as np
a = [[1, 2, 3], [4, 100, 6], [99, 8, 9]] 
a = np.array(a, dtype=int)
print(a)
print(np.argmax(a,axis=1))
>>
[[  1   2   3]
[  4 100   6]
[ 99   8   9]]
[2 1 0]
>>

这些代码都是很细节的地方,但是要搭建模型的话,这些是基础,很多代码是通用的,细节处理也是可以复用的,随时记录,重点是搞清楚每一种任务的处理方式,其中的数据维度梳理清楚。。。

np.argmax:

numpy.argmax(array, axis) 用于返回一个numpy数组中最大值的索引值。当一组中同时出现几个最大值时,返回第一个最大值的索引值。

一维数组:
import numpy as np
a = np.array([1.1,9.9,13.13,4.4], dtype=float)
print(a)
print(np.argmax(a))
>>[ 1.1   9.9  13.13  4.4 ]
2
二维数组:

遵循运算之后降一维的原则,因此返回的会是一个一维的array

##axis=0 外层
##axis=1 内层
import numpy as np
a = [[1, 2, 3], [4, 100, 6], [99, 8, 9]] 
a = np.array(a, dtype=int)
print(a)
print(np.argmax(a,axis=1))
axis0 = np.argmax(a, axis = 0)#外层
axis1 = np.argmax(a, axis = 1)#内层
print(axis0 )
print(axis1 )
>>
[[  1   2   3]
 [  4 100   6]
 [ 99   8   9]]
[2 1 2]
[2 1 0]

pytorch 中的tensor 用法类似:

b = torch.argmax(tag_scores, dim=1)
print(tag_scores)
print(b)
>>
tensor([[-0.3892, -1.2426, -3.3890],
        [-2.1082, -0.1328, -5.8464],
        [-3.0852, -5.9469, -0.0495],
        [-0.0499, -3.4414, -4.0961],
        [-2.4540, -0.0929, -5.8799]])
tensor([0, 1, 2, 0, 1])
>>
三维数组

三维计算之后降维,将返回一个二维数组。
一个m×n×p维的矩阵,
axis为0,舍去m,返回一个 n×p 维的矩阵
axis为1,舍去n,返回一个 m×p 维的矩阵
axis为2,舍去p,返回一个 m×n 维的矩阵

three_dim_array = [[[1, 2, 3, 4],  [-1, 0, 3, 5]],
                   [[2, 7, -1, 3], [0, 3, 12, 4]],
                   [[5, 1, 0, 19], [4, 2, -2, 13]]]
a = np.argmax(three_dim_array, axis = 0)
print(a)
b = np.argmax(three_dim_array, axis = 1)
print(b)
c = np.argmax(test, axis = 2)
print(c)

例中数组shape为 3×2×4
输出结果为:
0 对应shape 2×4
1 对应shape 3×4
2 对应shape 3×2

[[2 1 0 2]                                                                                                               
 [2 1 1 2]]

[[0 0 0 1]                                                                                                               
 [0 0 1 1]                                                                                                               
 [0 1 0 0]]

[[3 3]                                                                                                                   
 [1 2]                                                                                                                   
 [3 3]]  

参考:
英文文本关系抽取-来自wmother的博客
np.argmax的用法-通俗易懂

推荐阅读更多精彩内容