sm = torch.softmax(outputs, dim = 1)
1
# Softmax 将张量的每个元素缩放到(0,1)区间且和为1
Pmax, predicted_labels = torch.max(sm, 1)
1
2
# dim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值
# 函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。
torch.gather
1
2
3
pt = torch.gather(m1,m2)
就是取m1张量中,m2对应的值。
利用index来索引input特定位置的数值