The basic idea is simple: At each step in the MCMC chain, use the parameter values to randomly generate a simulated datum \(y\) at the probed value of \(x\). Then examine the resulting distribution of simulated \(y\) values; that is the posterior distribution of the predicted \(y\) values.
To implement the idea, the first programming choice is whether to simulate the \(y\) value with JAGS (or Stan or whatever) while it is generating the MCMC chain, or to simulate the \(y\) value after the MCMC chain has previously been generated. There are pros and cons of each option. Generating the value by JAGS has the benefit of keeping the code that generates the \(y\) value close to the code that expresses the model, so there is less chance of mistakenly simulating data by a different model than is being fit to the data. On the other hand, this method requires us to pre-specify all the \(x\) values we want to probe. If you want to choose the probed \(x\) values after JAGS has already generated the MCMC chain, then you'll need to re-express the model outside of JAGS, in R, and run the risk of mistakenly expressing it differently (e.g., using precision instead of standard deviation, or thinking that y=rt(...) in R will use the same syntax as y~dt(...) in JAGS). I will show an implementation in which JAGS simulated the \(y\) values while generating the MCMC chain.
(A SUBSEQUENT POST SHOWS HOW TO DO THIS IN R AFTER JAGS.)
To illustrate, I'll use the example of scholastic aptitude test (SAT) scores from Chapter 18 of DBDA2E, illustrated below:
Running robust multiple linear regression yields a posterior distribution as shown below:
where \(\nu\) is the normality (a.k.a. df) parameter of the \(t\) distribution.
Now for the issue of interest here: What is the predicted SAT score for a hypothetical state that spends, say, 9 thousand dollars per student and has 10% of the students take the exam? The answer, using the method described above, is shown below:
(Please note that the numerical annotation in these figures only shows the first three significant digits, so you'll need to examine the actual MCMC chain for more digits!) As another example, what is the predicted SAT score for a hypothetical state that spends, say, 9 thousand dollars per student and has 80% of the students take the exam? Answer:
A couple more examples of predictions:
Now for the R code I used to generate those results. I modified the scripts supplied with DBDA2E, named Jags-Ymet-XmetMulti-Mrobust.R and Jags-Ymet-XmetMulti-Mrobust-Example.R. Some of the key changes are highlighted below.
The to-be-probed \(x\) values are put in a matrix called xProbe., which has columns that correspond to the \(x\) predictors, and one row for each probe. The number of probed points (i.e., the number of rows of xProbe), is denoted Nprobe. The number of predictors (i.e., the number of columns of xProbe), is denoted Nx. Then, in the new low-level script, called Jags-Ymet-XmetMulti-MrobustPredict.R, the Jags model specification looks like this:
# Standardize the data:
data {
ym <- mean(y)
ysd <- sd(y)
for ( i in 1:Ntotal ) {
zy[i] <- ( y[i] - ym ) / ysd
}
for ( j in 1:Nx ) {
xm[j] <- mean(x[,j])
xsd[j] <- sd(x[,j])
for ( i in 1:Ntotal ) {
zx[i,j] <- ( x[i,j] - xm[j] ) / xsd[j]
}
# standardize the probe values:
for ( i in 1:Nprobe ) {
zxProbe[i,j] <- ( xProbe[i,j] - xm[j] ) / xsd[j]
}
}
}
# Specify the model for standardized data:
model {
for ( i in 1:Ntotal ) {
zy[i] ~ dt( zbeta0 + sum( zbeta[1:Nx] * zx[i,1:Nx] ) , 1/zsigma^2 , nu )
}
# Priors vague on standardized scale:
zbeta0 ~ dnorm( 0 , 1/2^2 )
for ( j in 1:Nx ) {
zbeta[j] ~ dnorm( 0 , 1/2^2 )
}
zsigma ~ dunif( 1.0E-5 , 1.0E+1 )
nu ~ dexp(1/30.0)
# Transform to original scale:
beta[1:Nx] <- ( zbeta[1:Nx] / xsd[1:Nx] )*ysd
beta0 <- zbeta0*ysd + ym - sum( zbeta[1:Nx] * xm[1:Nx] / xsd[1:Nx] )*ysd
sigma <- zsigma*ysd
# Predicted y values at xProbe:
for ( i in 1:Nprobe ) {
zyP[i] ~ dt( zbeta0 + sum( zbeta[1:Nx] * zxProbe[i,1:Nx] ) ,
1/zsigma^2 , nu )
yP[i] <- zyP[i] * ysd + ym
}
}
The changes noted above are analogous to those I used for simple linear regression in the previous post. The MCMC chain monitors the values of yP[i], and subsequently we can examine the posterior distribution.
I hope this helps. Here is complete R code for the high-level and low-level scripts:
High-level script:
#===== Begin high-level script ================
# Example for Jags-Ymet-XmetMulti-MrobustPredict.R
#-------------------------------------------------------------------------------
# Optional generic preliminaries:
graphics.off() # This closes all of R's graphics windows.
rm(list=ls()) # Careful! This clears all of R's memory!
#.............................................................................
# # Two predictors:
myData = read.csv( file="Guber1999data.csv" )
yName = "SATT"
xName = c("Spend","PrcntTake")
xProbe = matrix( c( 4 , 10 , # Spend, PrcntTake
9 , 10 ,
4 , 80 ,
9 , 80 ) , nrow=4 , byrow=TRUE )
fileNameRoot = "Guber1999data-Predict-"
numSavedSteps=15000 ; thinSteps=5
graphFileType = "png"
#-------------------------------------------------------------------------------
# Load the relevant model into R's working memory:
source("Jags-Ymet-XmetMulti-MrobustPredict.R")
#-------------------------------------------------------------------------------
# Generate the MCMC chain:
mcmcCoda = genMCMC( data=myData , xName=xName , yName=yName , xProbe=xProbe ,
numSavedSteps=numSavedSteps , thinSteps=thinSteps ,
saveName=fileNameRoot )
#-------------------------------------------------------------------------------
# Display diagnostics of chain, for specified parameters:
parameterNames = varnames(mcmcCoda) # get all parameter names
for ( parName in parameterNames ) {
diagMCMC( codaObject=mcmcCoda , parName=parName ,
saveName=fileNameRoot , saveType=graphFileType )
}
#-------------------------------------------------------------------------------
# Get summary statistics of chain:
summaryInfo = smryMCMC( mcmcCoda , saveName=fileNameRoot )
show(summaryInfo)
# Display posterior information:
plotMCMC( mcmcCoda , data=myData , xName=xName , yName=yName ,
pairsPlot=TRUE , showCurve=FALSE ,
saveName=fileNameRoot , saveType=graphFileType )
#-------------------------------------------------------------------------------
# Plot posterior predicted y at xProbe:
mcmcMat = as.matrix(mcmcCoda)
xPcols = grep( "xProbe" , colnames(mcmcMat) , value=FALSE )
yPcols = grep( "yP" , colnames(mcmcMat) , value=FALSE )
xLim = quantile( mcmcMat[,yPcols] , probs=c(0.005,0.995) )
for ( i in 1:length(yPcols) ) {
openGraph(width=4,height=3)
xNameText = paste( "@" , paste( xName , collapse=", " ) , "=" )
xProbeValText = paste(mcmcMat[1,xPcols[seq(i,
by=length(yPcols),
length=length(xName))]],
collapse=", ")
plotPost( mcmcMat[,yPcols[i]] , xlab="Post. Pred. y" , xlim=xLim ,
cenTend="mean" ,
main= bquote(atop(.(xNameText),.(xProbeValText))) )
}
#-------------------------------------------------------------------------------
#===== End high-level script ================
Low-level script, named Jags-Ymet-XmetMulti-MrobustPredict.R
and called by high-level script:
#===== Begin low-level script ================
# Jags-Ymet-XmetMulti-MrobustPredict.R
# Accompanies the book:
# Kruschke, J. K. (2015). Doing Bayesian Data Analysis, Second Edition:
# A Tutorial with R, JAGS, and Stan. Academic Press / Elsevier.
source("DBDA2E-utilities.R")
#===============================================================================
genMCMC = function( data , xName="x" , yName="y" , xProbe=NULL ,
numSavedSteps=10000 , thinSteps=1 , saveName=NULL ,
runjagsMethod=runjagsMethodDefault ,
nChains=nChainsDefault ) {
require(runjags)
#-----------------------------------------------------------------------------
# THE DATA.
y = data[,yName]
x = as.matrix(data[,xName],ncol=length(xName))
# Do some checking that data make sense:
if ( any( !is.finite(y) ) ) { stop("All y values must be finite.") }
if ( any( !is.finite(x) ) ) { stop("All x values must be finite.") }
cat("\nCORRELATION MATRIX OF PREDICTORS:\n ")
show( round(cor(x),3) )
cat("\n")
flush.console()
Nx = ncol(x) # number of x predictors
Ntotal = nrow(x) # number of data points
# Check the probe values:
if ( !is.null(xProbe) ) {
if ( any( !is.finite(xProbe) ) ) {
stop("All xProbe values must be finite.") }
if ( ncol(xProbe) != Nx ) {
stop("xProbe must have same number of columns as x.") }
} else { # fill in placeholder so JAGS doesn't balk
xProbe = matrix( 0 , ncol=Nx , nrow=3 )
for ( xIdx in 1:Nx ) {
xProbe[,xIdx] = quantile(x[,xIdx],probs=c(0.0,0.5,1.0))
}
}
# Specify the data in a list, for later shipment to JAGS:
dataList = list(
x = x ,
y = y ,
Nx = Nx ,
Ntotal = Ntotal ,
xProbe = xProbe ,
Nprobe = nrow(xProbe)
)
#-----------------------------------------------------------------------------
# THE MODEL.
modelString = "
# Standardize the data:
data {
ym <- mean(y)
ysd <- sd(y)
for ( i in 1:Ntotal ) {
zy[i] <- ( y[i] - ym ) / ysd
}
for ( j in 1:Nx ) {
xm[j] <- mean(x[,j])
xsd[j] <- sd(x[,j])
for ( i in 1:Ntotal ) {
zx[i,j] <- ( x[i,j] - xm[j] ) / xsd[j]
}
# standardize the probe values:
for ( i in 1:Nprobe ) {
zxProbe[i,j] <- ( xProbe[i,j] - xm[j] ) / xsd[j]
}
}
}
# Specify the model for standardized data:
model {
for ( i in 1:Ntotal ) {
zy[i] ~ dt( zbeta0 + sum( zbeta[1:Nx] * zx[i,1:Nx] ) , 1/zsigma^2 , nu )
}
# Priors vague on standardized scale:
zbeta0 ~ dnorm( 0 , 1/2^2 )
for ( j in 1:Nx ) {
zbeta[j] ~ dnorm( 0 , 1/2^2 )
}
zsigma ~ dunif( 1.0E-5 , 1.0E+1 )
nu ~ dexp(1/30.0)
# Transform to original scale:
beta[1:Nx] <- ( zbeta[1:Nx] / xsd[1:Nx] )*ysd
beta0 <- zbeta0*ysd + ym - sum( zbeta[1:Nx] * xm[1:Nx] / xsd[1:Nx] )*ysd
sigma <- zsigma*ysd
# Predicted y values at xProbe:
for ( i in 1:Nprobe ) {
zyP[i] ~ dt( zbeta0 + sum( zbeta[1:Nx] * zxProbe[i,1:Nx] ) ,
1/zsigma^2 , nu )
yP[i] <- zyP[i] * ysd + ym
}
}
" # close quote for modelString
# Write out modelString to a text file
writeLines( modelString , con="TEMPmodel.txt" )
#-----------------------------------------------------------------------------
# INTIALIZE THE CHAINS.
# Let JAGS do it...
#-----------------------------------------------------------------------------
# RUN THE CHAINS
parameters = c( "beta0" , "beta" , "sigma",
"zbeta0" , "zbeta" , "zsigma", "nu" , "xProbe" , "yP" )
adaptSteps = 500 # Number of steps to "tune" the samplers
burnInSteps = 1000
runJagsOut <- run.jags( method="parallel" ,
model="TEMPmodel.txt" ,
monitor=parameters ,
data=dataList ,
#inits=initsList ,
n.chains=nChains ,
adapt=adaptSteps ,
burnin=burnInSteps ,
sample=ceiling(numSavedSteps/nChains) ,
thin=thinSteps ,
summarise=FALSE ,
plots=FALSE )
codaSamples = as.mcmc.list( runJagsOut )
# resulting codaSamples object has these indices:
# codaSamples[[ chainIdx ]][ stepIdx , paramIdx ]
if ( !is.null(saveName) ) {
save( codaSamples , file=paste(saveName,"Mcmc.Rdata",sep="") )
}
return( codaSamples )
} # end function
#===============================================================================
smryMCMC = function( codaSamples ,
saveName=NULL ) {
summaryInfo = NULL
mcmcMat = as.matrix(codaSamples,chains=TRUE)
paramName = colnames(mcmcMat)
for ( pName in paramName ) {
summaryInfo = rbind( summaryInfo , summarizePost( mcmcMat[,pName] ) )
}
rownames(summaryInfo) = paramName
summaryInfo = rbind( summaryInfo ,
"log10(nu)" = summarizePost( log10(mcmcMat[,"nu"]) ) )
if ( !is.null(saveName) ) {
write.csv( summaryInfo , file=paste(saveName,"SummaryInfo.csv",sep="") )
}
return( summaryInfo )
}
#===============================================================================
plotMCMC = function( codaSamples , data , xName="x" , yName="y" ,
showCurve=FALSE , pairsPlot=FALSE ,
saveName=NULL , saveType="jpg" ) {
# showCurve is TRUE or FALSE and indicates whether the posterior should
# be displayed as a histogram (by default) or by an approximate curve.
# pairsPlot is TRUE or FALSE and indicates whether scatterplots of pairs
# of parameters should be displayed.
#-----------------------------------------------------------------------------
y = data[,yName]
x = as.matrix(data[,xName])
mcmcMat = as.matrix(codaSamples,chains=TRUE)
chainLength = NROW( mcmcMat )
zbeta0 = mcmcMat[,"zbeta0"]
zbeta = mcmcMat[,grep("^zbeta$|^zbeta\\[",colnames(mcmcMat))]
if ( ncol(x)==1 ) { zbeta = matrix( zbeta , ncol=1 ) }
zsigma = mcmcMat[,"zsigma"]
beta0 = mcmcMat[,"beta0"]
beta = mcmcMat[,grep("^beta$|^beta\\[",colnames(mcmcMat))]
if ( ncol(x)==1 ) { beta = matrix( beta , ncol=1 ) }
sigma = mcmcMat[,"sigma"]
nu = mcmcMat[,"nu"]
log10nu = log10(nu)
#-----------------------------------------------------------------------------
# Compute R^2 for credible parameters:
YcorX = cor( y , x ) # correlation of y with each x predictor
Rsq = zbeta %*% matrix( YcorX , ncol=1 )
#-----------------------------------------------------------------------------
if ( pairsPlot ) {
# Plot the parameters pairwise, to see correlations:
openGraph()
nPtToPlot = 1000
plotIdx = floor(seq(1,chainLength,by=chainLength/nPtToPlot))
panel.cor = function(x, y, digits=2, prefix="", cex.cor, ...) {
usr = par("usr"); on.exit(par(usr))
par(usr = c(0, 1, 0, 1))
r = (cor(x, y))
txt = format(c(r, 0.123456789), digits=digits)[1]
txt = paste(prefix, txt, sep="")
if(missing(cex.cor)) cex.cor <- 0.8/strwidth(txt)
text(0.5, 0.5, txt, cex=1.25 ) # was cex=cex.cor*r
}
pairs( cbind( beta0 , beta , sigma , log10nu )[plotIdx,] ,
labels=c( "beta[0]" ,
paste0("beta[",1:ncol(beta),"]\n",xName) ,
expression(sigma) , expression(log10(nu)) ) ,
lower.panel=panel.cor , col="skyblue" )
if ( !is.null(saveName) ) {
saveGraph( file=paste(saveName,"PostPairs",sep=""), type=saveType)
}
}
#-----------------------------------------------------------------------------
# Marginal histograms:
decideOpenGraph = function( panelCount , saveName , finished=FALSE ,
nRow=2 , nCol=3 ) {
# If finishing a set:
if ( finished==TRUE ) {
if ( !is.null(saveName) ) {
saveGraph( file=paste0(saveName,ceiling((panelCount-1)/(nRow*nCol))),
type=saveType)
}
panelCount = 1 # re-set panelCount
return(panelCount)
} else {
# If this is first panel of a graph:
if ( ( panelCount %% (nRow*nCol) ) == 1 ) {
# If previous graph was open, save previous one:
if ( panelCount>1 & !is.null(saveName) ) {
saveGraph( file=paste0(saveName,(panelCount%/%(nRow*nCol))),
type=saveType)
}
# Open new graph
openGraph(width=nCol*7.0/3,height=nRow*2.0)
layout( matrix( 1:(nRow*nCol) , nrow=nRow, byrow=TRUE ) )
par( mar=c(4,4,2.5,0.5) , mgp=c(2.5,0.7,0) )
}
# Increment and return panel count:
panelCount = panelCount+1
return(panelCount)
}
}
# Original scale:
panelCount = 1
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
histInfo = plotPost( beta0 , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(beta[0]) , main="Intercept" )
for ( bIdx in 1:ncol(beta) ) {
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
histInfo = plotPost( beta[,bIdx] , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(beta[.(bIdx)]) , main=xName[bIdx] )
}
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
histInfo = plotPost( sigma , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(sigma) , main=paste("Scale") )
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
histInfo = plotPost( log10nu , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(log10(nu)) , main=paste("Normality") )
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
histInfo = plotPost( Rsq , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(R^2) , main=paste("Prop Var Accntd") )
panelCount = decideOpenGraph( panelCount , finished=TRUE , saveName=paste0(saveName,"PostMarg") )
# Standardized scale:
panelCount = 1
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
histInfo = plotPost( zbeta0 , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(z*beta[0]) , main="Intercept" )
for ( bIdx in 1:ncol(beta) ) {
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
histInfo = plotPost( zbeta[,bIdx] , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(z*beta[.(bIdx)]) , main=xName[bIdx] )
}
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
histInfo = plotPost( zsigma , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(z*sigma) , main=paste("Scale") )
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
histInfo = plotPost( log10nu , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(log10(nu)) , main=paste("Normality") )
panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
histInfo = plotPost( Rsq , cex.lab = 1.75 , showCurve=showCurve ,
xlab=bquote(R^2) , main=paste("Prop Var Accntd") )
panelCount = decideOpenGraph( panelCount , finished=TRUE , saveName=paste0(saveName,"PostMargZ") )
#-----------------------------------------------------------------------------
}
#===============================================================================
#===== End low-level script ================
I know it is a bit beside the point, but would you be able to share the code for creating Fig 18.3 as well?
ReplyDeleteHow would it be implemented for the Jags-Ymet-XmetSsubj-MrobustHier.R code?
ReplyDeleteThis comment has been removed by the author.
ReplyDelete