A PyTorch implementation of the standard Variational Autoencoder (VAE). The amortized inference model (encoder) is parameterized by a convolutional network, while the generative model (decoder) is parameterized by a transposed convolutional network. The choice of the approximate posterior is a fully-factorized gaussian distribution with diagonal covariance.
This implementation supports model training on the CelebA dataset. This project serves as a proof of concept, hence the original images (178 x 218) are scaled and cropped to (64 x 64) images in order to speed up the training process. For ease of access, the zip file which contains the dataset can be downloaded from: https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip.
The VAE model was evaluated on several downstream tasks, such as image reconstruction and image generation. Some sample results can be found in the Results section.
Figure 1: Visual Representation of VAE. Image source: LearnOpenCV
- Python >= 3.9
- PyTorch >= 1.9
$ git clone https://github.com/julian-8897/Conv-VAE-PyTorch.git
$ cd Vanilla-VAE-PyTorch
$ pip install -r requirements.txt
To train the model, please modify the config.json
configuration file, and run:
python train.py --config config.json
To resume training of the model from a checkpoint, you can run the following command:
python train.py --resume path/to/checkpoint
To test the model, you can run the following command:
python test.py --resume path/to/checkpoint
Generated plots are stored in the 'Reconstructions' and 'Samples' folders.
Reconstructed Samples | Generated Samples |
---|---|
Reconstructed Samples | Generated Samples |
---|---|
-
Original VAE paper "Auto-Encoding Variational Bayes" by Kingma & Welling: https://arxiv.org/abs/1312.6114
-
Various implementations of VAEs in PyTorch: https://github.com/AntixK/PyTorch-VAE
-
PyTorch template used in this project: https://github.com/victoresque/pytorch-template
-
A comprehensive introduction to VAEs: https://arxiv.org/pdf/1906.02691.pdf