我有以下代码:
class myLSTM(nn.Module):def __init (self,input_size,output_size,hidden_size,num_layers): super(myLSTM,self). init __() self.input_size = input_size + 1 …
这是预期的 - 存储模块 list , dict , set 或其他python容器没有注册它们所拥有的模块 list 等。要使代码正常工作,请使用 nn.ModuleList 代替。它就像修改你的一样简单 __init__ 要使用的代码
list
dict
set
nn.ModuleList
__init__
layers = [] new_input_size = self.input_size for i in xrange(num_layers): layers.append(LSTMCell(new_input_size, hidden_size)) new_input_size = hidden_size self.layers = nn.ModuleList(layers)