Posts Pytorch常用语法
Post
Cancel

Pytorch常用语法

sm = torch.softmax(outputs, dim = 1)

1
# Softmax 将张量的每个元素缩放到(0,1)区间且和为1

Pmax, predicted_labels = torch.max(sm, 1)

1
2
# dim是max函数索引的维度0/1,0是每列的最大值,1是每行的最大值
# 函数会返回两个tensor,第一个tensor是每行的最大值;第二个tensor是每行最大值的索引。

torch.gather

1
2
3
pt = torch.gather(m1,m2)
就是取m1张量中,m2对应的值。
利用index来索引input特定位置的数值

具体见连接

This post is licensed under CC BY 4.0 by the author.

Contents

Trending Tags