用这个:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
现在您可以使用模型和每个张量
net.to(device) input = input.to(device)
经过一些试验和错误后,我发现了两种方法:
self._train_noise = torch.randn(batch_size, embedding_size)
self.register_buffer('_train_noise', torch.randn(batch_size, embedding_size)
net.to(device)
覆盖 net.to(device) :使用这个噪音不在state_dict之内。
def to(device): new_self = super(VariationalGenerator, self).to(device) new_self._train_noise = new_self._train_noise.to(device) new_self._eval_noise = new_self._eval_noise.to(device) return new_self