项目作者: plaxi0s

项目描述 :
Mixture Density Network implementation PyTorch
高级语言: Python
项目地址: git://github.com/plaxi0s/pytorch-MDN.git
创建时间: 2020-05-09T15:14:08Z
项目社区:https://github.com/plaxi0s/pytorch-MDN

开源协议:

下载


Mixture Density Network


Implementation of Mixture Density Network in PyTorch

An MDN models the conditional distribution over a scalar response as a mixture of Gaussians.




where the mixture distribution parameters are output by a neural network, trained to maximize overall log-likelihood. The set of mixture distribution parameters is the following.



In order to predict the response as a multivariate Gaussian distribution (for example, in [2]), we assume a fully factored distribution (i.e. a diagonal covariance matrix) and predict each dimension separately. We assume each component of the distribution is statistically independent.

Usage

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. from models.mdn import MixtureDensityNetworks
  5. from utils.loss import MDN_loss
  6. from utils.utils import sample
  7. model=nn.Sequential(
  8. nn.Linear(1,20),
  9. nn.Tanh(),
  10. MixtureDensityNetworks(20,1,5),
  11. )
  12. opt=optm.Adam(model.parameters())
  13. for e in range(num_epochs):
  14. opt.zero_grad()
  15. pi,mu,sigma=model.forward(x_var)
  16. loss=MDN_loss(t_var,pi,mu,sigma)
  17. loss.backward()
  18. opt.step()
  19. pi,mu,sigma=model.forward(mini)
  20. samples=samples(pi,mu,sigma)

ad

Original data


ad

Inverse data


ad

Inverse and sampled data


References

Bishop, C. M. Mixture density networks. (1994).