Empirical Bayes

The Bayesian approach to inference expects the practitioner to carefully consider their prior beliefs and encode those in prior distributions. In practice, however, not enough domain knowledge may be available when modelling to do this.

Especially when there are more features available than there are points in the dataset, pruning the features becomes important. This is essentially Occams Razor - when two models can explain the data, the simpler model is preferred. The bayesian definition of “simple” is to have a high marginal likelihood. This balances the bias that is introduced by reducing the number of parameters with the reduction in variance that goes along with it.

Generating the data

Let’s create a dataset with 2 explanatory variables, a and b, and one dependent variable, y. The variable a indeed influences y, while b is independent.

case class Record(a: Float, b: Float, y: Float)

// Tensor shape - let's make it typed!
case class Batch(size: Int) extends Dim[Batch]
val batch = Batch(1000)

val (a_vals, b_vals, y_vals) = {
    val a_weight = 1.0
    val b_weight = 0.0
    val noise = 0.5
    
    val data = for { _ <- 0 until batch.size } yield {
        val a = Random.nextGaussian()
        val b = Random.nextGaussian()
        val y = a_weight * a + noise
        Record(a.toFloat, b.toFloat, y.toFloat)
    }

    (
        Value(ArrayTensor(batch.sizes, data.map { _.a }.toArray), batch),
        Value(ArrayTensor(batch.sizes, data.map { _.b }.toArray), batch),
        Value(ArrayTensor(batch.sizes, data.map { _.y }.toArray), batch)
    )
}

Create the model

In the model, we now not just introduce parameters for the variational approximation (a_post_mu, a_post_s, b_post_mu and b_post_s). We also include parameters for the prior distributions, a_prior_s and b_prior_s.

Note that we here treat a and b identically, as we want the optimization procedure to figure out which parameters are relevant by itself.

val a_prior_s = Param(0.0)
val b_prior_s = Param(0.0)

val a_post_mu = Param(0.0)
val a_post_s = Param(0.0)
val a_guide = ReparamGuide(Normal(a_post_mu, exp(a_post_s)))

val b_post_mu = Param(0.0)
val b_post_s = Param(0.0)
val b_guide = ReparamGuide(Normal(b_post_mu, exp(b_post_s)))

val noise_mu = Param(0.0)
val noise_s = Param(0.0)
val noise_guide = ReparamGuide(Normal(noise_mu, exp(noise_s)))

val model = infer {
    val a_weight = sample(Normal(0.0, exp(a_prior_s)), a_guide)
    val b_weight = sample(Normal(0.0, exp(b_prior_s)), b_guide)
    val noise = sample(Normal(0.0, 1.0), noise_guide)

    observe(Normal(
        broadcast(a_weight, batch) * a_vals
        + broadcast(b_weight, batch) * b_vals,
        broadcast[Batch, ArrayTensor](exp(noise), batch)
    ), y_vals)
}

Running the optimization

After optimization, we can see a clear difference between relevant and irrelevant parameters. Parameter a_prior_s drives the standard deviation in the prior for a_weight to 1, while the same for b_weight gets close to 0.

Notebook

The Jupyter notebook with the code is available at Automatic Relevance Determination.ipynb in the scala-infer notebooks project.