看来你误会了什么 torch.tensor([1, 84, 84]) 是在做。让我们来看看:
torch.tensor([1, 84, 84])
torch.tensor([1, 84, 84]) print(x, x.shape) #tensor([ 1, 84, 84]) torch.Size([3])
您可以从上面的示例中看到,它为您提供了只有一个维度的张量。
从你的问题陈述中,你需要一个形状的张量[1,84,84]。 这是它的样子:
from collections import deque import torch import torchvision.transforms as T class ReplayBuffer: def __init__(self, buffersize, batchsize, framestack, device, nS): self.buffer = deque(maxlen=buffersize) self.phi = deque(maxlen=framestack) self.batchsize = batchsize self.device = device self._initialize_stack(nS) def get_stack(self): t = torch.cat(tuple(self.phi),dim=0) # t = torch.stack(tuple(self.phi),dim=0) return t def _initialize_stack(self, nS): while len(self.phi) < self.phi.maxlen: # self.phi.append(torch.tensor([1,nS[1], nS[2]])) self.phi.append(torch.zeros([1,nS[1], nS[2]])) a = ReplayBuffer(buffersize=50000, batchsize=64, framestack=4, device='cuda', nS=[1,84,84]) print(a.phi) s = a.get_stack() print(s, s.shape)
注意 torch.cat 给你一个形状的张量[4,84,84]和 torch.stack 给你一个形状张量[4,1,84,84]。他们的差异可以在 torch.stack()和torch.cat()函数有什么区别?
torch.cat
torch.stack