项目作者: ldeecke

项目描述 :
Gaussian mixture models in PyTorch.
高级语言: Python
项目地址: git://github.com/ldeecke/gmm-torch.git
创建时间: 2018-06-13T10:26:27Z
项目社区:https://github.com/ldeecke/gmm-torch

开源协议:MIT License

下载


This repository contains an implementation of a simple Gaussian mixture model (GMM) fitted with Expectation-Maximization in pytorch. The interface closely follows that of sklearn.

Example of a fit via a Gaussian Mixture model.


A new model is instantiated by calling gmm.GaussianMixture(..) and providing as arguments the number of components, as well as the tensor dimension. Note that once instantiated, the model expects tensors in a flattened shape (n, d).

The first step would usually be to fit the model via model.fit(data), then predict with model.predict(data). To reproduce the above figure, just run the provided example.py.

Some sanity checks can be executed by calling python test.py. To fit data on GPUs, ensure that you first call model.cuda().