.item()方法是,取一个元素张量里面的具体元素值并返回该值,可以将一个零维张量转换成int型或者float型,在计算loss,accuracy时常用到。
作用:
1.item()取出张量具体位置的元素元素值
2.并且返回的是该位置元素值的高精度值
3.保持原元素类型不变;必须指定位置4.节省内存(不会计入计算图)
import torch
loss = torch.randn(2, 2)
print(loss)
print(loss[1,1])
print(loss[1,1].item())
输出结果
tensor([[-2.0274, -1.5974],
[-1.4775, 1.9320]])
tensor(1.9320)
1.9319512844085693
其它:
loss = criterion(out, label)
loss_sum += loss # <--- 这里
运行着就发现显存炸了,观察发现随着每个batch显存消耗在不断增大…因为输出的loss的数据类型是Variable。PyTorch的动态图机制就是通过Variable来构建图。主要是使用Variable计算的时候,会记录下新产生的Variable的运算符号,在反向传播求导的时候进行使用。如果这里直接将loss加起来,系统会认为这里也是计算图的一部分,也就是说网络会一直延伸变大,那么消耗的显存也就越来越大。
正确的loss一般是这样写
loss_sum += loss.data[0]
其它注意事项:
使用loss += loss.detach()来获取不需要梯度回传的部分。
使用loss.item()直接获得对应的python数据类型。
补充阅读,pytorch 计算图
Pytorch的计算图由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系。
Pytorch中的计算图是动态图。这里的动态主要有两重含义。
第一层含义是:计算图的正向传播是立即执行的。无需等待完整的计算图创建完毕,每条语句都会在计算图中动态添加节点和边,并立即执行正向传播得到计算结果。
第二层含义是:计算图在反向传播后立即销毁。下次调用需要重新构建计算图。如果在程序中使用了backward方法执行了反向传播,或者利用torch.autograd.grad方法计算了梯度,那么创建的计算图会被立即销毁,释放存储空间,下次调用需要重新创建。
1,计算图的正向传播是立即执行的。
import torch
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关
loss = torch.mean(torch.pow(Y_hat-Y,2))
print(loss.data)
print(Y_hat.data)
tensor(17.8969)
tensor([[3.2613],
[4.7322],
[4.5037],
[7.5899],
[7.0973],
[1.3287],
[6.1473],
[1.3492],
[1.3911],
[1.2150]])
2,计算图在反向传播后立即销毁。
import torch
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关
loss = torch.mean(torch.pow(Y_hat-Y,2))
#计算图在反向传播后立即销毁,如果需要保留计算图, 需要设置retain_graph = True
loss.backward() #loss.backward(retain_graph = True)
#loss.backward() #如果再次执行反向传播将报错
参考链接:pytorch学习:loss为什么要加item()_dlvector的博客-CSDN博客_loss.item()
https://blog.csdn.net/cs111211/article/details/126221102