Linear Regression
Here we showcase linear regression on 2 input variables. All variables are continuous here, with some fixed values used to generate a data set and a model to infer these parameters from the data.
// generate data; parameters should be recovered by inference algorithm
val data = {
val alpha = 1.0
val beta = (1.0, 2.5)
val sigma = 1.0
for {_ <- 0 until 100} yield {
val X = (Random.nextGaussian(), 0.2 * Random.nextGaussian())
val Y = alpha + X._1 * beta._1 + X._2 * beta._2 + Random.nextGaussian() * sigma
(X, Y)
}
}
// set up variational approximation to the posterior distribution
val aPost = ReparamGuide(Normal(Param(0.0), exp(Param(0.0)))))
val b1Post = ReparamGuide(Normal(Param(0.0), exp(Param(0.0))))
val b2Post = ReparamGuide(Normal(Param(0.0), exp(Param(0.0))))
val errPost = ReparamGuide(Normal(Param(0.0), exp(Param(0.0))))
// Draw variables from prior distributions and link those
// variables to the posterior approximation.
val model = infer {
val a = sample(Normal(0.0, 1.0), aPost)
val b1 = sample(Normal(0.0, 1.0), b1Post)
val b2 = sample(Normal(0.0, 1.0), b2Post)
val err = exp(sample(Normal(0.0, 1.0), errPost))
// iterate over data points to define the observations
data.foreach[Unit] {
case ((x1, x2), y) =>
observe(Normal(a + b1 * x1 + b2 * x2, err), y: Real)
}
// return the values that we're interested in
(a, b1, b2, err)
}
// choose an optimization algorithm
// each parameter could have its own optimizer
val adam = new Adam(alpha = 0.1)
val interpreter = new OptimizingInterpreter(adam)
// warm up
// each sample of the model triggers a gradient descent step
Range(0, 1000).foreach { i =>
interpreter.reset()
model.sample(interpreter)
}
// print some samples
Range(0, 10).foreach { i =>
interpreter.reset()
val l = model.sample(interpreter)
val values = (l._1.v, l._2.v, l._3.v, l._4.v)
println(s" $values")
}
Here, we not only inject the variational posterior distribution into the model, but the data as well. Some things to note here
- we can naturally iterate over the data and declare observations - the used
data types
Seq
andTuple2
have no special meaning and neither has theforeach
method - while real parameters and random variables run over the whole real axis, they
can be mapped to the interval
(0, Inf)
by theexp
function