
Conformal Inference with Tidymodels - posit::conf(2023)
Presented by Max Kuhn Conformal inference theory enables any model to produce probabilistic predictions, such as prediction intervals. We'll demonstrate how these analytical methods can be used with tidymodels. Simulations will show that the results have good coverage (i.e., a 90% interval should include the real point 90% of the time). Presented at Posit Conference, between Sept 19-20 2023, Learn more at posit.co/conference. -------------------------- Talk Track: Tidy up your models. Session Code: TALK-1085
image: thumbnail.jpg
Transcript#
This transcript was generated automatically and may contain errors.
So, I'm here to talk about conformal inference and how to do that in tidymodels. Conformal inference, you put those two words together and if you would ask me maybe like a year ago what they mean, I'd essentially be like, I don't know. So you know, it's kind of like an oddly named technique, so in hindsight maybe a better title for this presentation was how you can make prediction intervals for any type of model without making very much statistical assumptions about your data or your model.
Just to remind you, if you want to put in some questions, here's the link for it. And also in this wee tiny little font down here, you can see the link to the slides if you want them.
What is a prediction interval?
All right, so maybe I should start first by saying, well, what's a prediction interval if you've never heard of that? And so a prediction interval, it's like a confidence interval, but it's an interval on a separate type of object. If you had like a 95% prediction interval, that means that you have a bounds where 95% of the time a new observation that you would acquire later will fall into that interval. So whereas confidence intervals are on like the mean prediction, this is about new observations.
And the little diagram here shows you both confidence interval, which is kind of very narrow for this data set, and a wider prediction interval. And you can see there's a couple of data points on the bottom and a few on the top that don't fall in. So I think that's a 95% interval right there.
People use these when they can get them, which is not that often for models. It gives you a sense of uncertainty about your prediction usually. So it gives you a sense of how much should I not trust the prediction, but you can sort of calibrate your expectations as to how good the prediction is.
The intuition behind conformal inference
All right, so let's start with just some data, right? So this is just like a histogram on its side. It's 500 data points. It's centered around zero, maybe like plus or minus like, I don't know, 0.15. And so you spend some time collecting this data, and then you get the 501st data point, and let's say it falls down here. And it's like, well, it's not outside the range per se, but is that like a new data point from the same distribution, or has something changed, like the model drift question and things like that?
And so what we can do is, if we were from a statistical standpoint, if we want to make some sort of judgment about whether this new data point is from our original distribution, one thing we could do is use good old-fashioned quantiles. So if you wanted to do like, let's say you wanted a 9% interval, what you could do is get the 0.05 quantile on the lower end and the 95% quantile. And if you were to make a completely distribution-free probabilistic statement about that, you could say that 90% of the time, when I get data from this distribution, they're going to fall in between that interval.
And so this is where the conformal part comes in. So you would say that that new data point, if it fell within that interval, it conforms to sort of this original reference distribution here. So if you were to take this data set, compute those particular intervals, so there's a lower 95%, and this is like the 95% interval quantile, and you would think that any new data point that falls in here would be consistent with the original data. And you can see, of course, there's false positives, 5% of the time here and 5% of the time up here. In this particular data point, you would not really consider to be consistent with the original data.
So why am I telling you this? Like suppose instead of me just saying we have some data, let's say they were out-of-sample residuals. So let's say you fit some sort of regression model, you had an extra data set, you know, just laying around like we do all the time, and you took that model and you predict on this different data set, you can compute those residuals. And so that gives you a sense of, for the data that was collected, especially in what we call this the calibration data set, you could say like, on average, this is what I expected the noise around my predicted values to be based on these residuals.
So if this is a training set, and this is just a nonlinear function I fit to it, we build the model on this data set, we take that same model fit to the calibration data set, this 500 I just showed you, and calculate the residuals. And that generates this histogram that I just showed you. And then what we can do is we can take this zero-centered histogram and basically center it around the predicted values of our model. So this is like a test set, so we get new data points, we, you know, the same model fit here.
Oh, and by the way, I swear to God, it's completely accidental that these colors match the ones of our T-shirts. I didn't realize until like an hour ago, I was like, oof, it's not like me.
So anyway, so these are like the, it's a consistently, consistent-width band around the predicted values in here, and I guess we're calling that purple. And you can see some of the data points don't fall inside the band, but mostly they do. And this is basically something akin to a prediction interval, it's just going about it in a completely different way.
Those prediction intervals are what we call parametric, we have to make some sort of statistical or probabilistic assumption about your data and the model and things like that. So what we tend to think of when we do these conformal intervals is that they have, if it were like a typical prediction interval you get from just straight-up linear regression, we'd say that, you know, a particular data point has a coverage of, let's say, 95%. But in this case, what we would say is, on average, across all the samples that we use, the coverage for that interval is 95%. So it's really the additional bit here is average, usually, for most conformal methods at least.
But in this case, what we would say is, on average, across all the samples that we use, the coverage for that interval is 95%. So it's really the additional bit here is average, usually, for most conformal methods at least.
Methodology and assumptions
So if you're a statistician or something like that, and you're interested in the methodology, there's been quite a lot written about this, especially lately, it's been kind of exploding both in terms of papers and in terms of visibility. It has a very strong frequentist theme to it, if you know what that means. We're basically using quantiles and empirical distributions, and it's sort of, when you read about this, it really sort of kind of ties in in an indirect way to what we would call non-parametric inference.
So anyway, so that's all well and good. So what's good about it? Well, they make very, these intervals require very, very minimal assumptions about your data. Exchangeable data means that if I just reorder the data, you know, I can get the same results. And that's not true for time series, but it's generally true for a lot of other situations. And of course, time series are important, so there are specialized conformal methods that are related to time series. Matt D'Ancia's model time package has those in them, so you can get those if that's your thing.
It can work with any type of regression or classification model. But so far in the probably package, I've only implemented it for regression. So we're working on that. And it's relatively fast. There's a couple of different methods I'll show you. The ones I'll show you are all, you know, relatively fast to both train the conformal interval method and to make new predictions. There's one called full conformal, which is the one that honestly kind of makes the most sense to me how this thing works, but it's abominably slow. So we have that in tidymodels, but like, you probably don't want to do that.
Anyway, the cons is because it doesn't really make any, you know, it's making minimal assumptions. And as you'll see in a little bit, they don't necessarily extrapolate well. So if you're making a prediction on a new sample and it falls well outside of your training or calibration set, the coverage of that or even the interval itself may not be very good. And so, you know, that's a little bit different than what we would see in things like linear regression.
So if that's a problem, if you're not sure, you might be saying to yourself, well, how am I going to know if I'm really extrapolating if I have a bunch of predictors? We do have something called the applicable package. It was generated by one of our interns, Marley, a couple of years ago. And what that does is it's used to quantify how much I'm extrapolating from my training set. So you can use applicable to have a score that gives you a sense of, like, how far out am I or am I sort of, like, in the middle of my training set. And they're probably not great for small sample sizes, for some, like, definition of small. You can get probably intervals, but, you know, it's hard to say what the coverage would be.
Method 1: Split conformal inference
So I'm going to show you some code. So I'm going to have a training set that is 1,000 data points, a test set that's 500, and a calibration set, which is where we compute those residuals, about the same size of 500. So you can load the tidymodels package down here, and we're going to be looking at some results if you use cross-validation, like V4 cross-validation. So I generated a cross-validation object here using tidymodels.
The model I'm going to use, which doesn't ordinarily generate prediction intervals, is called a support vector machine. And here's some code to specify the model. And then line seven here is we just fit that support vector machine model. And we can use that to generate the curve on that, or the curve on that data point or data set that I showed you.
All right. So there's three different conformal methods I'll go through really quickly. The first one's called split conformal inference, and you've already seen it. So it's basically you take a model, you fit it to your training set, you predict your calibration set, compute the residuals, and go on your merry way. So it's pretty simple.
In tidymodels, you load that with the probably package. The functions we're going to show start with int conformal. So in the last bit here is the method. So what you do is you give it your fit in your calibration set. It basically predicts the calibration set, gets residuals. And then when you want to actually get the prediction, predictive intervals on new samples, you give it, let's say, the test set. And here you specify whatever confidence level you want. And then you can see you get new columns here, pred lower and pred upper.
What does that look like for the data set that we just used? Is this, again, purple line is our, in this case, support vector machine model prediction. And then here's our prediction intervals. You see some of them down here aren't really overlapping and so on. So as you might imagine, it's pretty fast, pretty simple. Again, one of the downsides to it is the interval widths are always the same. So whether the variability in your outcome data is really small or really large at different points, it's always going to be the same width.
Method 2: Cross-validation conformal inference
The downside also to this is you have to have an extra set of data laying around that you can use to estimate that distribution. We don't always have that. So one thing you can do is you can use cross-validation. Every time you do cross-validation, you have some data that you're predicting. And those data points were not used by the corresponding model that generated the data. So the data you use for prediction is not the same data you use to fit the model inside a cross-validation. So if you do tenfold cross-validation, you have ten sets of sort of held-out residuals. And it turns out there's some math behind this that says, yeah, you can, with some slight alterations, you can do something very simple to split conformal inference.
The theory has only really been done for default cross-validation. So you can use other resampling methods and tidymodels to do this, but you sort of get a warning that's saying, like, you're at your own risk. I have a link near the end of a GitHub repo where I've done a ton of simulations. And I did try bootstrapping with this, and it seemed to work pretty well. So no guarantees, but it doesn't seem like a horrible idea to try other resampling methods.
So how that works is a little bit different than what you would normally do with resampling. In your control object, you have to save the out-of-sample predictions. That's not very exotic. But the other thing you do is, if you're doing tenfold cross-validation, you need to save the ten fitted models that you generated during cross-validation. So the easiest way to do that is there's an argument called extract, and just give it the identity function, which just returns itself. Then you use the fit resamples function like you normally would. And then the inconformal CV function just processes all that data, and this is the same thing as it was last time to generate the intervals.
And so, again, on that same data set, in this particular case, it's not always like this. These intervals are very, very close to the split intervals. That's not always the case, but it's what happened here.
Also one little note about this I found with very small sample sizes, when CV plus conformal methods center their interval, they center it on the average prediction of the ten cross-validated models, not the actual model that you fit on the training set. So if you have a very small sample size, those two things might be different, and your intervals might be kind of shifted in some places. I might try to fix that if you can fix it, but it's just like a little note.
Method 3: Conformalized quantile regression
And then the third and final method, which is very different than the others, is called conformalized quantile regression. So quantile regression is a pretty well-known in statistics technique where when you fit a linear regression and you get a prediction, that's the mean of the outcome distribution. And in a quantile regression, like a quantile linear regression, what you can do is you can actually make predictions on whatever quantile you want. So if you want to do something around the center of the distribution, you would use a quantile regression with a 50% quantile, a 0.5 quantile to get the median prediction.
But it's really, if we set it to be like, let's say, the 0.05 or the 0.95 quantile, that's really what we're trying to do, basically, in conformal inferences. We're trying to estimate the boundaries of what our predictive distribution is. And so it's really a more direct approach to solving this.
This is the same data set I showed in that second slide where I had linear regression. And you can see, maybe, it might be hard to see, but you can see these are not equally spaced. So the upside to conformal inference using quantile regression is that you very well can get intervals that are not equally spaced like the other two methods.
Now, linear methods here are maybe not the best of ideas. So what we do is what something other people do in the literature is we use a tree-based unsolvable, specifically, quantile random forests. And then that can give us a prediction interval. And here's the same data set. You can see, for lack of a better term, I'd call this like a chunky sort of line. But it will vary across the range if the variance changes across the range.
The other thing to note here is you can see when the data stops, this is true of all tree-based methods, not just random forest, that it just sort of goes off to infinity with the same lines. And this is especially bad if you start extrapolating. So you can imagine if I set a linear regression to this or something for this particular data set, like this cloud of points, you can imagine that line just keep increasing. But the conformalized quantum intervals will just keep going off in one direction. So you can extrapolate and get very, you can get intervals that are so poor that they don't even contain the predicted value. So just with this particular method, especially, you have to be very sure you're not extrapolating.
All right. The code's a little bit different for this one. You need to give it your model fit, and you need to give it both the training set and the calibration set. So it works best with a split sample. And also, since you're estimating the quantiles, you have to get the confidence level upfront. So instead of setting it in the predict method, you have to do it here. And then we use the QRF function for quantile regression forests. So you have to tell it how many, anything you want to pass to that argument or that function, you can do it here. So I'm going to bump up the number of trees to 2,000. And then they get the actual intervals.
What's that look like for our data set again? It's kind of interesting. This seems like you might be like, wow, I don't know, that's kind of scary. But performance of this is actually pretty good despite the lack of smoothness of these intervals. So when you do simulations and things like this, the idea that you're getting step functions across a range is not that big of a deal, performance-wise.
Checking coverage with simulations
Speaking of performance, we have to make sure it works. So this is a little bit different than your average machine learning method because we're saying, if you get intervals from this function, we think they'll have, let's say, 90 percent coverage. And so with this GitHub repo here, I did a bunch of simulations to make sure these things actually work the way they're supposed to work.
Now, for the data sets I showed you, like this one, the coverages were very close to 90 percent. So for the original split conformal plot, it was just about, CV plus was 80 percent, and then we're very close to 90 percent for those individual data sets. And I have a read me at that repo. Generally speaking, I tried it with trees and regression and neural networks and things like that. The sample size matters a lot. So your coverage may not be spectacular if you have a very small sample size, but I was pretty happy with the average coverage of those methods. So they seem to be working and doing what they're supposed to do.
What's next and resources
So what's next? As I mentioned, we've already implemented regression models. So probably sometime in the new year, I'll start looking at doing the same thing for classification models. In classification, in this particular domain, they typically are focusing on situations where we have a lot of classes. So let's say you're doing some sort of image classification and there's like 30 things you might classify an image as. What they focus on with conformal inference is saying, like clustering them, and saying like, well, you know, what is it, like dog or hot dog, that whole thing? You know, if you have a prediction and you have two class probabilities that are pretty close, one thing conformal inference does for you is to tell you whether they're like sort of equivocal probabilities or not. And again, it's a fastly growing field, so if there's a new methodology that pops up, we will definitely take a look at it.
Thanks to tidymodels and tidyverse groups for listening to me prattle about this and test my presentation. And Joe Rickard did a lot of reading and help with me on this to get a sense of how these things actually work. Speaking of which, if you want to learn more, there's an article on tidymodels.org. The two references I would really suggest the most here is Christopher Molinar made this book. It says conformal inference with Python, but this is probably the best reference if you want to get like a layman sense of like how these things are working. So it's an excellent book, you can get it off Amazon. If you're more of the statistical kind of person, Ryan Timshirani has a really nice set of notes on the various methods and how they work. And then there's awesome conformal prediction repo and GitHub has like any reference that comes up ever. They'll pretty much list it in probably this talk or whatever there.
Q&A
We have a few questions coming in. Do these prediction intervals break down in the presence of heteroscedasticity? The first two, well they don't break down, they're just on average, they're right. So if you go to the article on tidymodels.org, we did a little simulation where we took something that had different variants as you go. And so if you do the split intervals, they're bad here, they're bad here, they overdo it here, so on average they're fine. So in other words, it's like you can't compute them, they're just not especially great. And just in case you were wondering, you would probably use the conformal inference there and they sort of align pretty well. So you can see the quantum method here is very tight in that area and widens out where it's supposed to. So that's what I would suggest if you think that's the case.
Can these be used for anomaly detection in deployed models and then drive some sort of emergency action to avoid anomalous predictions? Yeah, that's a good question. So I've seen people on social media say that. I don't know that I would buy into it so much. If you were using this, you could look at the width of the interval to get a sense of how it's going within the mainstream of your training data. So as long as you stay within your training data, you're fine. And actually, it wouldn't be very good at all for anomaly detection for something that's an aberration because it's extrapolation because look what happens as you go a little bit further. It's like, boop, these goes off into infinity and actually my fitted curve goes out here. So this is the one you would probably use if you want to judge the uncertainty by the confidence interval width because the other two methods have a fixed width. And it only would really detect anomalies that are sort of where your data already live. So that's my answer to that question.
When evaluating performance of prediction intervals, for example, coverage, is that something you picture living in probably or some kind of new package like Yardstick? Oh, well, like the evaluation of whether these things work or not? No, I'll just stuff that in a GitHub repo because it's like a million files of simulations and things like that. But if the question is like, if you're going to develop methods for interval estimates like this, yeah, probably is the best place for them. We've added a bunch of, Edgar actually, did a lot of work and added a bunch of calibration tools for models and we put this in probably. Probably, it tends to focus on things you would do after the model. And that seems like the kind of thing that would fit there. So yeah, I'd most likely put them there.

