Mixture Density Network implementation PyTorch
Implementation of Mixture Density Network in PyTorch
An MDN models the conditional distribution over a scalar response as a mixture of Gaussians.
In order to predict the response as a multivariate Gaussian distribution (for example, in [2]), we assume a fully factored distribution (i.e. a diagonal covariance matrix) and predict each dimension separately. We assume each component of the distribution is statistically independent.
import torch
import torch.nn as nn
import torch.optim as optim
from models.mdn import MixtureDensityNetworks
from utils.loss import MDN_loss
from utils.utils import sample
model=nn.Sequential(
nn.Linear(1,20),
nn.Tanh(),
MixtureDensityNetworks(20,1,5),
)
opt=optm.Adam(model.parameters())
for e in range(num_epochs):
opt.zero_grad()
pi,mu,sigma=model.forward(x_var)
loss=MDN_loss(t_var,pi,mu,sigma)
loss.backward()
opt.step()
pi,mu,sigma=model.forward(mini)
samples=samples(pi,mu,sigma)
Bishop, C. M. Mixture density networks. (1994).