Friday, June 7, 2019

Variational Autoencoders 2: Implementation

Prior: isotropic multivariate Gaussian $p_{\boldsymbol{\theta}}(\mathbf{z})=\mathcal{N}(\mathbf{z} ; \mathbf{0}, \mathbf{I})$

Probabilistic encoer(recognition model): Bernoulli or Gaussian $q_{\phi}(\mathbf{z} | \mathbf{x})$

Probabilistic decoder : multivariate Gaussian(real-valued data) or Bernoulli(binary data) $p_{\boldsymbol{\theta}}(\mathbf{z} | \mathbf{x})$

$\log q_{\phi}\left(\mathbf{z} | \mathbf{x}^{(i)}\right)=\log \mathcal{N}\left(\mathbf{z} ; \boldsymbol{\mu}^{(i)}, \boldsymbol{\sigma}^{2(i)} \mathbf{I}\right)$
So $\phi=\{ \boldsymbol{\mu}^{(i)}, \boldsymbol{\sigma}^{2(i)} \}$

The structure for decoder is similar.

We sample using the re-parameterization trick:
$\mathbf{z}^{(i, l)} \sim q_{\phi}\left(\mathbf{z} | \mathbf{x}^{(i)}\right)$
$\mathbf{z}^{(i, l)}=g_{\phi}\left(\mathbf{x}^{(i)}, \boldsymbol{\epsilon}^{(l)}\right)

=\boldsymbol{\mu}^{(i)}+\boldsymbol{\sigma}^{(i)} \odot \boldsymbol{\epsilon}^{(l)}$

where $\epsilon^{(l)} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})$

For implementation, I like one example from PyTorch's examples:

Some details:

  • MNIST images are gray-scale images ranging from 0 to 255. But 

will transform "a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) or if the numpy.ndarray has dtype = np.uint8"

  • As you may might notice, the choice of $p(z | x)$ is a Bernoulli distribution. No good explanations have been found. MNIST orignally are gray-scale images. After conversion it's indeed in the range of $[0,1]$, but can you take a step further and claim its the probability of being some sort of black or white? Of course this doesn't make too much sense. But as a very crude approximation it might work. One guy tried to explain in his blog post[3], but I found his explanations flawed and left comments for him. You can find other discussions about this from [4] and [5]. IMHO, this is just an engineering artifact


  1. PyTorch document on ToTensor method
  2. PyTorch official VAE example 
  3. Using a Bernoulli VAE on Real-Valued Observations - Rui Shu. (2018). 
  4. r/MachineLearning - what loss to use for variational auto encoder with real numbers? (2011)
  5. Loss function autoencoder vs variational-autoencoder or MSE-loss vs binary-cross-entropy-loss

No comments:

Post a Comment