Saturday, October 22, 2016

Posterior predictive distribution for multiple linear regression

Suppose you've done a (robust) Bayesian multiple linear regression, and now you want the posterior distribution on the predicted value of \(y\) for some probe value of \( \langle x_1,x_2,x_3, ... \rangle \). That is, not the posterior distribution on the mean of the predicted value, but the posterior distribution on the predicted value itself. I showed how to do this for simple linear regression in a previous post; in this post I show how to do it for multiple linear regression. (A lot of commenters and emailers have asked me to do this.)

The basic idea is simple: At each step in the MCMC chain, use the parameter values to randomly generate a simulated datum \(y\) at the probed value of \(x\). Then examine the resulting distribution of simulated \(y\) values; that is the posterior distribution of the predicted \(y\) values.

To implement the idea, the first programming choice is whether to simulate the \(y\) value with JAGS (or Stan or whatever) while it is generating the MCMC chain, or to simulate the \(y\) value after the MCMC chain has previously been generated. There are pros and cons of each option. Generating the value by JAGS has the benefit of keeping the code that generates the \(y\) value close to the code that expresses the model, so there is less chance of mistakenly simulating data by a different model than is being fit to the data. On the other hand, this method requires us to pre-specify all the \(x\) values we want to probe. If you want to choose the probed \(x\) values after JAGS has already generated the MCMC chain, then you'll need to re-express the model outside of JAGS, in R, and run the risk of mistakenly expressing it differently (e.g., using precision instead of standard deviation, or thinking that y=rt(...) in R will use the same syntax as y~dt(...) in JAGS). I will show an implementation in which JAGS simulated the \(y\) values while generating the MCMC chain.

To illustrate, I'll use the example of scholastic aptitude test (SAT) scores from Chapter 18 of DBDA2E, illustrated below:


Running robust multiple linear regression yields a posterior distribution as shown below:

where \(\nu\) is the normality (a.k.a. df) parameter of the \(t\) distribution.

Now for the issue of interest here: What is the predicted SAT score for a hypothetical state that spends, say, 9 thousand dollars per student and has 10% of the students take the exam? The answer, using the method described above, is shown below:

(Please note that the numerical annotation in these figures only shows the first three significant digits, so you'll need to examine the actual MCMC chain for more digits!) As another example, what is the predicted SAT score for a hypothetical state that spends, say, 9 thousand dollars per student and has 80% of the students take the exam? Answer:

A couple more examples of predictions:



Now for the R code I used to generate those results. I modified the scripts supplied with DBDA2E, named Jags-Ymet-XmetMulti-Mrobust.R and Jags-Ymet-XmetMulti-Mrobust-Example.R. Some of the key changes are highlighted below.

The to-be-probed \(x\) values are put in a matrix called xProbe., which has columns that correspond to the \(x\) predictors, and one row for each probe. The number of probed points (i.e., the number of rows of xProbe), is denoted Nprobe. The number of predictors (i.e., the number of columns of xProbe), is denoted Nx. Then, in the new low-level script, called  Jags-Ymet-XmetMulti-MrobustPredict.R, the Jags model specification looks like this:

  # Standardize the data:
  data {
    ym <- mean(y)
    ysd <- sd(y)
    for ( i in 1:Ntotal ) {
      zy[i] <- ( y[i] - ym ) / ysd
    }
    for ( j in 1:Nx ) {
      xm[j]  <- mean(x[,j])
      xsd[j] <-   sd(x[,j])
      for ( i in 1:Ntotal ) {
        zx[i,j] <- ( x[i,j] - xm[j] ) / xsd[j]
      }
      # standardize the probe values:
      for ( i in 1:Nprobe ) {
        zxProbe[i,j] <- ( xProbe[i,j] - xm[j] ) / xsd[j]
      }


    }
  }
  # Specify the model for standardized data:
  model {
    for ( i in 1:Ntotal ) {
      zy[i] ~ dt( zbeta0 + sum( zbeta[1:Nx] * zx[i,1:Nx] ) , 1/zsigma^2 , nu )
    }
    # Priors vague on standardized scale:
    zbeta0 ~ dnorm( 0 , 1/2^2 ) 
    for ( j in 1:Nx ) {
      zbeta[j] ~ dnorm( 0 , 1/2^2 )
    }
    zsigma ~ dunif( 1.0E-5 , 1.0E+1 )
    nu ~ dexp(1/30.0)
    # Transform to original scale:
    beta[1:Nx] <- ( zbeta[1:Nx] / xsd[1:Nx] )*ysd
    beta0 <- zbeta0*ysd  + ym - sum( zbeta[1:Nx] * xm[1:Nx] / xsd[1:Nx] )*ysd
    sigma <- zsigma*ysd
    # Predicted y values at xProbe:
    for ( i in 1:Nprobe ) {
      zyP[i] ~ dt( zbeta0 + sum( zbeta[1:Nx] * zxProbe[i,1:Nx] ) ,
                   1/zsigma^2 , nu )
      yP[i] <- zyP[i] * ysd + ym
    }


  }


The changes noted above are analogous to those I used for simple linear regression in the previous post. The MCMC chain monitors the values of yP[i], and subsequently we can examine the posterior distribution.

I hope this helps. Here is complete R code for the high-level and low-level scripts:

High-level script:

#===== Begin high-level script ================
# Example for Jags-Ymet-XmetMulti-MrobustPredict.R
#-------------------------------------------------------------------------------
# Optional generic preliminaries:
graphics.off() # This closes all of R's graphics windows.
rm(list=ls())  # Careful! This clears all of R's memory!
#.............................................................................
# # Two predictors:
myData = read.csv( file="Guber1999data.csv" )
yName = "SATT"
xName = c("Spend","PrcntTake")
xProbe = matrix( c( 4 , 10 , # Spend, PrcntTake
                    9 , 10 ,
                    4 , 80 ,
                    9 , 80 ) , nrow=4 , byrow=TRUE )


fileNameRoot = "Guber1999data-Predict-"
numSavedSteps=15000 ; thinSteps=5
graphFileType = "png"
#-------------------------------------------------------------------------------
# Load the relevant model into R's working memory:
source("Jags-Ymet-XmetMulti-MrobustPredict.R")
#-------------------------------------------------------------------------------
# Generate the MCMC chain:
mcmcCoda = genMCMC( data=myData , xName=xName , yName=yName , xProbe=xProbe ,
                    numSavedSteps=numSavedSteps , thinSteps=thinSteps ,
                    saveName=fileNameRoot )
#-------------------------------------------------------------------------------
# Display diagnostics of chain, for specified parameters:
parameterNames = varnames(mcmcCoda) # get all parameter names
for ( parName in parameterNames ) {
  diagMCMC( codaObject=mcmcCoda , parName=parName ,
            saveName=fileNameRoot , saveType=graphFileType )
}
#-------------------------------------------------------------------------------
# Get summary statistics of chain:
summaryInfo = smryMCMC( mcmcCoda , saveName=fileNameRoot )
show(summaryInfo)
# Display posterior information:
plotMCMC( mcmcCoda , data=myData , xName=xName , yName=yName ,
          pairsPlot=TRUE , showCurve=FALSE ,
          saveName=fileNameRoot , saveType=graphFileType )
#-------------------------------------------------------------------------------
# Plot posterior predicted y at xProbe:
mcmcMat = as.matrix(mcmcCoda)
xPcols = grep( "xProbe" , colnames(mcmcMat) , value=FALSE )
yPcols = grep( "yP" , colnames(mcmcMat) , value=FALSE )
xLim = quantile( mcmcMat[,yPcols] , probs=c(0.005,0.995) )
for ( i in 1:length(yPcols) ) {
  openGraph(width=4,height=3)
  xNameText = paste( "@" , paste( xName , collapse=", " ) , "=" )
  xProbeValText = paste(mcmcMat[1,xPcols[seq(i,
                                          by=length(yPcols),
                                          length=length(xName))]],
                     collapse=", ")
  plotPost( mcmcMat[,yPcols[i]] , xlab="Post. Pred. y" , xlim=xLim ,
            cenTend="mean" ,
            main= bquote(atop(.(xNameText),.(xProbeValText))) )
}


#-------------------------------------------------------------------------------

#===== End high-level script ================


Low-level script, named Jags-Ymet-XmetMulti-MrobustPredict.R
and called by high-level script:
#===== Begin low-level script ================
# Jags-Ymet-XmetMulti-MrobustPredict.R
# Accompanies the book:
#  Kruschke, J. K. (2015). Doing Bayesian Data Analysis, Second Edition:
#  A Tutorial with R, JAGS, and Stan. Academic Press / Elsevier.

source("DBDA2E-utilities.R")

#===============================================================================

genMCMC = function( data , xName="x" , yName="y" ,  xProbe=NULL ,
                    numSavedSteps=10000 , thinSteps=1 , saveName=NULL  ,
                    runjagsMethod=runjagsMethodDefault ,
                    nChains=nChainsDefault ) {
  require(runjags)
  #-----------------------------------------------------------------------------
  # THE DATA.
  y = data[,yName]
  x = as.matrix(data[,xName],ncol=length(xName))
  # Do some checking that data make sense:
  if ( any( !is.finite(y) ) ) { stop("All y values must be finite.") }
  if ( any( !is.finite(x) ) ) { stop("All x values must be finite.") }
  cat("\nCORRELATION MATRIX OF PREDICTORS:\n ")
  show( round(cor(x),3) )
  cat("\n")
  flush.console()
  Nx = ncol(x) # number of x predictors
  Ntotal = nrow(x) # number of data points
  # Check the probe values:
  if ( !is.null(xProbe) ) {
    if ( any( !is.finite(xProbe) ) ) {
      stop("All xProbe values must be finite.") }
    if ( ncol(xProbe) != Nx ) {
      stop("xProbe must have same number of columns as x.") }
  } else { # fill in placeholder so JAGS doesn't balk
    xProbe = matrix( 0 , ncol=Nx , nrow=3 )
    for ( xIdx in 1:Nx ) {
      xProbe[,xIdx] = quantile(x[,xIdx],probs=c(0.0,0.5,1.0))
    }
  }
  # Specify the data in a list, for later shipment to JAGS:
  dataList = list(
    x = x ,
    y = y ,
    Nx = Nx ,
    Ntotal = Ntotal ,
    xProbe = xProbe ,
    Nprobe = nrow(xProbe)


  )
  #-----------------------------------------------------------------------------
  # THE MODEL.
  modelString = "
  # Standardize the data:
  data {
    ym <- mean(y)
    ysd <- sd(y)
    for ( i in 1:Ntotal ) {
      zy[i] <- ( y[i] - ym ) / ysd
    }
    for ( j in 1:Nx ) {
      xm[j]  <- mean(x[,j])
      xsd[j] <-   sd(x[,j])
      for ( i in 1:Ntotal ) {
        zx[i,j] <- ( x[i,j] - xm[j] ) / xsd[j]
      }
      # standardize the probe values:
      for ( i in 1:Nprobe ) {
        zxProbe[i,j] <- ( xProbe[i,j] - xm[j] ) / xsd[j]
      }


    }
  }
  # Specify the model for standardized data:
  model {
    for ( i in 1:Ntotal ) {
      zy[i] ~ dt( zbeta0 + sum( zbeta[1:Nx] * zx[i,1:Nx] ) , 1/zsigma^2 , nu )
    }
    # Priors vague on standardized scale:
    zbeta0 ~ dnorm( 0 , 1/2^2 ) 
    for ( j in 1:Nx ) {
      zbeta[j] ~ dnorm( 0 , 1/2^2 )
    }
    zsigma ~ dunif( 1.0E-5 , 1.0E+1 )
    nu ~ dexp(1/30.0)
    # Transform to original scale:
    beta[1:Nx] <- ( zbeta[1:Nx] / xsd[1:Nx] )*ysd
    beta0 <- zbeta0*ysd  + ym - sum( zbeta[1:Nx] * xm[1:Nx] / xsd[1:Nx] )*ysd
    sigma <- zsigma*ysd
    # Predicted y values at xProbe:
    for ( i in 1:Nprobe ) {
      zyP[i] ~ dt( zbeta0 + sum( zbeta[1:Nx] * zxProbe[i,1:Nx] ) ,
                   1/zsigma^2 , nu )
      yP[i] <- zyP[i] * ysd + ym
    }


  }
  " # close quote for modelString
  # Write out modelString to a text file
  writeLines( modelString , con="TEMPmodel.txt" )
  #-----------------------------------------------------------------------------
  # INTIALIZE THE CHAINS.
  # Let JAGS do it...
  #-----------------------------------------------------------------------------
  # RUN THE CHAINS
  parameters = c( "beta0" ,  "beta" ,  "sigma",
                  "zbeta0" , "zbeta" , "zsigma", "nu" , "xProbe" , "yP" )
  adaptSteps = 500  # Number of steps to "tune" the samplers
  burnInSteps = 1000
  runJagsOut <- run.jags( method="parallel" ,
                          model="TEMPmodel.txt" ,
                          monitor=parameters ,
                          data=dataList , 
                          #inits=initsList ,
                          n.chains=nChains ,
                          adapt=adaptSteps ,
                          burnin=burnInSteps ,
                          sample=ceiling(numSavedSteps/nChains) ,
                          thin=thinSteps ,
                          summarise=FALSE ,
                          plots=FALSE )
  codaSamples = as.mcmc.list( runJagsOut )
  # resulting codaSamples object has these indices:
  #   codaSamples[[ chainIdx ]][ stepIdx , paramIdx ]
  if ( !is.null(saveName) ) {
    save( codaSamples , file=paste(saveName,"Mcmc.Rdata",sep="") )
  }
  return( codaSamples )
} # end function

#===============================================================================

smryMCMC = function(  codaSamples ,
                      saveName=NULL ) {
  summaryInfo = NULL
  mcmcMat = as.matrix(codaSamples,chains=TRUE)
  paramName = colnames(mcmcMat)
  for ( pName in paramName ) {
    summaryInfo = rbind( summaryInfo , summarizePost( mcmcMat[,pName] ) )
  }
  rownames(summaryInfo) = paramName
  summaryInfo = rbind( summaryInfo ,
                       "log10(nu)" = summarizePost( log10(mcmcMat[,"nu"]) ) )
  if ( !is.null(saveName) ) {
    write.csv( summaryInfo , file=paste(saveName,"SummaryInfo.csv",sep="") )
  }
  return( summaryInfo )
}

#===============================================================================

plotMCMC = function( codaSamples , data , xName="x" , yName="y" ,
                     showCurve=FALSE ,  pairsPlot=FALSE ,
                     saveName=NULL , saveType="jpg" ) {
  # showCurve is TRUE or FALSE and indicates whether the posterior should
  #   be displayed as a histogram (by default) or by an approximate curve.
  # pairsPlot is TRUE or FALSE and indicates whether scatterplots of pairs
  #   of parameters should be displayed.
  #-----------------------------------------------------------------------------
  y = data[,yName]
  x = as.matrix(data[,xName])
  mcmcMat = as.matrix(codaSamples,chains=TRUE)
  chainLength = NROW( mcmcMat )
  zbeta0 = mcmcMat[,"zbeta0"]
  zbeta  = mcmcMat[,grep("^zbeta$|^zbeta\\[",colnames(mcmcMat))]
  if ( ncol(x)==1 ) { zbeta = matrix( zbeta , ncol=1 ) }
  zsigma = mcmcMat[,"zsigma"]
  beta0 = mcmcMat[,"beta0"]
  beta  = mcmcMat[,grep("^beta$|^beta\\[",colnames(mcmcMat))]
  if ( ncol(x)==1 ) { beta = matrix( beta , ncol=1 ) }
  sigma = mcmcMat[,"sigma"]
  nu = mcmcMat[,"nu"]
  log10nu = log10(nu)
  #-----------------------------------------------------------------------------
  # Compute R^2 for credible parameters:
  YcorX = cor( y , x ) # correlation of y with each x predictor
  Rsq = zbeta %*% matrix( YcorX , ncol=1 )
  #-----------------------------------------------------------------------------
  if ( pairsPlot ) {
    # Plot the parameters pairwise, to see correlations:
    openGraph()
    nPtToPlot = 1000
    plotIdx = floor(seq(1,chainLength,by=chainLength/nPtToPlot))
    panel.cor = function(x, y, digits=2, prefix="", cex.cor, ...) {
      usr = par("usr"); on.exit(par(usr))
      par(usr = c(0, 1, 0, 1))
      r = (cor(x, y))
      txt = format(c(r, 0.123456789), digits=digits)[1]
      txt = paste(prefix, txt, sep="")
      if(missing(cex.cor)) cex.cor <- 0.8/strwidth(txt)
      text(0.5, 0.5, txt, cex=1.25 ) # was cex=cex.cor*r
    }
    pairs( cbind( beta0 , beta , sigma , log10nu )[plotIdx,] ,
           labels=c( "beta[0]" ,
                     paste0("beta[",1:ncol(beta),"]\n",xName) ,
                     expression(sigma) ,  expression(log10(nu)) ) ,
           lower.panel=panel.cor , col="skyblue" )
    if ( !is.null(saveName) ) {
      saveGraph( file=paste(saveName,"PostPairs",sep=""), type=saveType)
    }
  }
  #-----------------------------------------------------------------------------
  # Marginal histograms:
 
  decideOpenGraph = function( panelCount , saveName , finished=FALSE ,
                              nRow=2 , nCol=3 ) {
    # If finishing a set:
    if ( finished==TRUE ) {
      if ( !is.null(saveName) ) {
        saveGraph( file=paste0(saveName,ceiling((panelCount-1)/(nRow*nCol))),
                   type=saveType)
      }
      panelCount = 1 # re-set panelCount
      return(panelCount)
    } else {
    # If this is first panel of a graph:
    if ( ( panelCount %% (nRow*nCol) ) == 1 ) {
      # If previous graph was open, save previous one:
      if ( panelCount>1 & !is.null(saveName) ) {
        saveGraph( file=paste0(saveName,(panelCount%/%(nRow*nCol))),
                   type=saveType)
      }
      # Open new graph
      openGraph(width=nCol*7.0/3,height=nRow*2.0)
      layout( matrix( 1:(nRow*nCol) , nrow=nRow, byrow=TRUE ) )
      par( mar=c(4,4,2.5,0.5) , mgp=c(2.5,0.7,0) )
    }
    # Increment and return panel count:
    panelCount = panelCount+1
    return(panelCount)
    }
  }
 
  # Original scale:
  panelCount = 1
  panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
  histInfo = plotPost( beta0 , cex.lab = 1.75 , showCurve=showCurve ,
                       xlab=bquote(beta[0]) , main="Intercept" )
  for ( bIdx in 1:ncol(beta) ) {
    panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
    histInfo = plotPost( beta[,bIdx] , cex.lab = 1.75 , showCurve=showCurve ,
                         xlab=bquote(beta[.(bIdx)]) , main=xName[bIdx] )
  }
  panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
  histInfo = plotPost( sigma , cex.lab = 1.75 , showCurve=showCurve ,
                       xlab=bquote(sigma) , main=paste("Scale") )
  panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
  histInfo = plotPost( log10nu , cex.lab = 1.75 , showCurve=showCurve ,
                       xlab=bquote(log10(nu)) , main=paste("Normality") )
  panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMarg") )
  histInfo = plotPost( Rsq , cex.lab = 1.75 , showCurve=showCurve ,
                       xlab=bquote(R^2) , main=paste("Prop Var Accntd") )
  panelCount = decideOpenGraph( panelCount , finished=TRUE , saveName=paste0(saveName,"PostMarg") )
 
  # Standardized scale:
  panelCount = 1
  panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
  histInfo = plotPost( zbeta0 , cex.lab = 1.75 , showCurve=showCurve ,
                       xlab=bquote(z*beta[0]) , main="Intercept" )
  for ( bIdx in 1:ncol(beta) ) {
    panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
    histInfo = plotPost( zbeta[,bIdx] , cex.lab = 1.75 , showCurve=showCurve ,
                         xlab=bquote(z*beta[.(bIdx)]) , main=xName[bIdx] )
  }
  panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
  histInfo = plotPost( zsigma , cex.lab = 1.75 , showCurve=showCurve ,
                       xlab=bquote(z*sigma) , main=paste("Scale") )
  panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
  histInfo = plotPost( log10nu , cex.lab = 1.75 , showCurve=showCurve ,
                       xlab=bquote(log10(nu)) , main=paste("Normality") )
  panelCount = decideOpenGraph( panelCount , saveName=paste0(saveName,"PostMargZ") )
  histInfo = plotPost( Rsq , cex.lab = 1.75 , showCurve=showCurve ,
                       xlab=bquote(R^2) , main=paste("Prop Var Accntd") )
  panelCount = decideOpenGraph( panelCount , finished=TRUE , saveName=paste0(saveName,"PostMargZ") )
 
  #-----------------------------------------------------------------------------
}
#===============================================================================

#===== End low-level script ================

3 comments:

  1. I know it is a bit beside the point, but would you be able to share the code for creating Fig 18.3 as well?

    ReplyDelete
  2. How would it be implemented for the Jags-Ymet-XmetSsubj-MrobustHier.R code?

    ReplyDelete
  3. This comment has been removed by the author.

    ReplyDelete