Suppose you've done a (robust) Bayesian multiple linear regression, and now you want the posterior distribution on the predicted value of \(y\) for some probe value of \( \langle x_1,x_2,x_3, ... \rangle \). That is, not the posterior distribution on the
mean of the predicted value, but the posterior distribution on the predicted value itself. I showed how to do this for simple linear regression in
a previous post; in this post I show how to do it for multiple linear regression. (A lot of commenters and emailers have asked me to do this.)
The basic idea is simple: At each step in the MCMC chain, use the parameter values to randomly generate a simulated datum \(y\) at the probed value of \(x\). Then examine the resulting distribution of simulated \(y\) values; that
is the posterior distribution of the predicted \(y\) values.
To implement the idea, the first programming choice is whether to simulate the \(y\) value with JAGS (or Stan or whatever) while it is generating the MCMC chain, or to simulate the \(y\) value after the MCMC chain has previously been generated. There are pros and cons of each option. Generating the value by JAGS has the benefit of keeping the code that generates the \(y\) value close to the code that expresses the model, so there is less chance of mistakenly simulating data by a different model than is being fit to the data. On the other hand, this method requires us to pre-specify all the \(x\) values we want to probe. If you want to choose the probed \(x\) values after JAGS has already generated the MCMC chain, then you'll need to re-express the model outside of JAGS, in R, and run the risk of mistakenly expressing it differently (e.g., using precision instead of standard deviation, or thinking that
y=rt(...) in R will use the same syntax as
y~dt(...) in JAGS). I will show an implementation in which JAGS simulated the \(y\) values while generating the MCMC chain.
To illustrate, I'll use the example of scholastic aptitude test (SAT) scores from Chapter 18 of DBDA2E, illustrated below:
Running robust multiple linear regression yields a posterior distribution as shown below:
where \(\nu\) is the normality (a.k.a. df) parameter of the \(t\) distribution.
Now for the issue of interest here: What is the predicted SAT score for a hypothetical state that spends, say, 9 thousand dollars per student and has 10% of the students take the exam? The answer, using the method described above, is shown below:
(Please note that the numerical annotation in these figures only shows the first three significant digits, so you'll need to examine the actual MCMC chain for more digits!) As another example, what is the predicted SAT score for a hypothetical state that spends,
say, 9 thousand dollars per student and has 80% of the students take the
exam? Answer:
A couple more examples of predictions:
Now for
the R code I used to generate those results. I modified the scripts supplied with DBDA2E, named
Jags-Ymet-XmetMulti-Mrobust.R and
Jags-Ymet-XmetMulti-Mrobust-Example.R. Some of the key changes are highlighted below.
The to-be-probed \(x\) values are put in a matrix called
xProbe., which has columns that correspond to the \(x\) predictors, and one row for each probe. The number of probed points (i.e., the number of rows of
xProbe), is denoted
Nprobe. The number of predictors (i.e., the number of columns of
xProbe), is denoted
Nx. Then, in the new low-level script, called
Jags-Ymet-XmetMulti-MrobustPredict.R, the Jags model specification looks like this:
# Standardize the data:
data {
ym <- mean(y)
ysd <- sd(y)
for ( i in 1:Ntotal ) {
zy[i] <- ( y[i] - ym ) / ysd
}
for ( j in 1:Nx ) {
xm[j] <- mean(x[,j])
xsd[j] <- sd(x[,j])
for ( i in 1:Ntotal ) {
zx[i,j] <- ( x[i,j] - xm[j] ) / xsd[j]
}
# standardize the probe values:
for ( i in 1:Nprobe ) {
zxProbe[i,j] <- ( xProbe[i,j] - xm[j] ) / xsd[j]
}
}
}
# Specify the model for standardized data:
model {
for ( i in 1:Ntotal ) {
zy[i] ~ dt( zbeta0 + sum( zbeta[1:Nx] * zx[i,1:Nx] ) , 1/zsigma^2 , nu )
}
# Priors vague on standardized scale:
zbeta0 ~ dnorm( 0 , 1/2^2 )
for ( j in 1:Nx ) {
zbeta[j] ~ dnorm( 0 , 1/2^2 )
}
zsigma ~ dunif( 1.0E-5 , 1.0E+1 )
nu ~ dexp(1/30.0)
# Transform to original scale:
beta[1:Nx] <- ( zbeta[1:Nx] / xsd[1:Nx] )*ysd
beta0 <- zbeta0*ysd + ym - sum( zbeta[1:Nx] * xm[1:Nx] / xsd[1:Nx] )*ysd
sigma <- zsigma*ysd
# Predicted y values at xProbe:
for ( i in 1:Nprobe ) {
zyP[i] ~ dt( zbeta0 + sum( zbeta[1:Nx] * zxProbe[i,1:Nx] ) ,
1/zsigma^2 , nu )
yP[i] <- zyP[i] * ysd + ym
}
}
The changes noted above are analogous to those I used for simple linear regression in
the previous post. The MCMC chain monitors the values of
yP[i], and subsequently we can examine the posterior distribution.
I hope this helps. Here is complete R code for the high-level and low-level scripts:
High-level script:
#===== Begin high-level script ================
# Example for Jags-Ymet-XmetMulti-MrobustPredict.R
#-------------------------------------------------------------------------------
# Optional generic preliminaries:
graphics.off() # This closes all of R's graphics windows.
rm(list=ls()) # Careful! This clears all of R's memory!
#.............................................................................
# # Two predictors:
myData = read.csv( file="Guber1999data.csv" )
yName = "SATT"
xName = c("Spend","PrcntTake")
xProbe = matrix( c( 4 , 10 , # Spend, PrcntTake
9 , 10 ,
4 , 80 ,
9 , 80 ) , nrow=4 , byrow=TRUE )
fileNameRoot = "Guber1999data-Predict-"
numSavedSteps=15000 ; thinSteps=5
graphFileType = "png"
#-------------------------------------------------------------------------------
# Load the relevant model into R's working memory:
source("Jags-Ymet-XmetMulti-MrobustPredict.R")
#-------------------------------------------------------------------------------
# Generate the MCMC chain:
mcmcCoda = genMCMC( data=myData , xName=xName , yName=yName , xProbe=xProbe ,
numSavedSteps=numSavedSteps , thinSteps=thinSteps ,
saveName=fileNameRoot )
#-------------------------------------------------------------------------------
# Display diagnostics of chain, for specified parameters:
parameterNames = varnames(mcmcCoda) # get all parameter names
for ( parName in parameterNames ) {
diagMCMC( codaObject=mcmcCoda , parName=parName ,
saveName=fileNameRoot , saveType=graphFileType )
}
#-------------------------------------------------------------------------------
# Get summary statistics of chain:
summaryInfo = smryMCMC( mcmcCoda , saveName=fileNameRoot )
show(summaryInfo)
# Display posterior information:
plotMCMC( mcmcCoda , data=myData , xName=xName , yName=yName ,
pairsPlot=TRUE , showCurve=FALSE ,
saveName=fileNameRoot , saveType=graphFileType )
#-------------------------------------------------------------------------------
# Plot posterior predicted y at xProbe:
mcmcMat = as.matrix(mcmcCoda)
xPcols = grep( "xProbe" , colnames(mcmcMat) , value=FALSE )
yPcols = grep( "yP" , colnames(mcmcMat) , value=FALSE )
xLim = quantile( mcmcMat[,yPcols] , probs=c(0.005,0.995) )
for ( i in 1:length(yPcols) ) {
openGraph(width=4,height=3)
xNameText = paste( "@" , paste( xName , collapse=", " ) , "=" )
xProbeValText = paste(mcmcMat[1,xPcols[seq(i,
by=length(yPcols),
length=length(xName))]],
collapse=", ")
plotPost( mcmcMat[,yPcols[i]] , xlab="Post. Pred. y" , xlim=xLim ,
cenTend="mean" ,
main= bquote(atop(.(xNameText),.(xProbeValText))) )
}
#-------------------------------------------------------------------------------
#===== End high-level script ================
Low-level script, named
Jags-Ymet-XmetMulti-MrobustPredict.R
and called by high-level script:
#===== Begin low-level script ================
# Jags-Ymet-XmetMulti-MrobustPredict.R
# Accompanies the book:
# Kruschke, J. K. (2015). Doing Bayesian Data Analysis, Second Edition:
# A Tutorial with R, JAGS, and Stan. Academic Press / Elsevier.
source("DBDA2E-utilities.R")
#===============================================================================
genMCMC = function( data , xName="x" , yName="y" , xProbe=NULL ,
numSavedSteps=10000 , thinSteps=1 , saveName=NULL ,
runjagsMethod=runjagsMethodDefault ,
nChains=nChainsDefault ) {
require(runjags)
#-----------------------------------------------------------------------------
# THE DATA.
y = data[,yName]
x = as.matrix(data[,xName],ncol=length(xName))
# Do some checking that data make sense:
if ( any( !is.finite(y) ) ) { stop("All y values must be finite.") }
if ( any( !is.finite(x) ) ) { stop("All x values must be finite.") }
cat("\nCORRELATION MATRIX OF PREDICTORS:\n ")
show( round(cor(x),3) )
cat("\n")
flush.console()
Nx = ncol(x) # number of x predictors
Ntotal = nrow(x) # number of data points
# Check the probe values:
if ( !is.null(xProbe) ) {
if ( any( !is.finite(xProbe) ) ) {
stop("All xProbe values must be finite.") }
if ( ncol(xProbe) != Nx ) {
stop("xProbe must have same number of columns as x.") }
} else { # fill in placeholder so JAGS doesn't balk
xProbe = matrix( 0 , ncol=Nx , nrow=3 )
for ( xIdx in 1:Nx ) {
xProbe[,xIdx] = quantile(x[,xIdx],probs=c(0.0,0.5,1.0))
}
}
# Specify the data in a list, for later shipment to JAGS:
dataList = list(
x = x ,
y = y ,
Nx = Nx ,
Ntotal = Ntotal ,
xProbe = xProbe ,
Nprobe = nrow(xProbe)
)
#-----------------------------------------------------------------------------
# THE MODEL.
modelString = "
# Standardize the data:
data {
ym <- mean(y)
ysd <- sd(y)
for ( i in 1:Ntotal ) {
zy[i] <- ( y[i] - ym ) / ysd
}
for ( j in 1:Nx ) {
xm[j] <- mean(x[,j])
xsd[j] <- sd(x[,j])
for ( i in 1:Ntotal ) {
zx[i,j] <- ( x[i,j] - xm[j] ) / xsd[j]
}
# standardize the probe values:
for ( i in 1:Nprobe ) {
zxProbe[i,j] <- ( xProbe[i,j] - xm[j] ) / xsd[j]
}
}
}
# Specify the model for standardized data:
model {
for ( i in 1:Ntotal ) {
zy[i] ~ dt( zbeta0 + sum( zbeta[1:Nx] * zx[i,1:Nx] ) , 1/zsigma^2 , nu )
}
# Priors vague on standardized scale:
zbeta0 ~ dnorm( 0 , 1/2^2 )
for ( j in 1:Nx ) {
zbeta[j] ~ dnorm( 0 , 1/2^2 )
}
zsigma ~ dunif( 1.0E-5 , 1.0E+1 )
nu ~ dexp(1/30.0)
# Transform to original scale:
beta[1:Nx] <- ( zbeta[1:Nx] / xsd[1:Nx] )*ysd
beta0 <- zbeta0*ysd + ym - sum( zbeta[1:Nx] * xm[1:Nx] / xsd[1:Nx] )*ysd
sigma <- zsigma*ysd
# Predicted y values at xProbe:
for ( i in 1:Nprobe ) {
zyP[i] ~ dt( zbeta0 + sum( zbeta[1:Nx] * zxProbe[i,1:Nx] ) ,
1/zsigma^2 , nu )
yP[i] <- zyP[i] * ysd + ym
}
}
" # close quote for modelString
# Write out modelString to a text file
writeLines( modelString , con="TEMPmodel.txt" )
#-----------------------------------------------------------------------------
# INTIALIZE THE CHAINS.
# Let JAGS do it...
#-----------------------------------------------------------------------------
# RUN THE CHAINS
parameters = c( "beta0" , "beta" , "sigma",
"zbeta0" , "zbeta" , "zsigma", "nu" , "xProbe" , "yP" )
adaptSteps = 500 # Number of steps to "tune" the samplers
burnInSteps = 1000
runJagsOut <- run.jags( method="parallel" ,
model="TEMPmodel.txt" ,
monitor=parameters ,
data=dataList ,
#inits=initsList ,
n.chains=nChains ,
adapt=adaptSteps ,
burnin=burnInSteps ,
sample=ceiling(numSavedSteps/nChains) ,
thin=thinSteps ,
summarise=FALSE ,
plots=FALSE )
codaSamples = as.mcmc.list( runJagsOut )
# resulting codaSamples object has these indices:
# codaSamples[[ chainIdx ]][ stepIdx , paramIdx ]
if ( !is.null(saveName) ) {
save( codaSamples , file=paste(saveName,"Mcmc.Rdata",sep="") )
}
return( codaSamples )
} # end function
#===============================================================================
smryMCMC = function( codaSamples ,
saveName=NULL ) {
summaryInfo = NULL
mcmcMat = as.matrix(codaSamples,chains=TRUE)
paramName = colnames(mcmcMat)
for ( pName in paramName ) {
summaryInfo = rbind( summaryInfo , summarizePost( mcmcMat[,pName] ) )
}
rownames(summaryInfo) = paramName
summaryInfo = rbind( summaryInfo ,
"log10(nu)" = summarizePost( log10(mcmcMat[,"nu"]) ) )
if ( !is.null(saveName) ) {
write.csv( summaryInfo , file=paste(saveName,"SummaryInfo.csv",sep="") )
}
return( summaryInfo )
}
#===============================================================================
plotMCMC = function( codaSamples , data , xName="x" , yName="y" ,
showCurve=FALSE , pairsPlot=FALSE ,
saveName=NULL , saveType="jpg" ) {
# showCurve is TRUE or FALSE and indicates whether the posterior should
# be displayed as a histogram (by default) or by an approximate curve.
# pairsPlot is TRUE or FALSE and indicates whether scatterplots of pairs
# of parameters should be displayed.
#-----------------------------------------------------------------------------
y = data[,yName]
x = as.matrix(data[,xName])
mcmcMat = as.matrix(codaSamples,chains=TRUE)
chainLength = NROW( mcmcMat )
zbeta0 = mcmcMat[,"zbeta0"]
zbeta = mcmcMat[,grep("^zbeta$|^zbeta\\[",colnames(mcmcMat))]
if ( ncol(x)==1 ) { zbeta = matrix( zbeta , ncol=1 ) }
zsigma = mcmcMat[,"zsigma"]
beta0 = mcmcMat[,"beta0"]
beta = mcmcMat[,grep("^beta$|^beta\\[",colnames(mcmcMat))]
if ( ncol(x)==1 ) { beta = matrix( beta , ncol=1 ) }
sigma = mcmcMat[,"sigma"]
nu = mcmcMat[,"nu"]
log10nu = log10(nu)
#-----------------------------------------------------------------------------
# Compute R^2 for credible parameters:
YcorX = cor( y , x ) # correlation of y with each x predictor
Rsq = zbeta %*% matrix( YcorX , ncol=1 )
#-----------------------------------------------------------------------------
if ( pairsPlot ) {
# Plot the parameters pairwise, to see correlations:
openGraph()
nPtToPlot = 1000
plotIdx = floor(seq(1,chainLength,by=chainLength/nPtToPlot))
panel.cor = function(x, y, digits=2, prefix="", cex.cor, ...) {
usr = par("usr"); on.exit(par(usr))
par(usr = c(0, 1, 0, 1))
r = (cor(x, y))
txt = format(c(r, 0.123456789), digits=digits)[1]
txt = paste(prefix, txt, sep="")
if(missing(cex.cor)) cex.cor <- 0.8/strwidth(txt)
text(0.5, 0.5, txt, cex=1.25 ) # was cex=cex.cor*r
}
pairs( cbind( beta0 , beta , sigma , log10nu )[plotIdx,] ,
labels=c( "beta[0]" ,
paste0("beta[",1:ncol(beta),"]\n",xName) ,
expression(sigma) , expression(log10(nu)) ) ,
lower.panel=panel.cor , col="skyblue" )
if ( !is.null(saveName) ) {
saveGraph( file=paste(saveName,"PostPairs",sep=""), type=saveType)
}
}
#-----------------------------------------------------------------------------
# Marginal histograms:
decideOpenGraph = function( panelCount , saveName , finished=FALSE ,
nRow=2 , nCol=3 ) {
# If finishing a set:
if ( finished==TRUE ) {
if ( !is.null(saveName) ) {
saveGraph( file=paste0(saveName,ceiling((panelCount-1)/(nRow*nCol))),
type=saveType)
}
panelCount = 1 # re-set panelCount
return(panelCount)
} else {
# If this is first panel of a graph:
if ( ( panelCount %% (nRow*nCol) ) == 1 ) {
# If previous graph was open, save previous one:
if ( panelCount>1 & !is.null(saveName) ) {
saveGraph( file=paste0(saveName,(panelCount%/%(nRow*nCol))),
type=saveType)
}
# Open new graph
openGraph(width=nCol*7.0/3,height=nRow*2.0)
layout( matrix( 1:(nRow*nCol) , nrow=nRow, byrow=TRUE ) )
par( mar=c(4,4,2.5,0.5) , mgp=c(2.5,0.7,0) )
}
# Increment and return panel count:
panelCount = panelCount+1
return(panelCount)
}
}
# Original scale:
panelCount = 1
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
histInfo = plotPost( beta0 , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(beta[0]) , main="Intercept" )
for ( bIdx in 1:ncol(beta) ) {
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
histInfo = plotPost( beta[,bIdx] , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(beta[.(bIdx)]) , main=xName[bIdx] )
}
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
histInfo = plotPost( sigma , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(sigma) , main=paste("Scale") )
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
histInfo = plotPost( log10nu , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(log10(nu)) , main=paste("Normality") )
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
histInfo = plotPost( Rsq , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(R^2) , main=paste("Prop Var Accntd") )
panelCount = decideOpenGraph( panelCount , finished=TRUE , saveName=paste0(saveName,"PostMarg") )
# Standardized scale:
panelCount = 1
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
histInfo = plotPost( zbeta0 , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(z*beta[0]) , main="Intercept" )
for ( bIdx in 1:ncol(beta) ) {
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
histInfo = plotPost( zbeta[,bIdx] , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(z*beta[.(bIdx)]) , main=xName[bIdx] )
}
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
histInfo = plotPost( zsigma , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(z*sigma) , main=paste("Scale") )
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
histInfo = plotPost( log10nu , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(log10(nu)) , main=paste("Normality") )
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
histInfo = plotPost( Rsq , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(R^2) , main=paste("Prop Var Accntd") )
panelCount = decideOpenGraph( panelCount , finished=TRUE , saveName=paste0(saveName,"PostMargZ") )
#-----------------------------------------------------------------------------
}
#===============================================================================
#===== End low-level script ================