Hello:
I am new to both logistic regression as well as Infer.NET. I am trying to build a binary logistic model and would appreciate feedback/comments/corrections on the below code. The model has two independent variables (and a third one is added to the vectors for constant coefficient). Based on the single observation example provided in the post 'Bayesian logistic regression' and some code posted in another post (I think it was multiple linear regression), this is what I have come up with and have serious doubts about whether it is done correctly:
Vector[] data = new Vector[] { new Vector(1.0, -3, 1), new Vector(2.0, -2.1, 1), new Vector(1.0, -1.3, 1), new Vector(2.0, 0.5, 1), new Vector(1.0, 1.2, 1 ), new Vector(1.0, 3.3, 1), new Vector(1.0, 4.4, 1), new Vector(1.0, 5.5,1 ) };Range rows = new Range(data.Length);VariableArray<Vector x = Variable.Constant(data, rows).Named("x");Variable<Vector> w = Variable.VectorGaussianFromMeanAndPrecision(new Vector(new double[] { 0, 0, 0 }), PositiveDefiniteMatrix.Identity(3)).Named("w");VariableArray<bool> y = Variable.Array<bool>(rows);y[rows] = Variable.BernoulliFromLogOdds(Variable.GaussianFromMeanAndVariance(Variable.InnerProduct(x[rows], w), 1.0));InferenceEngine engine = new InferenceEngine(new VariationalMessagePassing());y.ObservedValue = new bool[] { true, false, true, false, false, true, false, true };VectorGaussian postW = engine.Infer<VectorGaussian>(w);txtResult.Text = "W = \n" + postW;
The output is:
W = VectorGaussian(-0.3729 0.01771 0.2636, 0.2634 0.009538 -0.3017 ) 0.009538 0.01458 -0.02437 -0.3017 -0.02437 0.4693
Please advise on whether the model is built correctly and how to read the output. I am assuming that b0 = 0.4693, b1 = -0.3017 and b2 = -0.02437.
All help is greatly appreciated.
Many thanks!
R Hasnani
This looks correct.
The output shows the mean and the variance of the posterior on W. If you want to be more specific, you can write:
VectorGaussian postW = engine.Infer<VectorGaussian>(w);Vector postWMean = postW.GetMean();PositiveDefiniteMatrix postWVar = postW.GetVariance();Console.WriteLine(postWMean);Console.WriteLine(postWVar);
John