Resources

Max Kuhn - TabPFN: A Deep-Learning Solution for Tabular Data

TabPFN: A Deep-Learning Solution for Tabular Data (Max Kuhn) Abstract: There have been numerous proposals for deep neural networks for tabular data, such as rectangular data sets (e.g., data frames). To date, none have really worked well and take far too long to train. TabPFN is a model that emulates a Bayesian approach and trains a deep learning model on a prior of simulated tabular datasets. Version 2 was released this year and offers several significant advantages, but also has one notable disadvantage. I'll introduce this model and show an example. Presented at the 2025 R/Pharma Conference Europe/US Track. Resources mentioned in the presentation: - Presentation slides: https://topepo.github.io/2025-r-pharma/

Dec 17, 2025
22 min

image: thumbnail.jpg

Transcript#

This transcript was generated automatically and may contain errors.

All right. So, as Eric mentioned, I want to talk about some deep learning work that is not mine.

If you're interested in the slides, there's links to papers and things in here. So on GitHub, just go to Pepo and it's 2025Rforma.

All right, so I want to talk about TabPFN. It's a deep learning model that could be used for what they call tabular data. It can do regression or classification or density estimation. TabPFN stands for Tabular Prior Dated Fitted Network. And this tabular term in deep learning, for like the last, I don't know, 20 years, deep learning has been much more concerned with image prediction, video, text, and unstructured data. And so from their terminology, tabular data is probably the data we've lived with forever, which is a rectangular data set where you have columns and rows and the columns are attributes of some characteristic like predictors and outcomes and the rows are samples.

How TabPFN works: the data prior

So there's two main references. This first one, so there's two versions of the model, at least in terms of software. So the first reference there in 2022 is the original one. And then I'll mostly be describing version two where the paper came out this year. I don't know there's a lot difference from it in terms of methodology, but the model is, as you'll see in a minute, more complex.

Now, why would you wanna know about this? Deep learning has not been successful with rectangular data sets. It's not often that they can create very large models with millions of parameters that are often do worse than like a busa tree or random forest. And so this model actually seems to live up to the hype to some degree, and I'm a very skeptical person, especially when I first read this like click baitish Gawker headline about like, we do everything in seconds. But I think actually they do kind of do that. So let me walk you through, it's very unconventional and it's somewhat complicated to describe how this thing works.

So the main, there's a couple of main attributes of these models. The first one is this notion of a data prior. So these authors wrote an earlier paper where they showed that they could take a very large deep learning model and comparing it to known benchmarks actually do something that very well approximates Bayesian inference, right? So we think about Bayesian models having a prior and some data that produces a posterior probability or posterior distribution. And they found that with a sufficiently large type of deep learning model, they can do well, they can basically do the same thing within like an epsilon.

The really interesting thing is when you think about Bayesian analysis, like let's say we're trying to like model a single proportion. Our prior is on some singular scalar parameter, like the probability of a coin being heads or something like that. In this particular case, the prior is indexed by data sets. And so when we think about what the prior is, it's not like a scalar function, it's like this artificial notion of a data set or this abstract notion of a data set. And so what they did is they developed this like really complicated, and we don't know all the details of this part, but they've developed a very complicated like causal graph that mimics like common data generating mechanisms that we would see in rectangular data. And I'll show you a list in just a minute, but basically they have like a simulation system that can generate data sets with different number of columns and different number of rows. And depending on how they set up the hyper parameters for the prior, it could have like missing data or like highly correlated predictors and things.

So when they talk about having a prior, basically they can generate this probability, this prior notion of what a data set would be where data set is symbolized by the D here. And so we don't know again, all the nuances of this particular one, they mentioned a few aspects of it, but things that are probably in there is like distributional effects. So when they simulate data from their prior, it could be skewed data, so they probably have some parameter governing that, correlation structures between columns, zero correlations maybe between rows and things like that. In their causal graph, they can with some probability generate like missing data based on some mechanism and so on. You can imagine there being part of that network that handles like latent variables. So like in chemistry, when I used to work in computational chemistry, almost everything that we would do is based on size and charge of the molecule. And so you might have like 12 predictors that are actually driven by like a few handfuls of latent variables. And that's probably something that they simulate in their data. And then lastly, one thing they do, and they call this the task, is they generate some simulated equation that relates the predictors to the outcome, and then simulate the parameters of that relationship and add some random error to it. So at the end of this simulation system, you basically get a data set that's artificial, and it has some task, which is like a classification problem or regression problem with different coefficients and functional forms.

And so when they build their models, as you'll see in a minute, what they do is they generate quite a lot of these data sets and that's what they build their deep learning model on.

So a little bit of math here, what they basically have is, if X and Y are the actual data that we have, or the data we wanna predict, like an X, what they have is a simulation system that generates a ton of these data sets and tasks. And they basically have their deep learning model that approximates what Bayes' rule would do, which is like factoring together like the likelihood and the prior, normalizing that, and then produces a posterior distribution, right? And so this is just like a really fancy way of writing Bayes' rule in the context of their deep learning system.

Now, the important thing is you just have X and Y, you never really see the data sets that were in the model and the task that they simulated. Now, for version one of the model, they used, I think about 500 data sets. With version two, they scaled it up to 130 million realizations of the data set and the task. So they quite like scaled up the system. And one really interesting thing about this model is, the Python and R libraries for this, the model is already fit. So all the weights and everything are already estimated by the time the software is given to us. And we don't actually do any more estimation after that. And so when we come in with our training set and our samples that we wanna predict, we're not updating the network, like using forward pass or back propagation to change the weights. We do actually no additional estimation based on the actual data that we have to get good predictions.

the model is already fit. So all the weights and everything are already estimated by the time the software is given to us. And we don't actually do any more estimation after that.

Attention mechanisms and in-context learning

And let's talk about how that works, because that's kind of a different thing.

So if you've read anything about like neural networks or deep learning, like in the last 10 years, probably the one word you may have heard of is attention. And so there's a really famous paper called attention is all you need. And these people came up with a mechanism to really greatly improve how these models work. And I think it's not unreasonable to say most of the advances we've had in LLMs and other complicated like sentence prediction models and things like that are largely in terms of, or due to these things called attention mechanisms. And so the way it works is, again, remember we have a tabular data set. The TabPFN people do two types of attention. They do what they call feature attention, which basically allows the model to relate between column, let's say between predictor relationships, such as like interactions or things like that. And then they also allow to look across the rows at once so you can get like relationships between different rows of a particular column. So you can imagine the feature attention being really good at things like, I don't know, like spline terms, you know, implicitly or interaction effects and things like that. Whereas the sample attention might do really well at like serial correlation, things like that. And so in their network, they have the ability to really estimate these relationships in either dimension.

And from that model, once you push your data through, you basically get some posterior distribution that's symbolized here. Now, one of the interesting things is, like I said, there's no actual estimation. In fact, the training set isn't really a training set because you're not estimating anything from it. So the way their model works is you take the data that you have, which we would normally think of as a training set, and you concatenate the data you want to predict to the bottom of that, like shown here. And what they do is they push this data set, the entirety of it, through the network at once. So you need both the training set and the test samples or the unknown samples to make a prediction on the unknown samples. And the way that works is through the intention mechanism. And so what it actually does is it doesn't really do like a pure similarity search like we would think of it, but it basically has a weighting mechanism that can look for weights in the model, the parameters of the neural network of the model that are most relevant for the samples that you're trying to predict. So as you start to push the training set through the model, it's doing what they call in-context learning, and it's kind of getting an idea of like what part of this really huge neural network is relevant to the type of data that I want to predict. And so what it does is it generates a set of weights that upweights model parameters in the neural network that are relevant to the data that you have at hand that you just gave it, and downweights things that are more irrelevant. So, you know, let's say you have completely uncorrelated predictors, it would downweight the part of that neural network that handles like latent factors like PCA type effects. And so in Accents, the training set, it just like primes the pump for what the model should be doing. So by the time that the new samples you want to predict are pushed to the network, it's relatively localized the parts of the larger model that will be relevant for the samples you want to predict.

Version 2, hardware requirements, and performance

Like I said, version two is trained on 130 million data sets of different sizes. It's intended use for what they think of as small data set from the deep learning perspective. And they consider that to be, you know, a training set up to 10,000 samples, 500 features and 10 classes. And you can do more than 10,000 samples or 500 features. The maximum of 10 classes is sort of baked into the algorithm right now. And as you'll see in a minute that I don't know that we need up to 10,000 training set samples, but it depends on your data, but I'll show you some simulations in a minute. And I don't have any fancy plot to show you, but I probably applied this to about a dozen data sets that I have, and it always ranks in the top five. So it's doing very, very well performance wise compared to other models. So from a predictor performance, it's very, very, very effective.

Now, one thing to keep in mind is if you're like me and you've lived in this tabular world your whole life, you know, the notion of needing what I think of as like specialized hardware to make a prediction is a little bit odd. But basically a GPU is required. You can do the calculations on the CPU, but it uses an inordinate amount of memory and it takes quite a long time to compute. Whereas with a basic GPU, nothing extravagant, you can basically get a lot of predictions in a second, like they say. So really, if you have what's called a CUDA based GPU, which is like an NVIDIA card, your prediction time is very fast. Otherwise you're gonna be waiting a long time. Like I have a Mac right now, a Mac desktop that has like a fancy GPU with like 24 cores, but it doesn't use CUDA instructions. So it's kind of useless to me for this type of model. You really need like a CUDA based graphics card, which tend to be like an NVIDIA type card.

That might seem odd. Like if you're coming from a background like mine, where I've never needed like special hardware to make predictions on a couple of samples. But in this particular case, that's really what's required.

A quick question I'll answer real quick is pre-processing, feature engineering. Technically, no, it does a lot for you. I usually, when I've tested this, I've given it as is. You don't have to convert to dummy variables or transformations. That is all handled inside the network.

Robustness to irrelevant predictors

So I've been sort of studying this sort of indirectly because we don't have a lot of, we're not like obfuscating information, but there's a lot that is in papers, but not in software. And so one thing that I always think about when we look at new models is, well, what if I have a lot of predictors in my training set that have nothing to do with the outcome? And so I have the simulation I've been running for like 10 years now for like different books and look at, for a particular simulation that I like, fit models of different types when you take and add irrelevant predictors. So the X's axis here is how many noise predictors we have. The right, or the Y axis here is basically the inflation in the error. So you want this to stay flat. And if you look at something like GleamNet, it doesn't really matter how many extra predictors you put in. Some of them might be included periodically, but the performance doesn't get markedly worse. And then a counter example is basic feed-forward neural networks. As you add more noise predictors, you really, really degrade the root mean squared error and the error gets a lot worse. And KNN's are really awful at this. And these methods in black don't have any real mechanism for feature selection, but methods like Cubist or Mars or something like that do, and they tend to be very resistant. And then when you come down here, we don't really know what TabPFN does, if it does anything at all, but for some reason it's incredibly flat when it comes to irrelevant predictors. It seems to do a really good job of knowing that it should not let these noise features interfere.

It seems to do a really good job of knowing that it should not let these noise features interfere.

I'm not gonna talk about this model too much, but this is a more conventional tabular neural network model, and ironically is designed to deal with irrelevant features. It has an architecture that will downweight features and it actually does kind of poorly when you have a lot of predictors that are irrelevant. And so we don't know why, but TabPFN seems to be resistant to this or do some sort of automatic feature selection, whether they do it explicitly or implicitly, it's handling this pretty well.

Training set size and convergence

Now, one other question is, I talked about attention and upweighting model parameters. And so one question I had is, well, will that happen after 10 data points or a thousand data points, or what do we need to do the right thing here?

And so I took a real data set that has about a 4,800 training set samples and about 1,300, it's a test set, but we'll just call it a holdout. And then I just randomly sampled training set rows and then predicted this holdout set. And so what I found is in the X-axis here is in log units in terms of the training set number of rows, that if you only have like four rows, your performance is really awful and is to be expected as we increase the number of rows, we get down to like a hundred, we're actually getting decent performance. And then we get to about a thousand for this particular data set. We've really stabilized our predictions and are getting very good performance. That's true for the Briar score. You can see the air into the RC curve gets pretty high, you know, after a certain degree, like 500 or a thousand samples.

Now the prediction time actually, this is done on a GPU, actually does increase exponentially, but again, it's two seconds for predicting about 1,300 samples. So even though it's getting larger, this prediction time with the right hardware is not really that much of an obstacle.

To support this, you can, I just randomly selected like 12 data points and looked at their class probability estimates as the classification problem to see how well they converged. And here's one that in the end has a very low probability of being the event, but you can see on the initial data set with like four rows sampled at random, almost every row here is showing something like 50% class probability. And like this, when many of them do converge, you can see this last one does converge. I'm not sure what this jump is around the middle, but then you do have one or two that still have a high degree of uncertainty over the training set size. And that's just probably because they're difficult to predict. So we don't know that we need 10,000 data points for a lot of data sets, depends on their complexity, but we certainly do need, you know, more than a handful to sort of get attention working the way we want it to.

Additional features and software integration

Now, there is a lot more that we could talk about. I don't have time, but if you look at the link to the second paper, they do a lot about talking about actual inference, meaning like the machine learning people, which I sometimes consider myself, but when you read about them or hear them talk about doing inference, they really are saying prediction. I mean here getting like posterior probabilities for predictions or, you know, predictive intervals and things like that. That's easy to do because they approximate a posterior distribution. You can also ensemble this model and that seems to be effective. They have their own internal variable importance measurement and finally they do, you can use regular like explainability measures, but they have a built-in Shapley API. So if you want to, if you like those particular model methods, you can get them from this model pretty easily. And then finally, there's a Python library and I think somebody has posted the link to the R port. The code is very similar. It's basically what you would expect. You load your library in Python, you initialize the classifier or the regressor, give it the training set and then make predictions. It's the same thing here. Nothing really happens, you know, at this point. I think all it does is like embed the data, but it's just waiting to concatenate these data with whatever data is in the test set.

It is kind of being incorporated into tidymodels. One problem that we have, and this is really a Python problem, is we use reticulate here. This is awful, but if you load any R package that uses OpenMP and then be a reticulate load a Python package, if that Python package uses OpenMP, as soon as you invoke it, like running a function, it segfaults. And that is a known problem with Python. So that's the main thing for me, fully integrating with tidymodels is if you don't read the fine print and initialize your Python library first, you're gonna get like a bunch of segfaults. So we're working on that, but yes, it will have broader usage in tidymodels.

And that's pretty much it. Again, thank you to everybody for the attention or for the invitation to speak. And thanks to Simon, Daniel, and Tomas at Posit who helped with various aspects of this presentation.

Q&A

Thank you so much, Max. That was very, very informative. I guess one question that did come in was, is there actually a limit on the size or number of say rows or columns of a tabular dataset that your TabPF in order for it to handle effectively? Is there kind of a limitation on that?

They have limitations built into the software, but they also have an option to ignore those. So I think it's 10,000 rows in the training set and 500 columns. I don't know that we need all, I mean, it depends on how difficult your problem is. I don't know that we need 10,000, but I did test it on a spectroscopy dataset that I have that had about 600 predictors. And I just turned the switch off that says ignore the limitations. And it still did pretty well. I'm not sure what it does in that instance, but you can get around those if you want to. You'd have to probably play around a little bit to see for your dataset, like how much of that training set you need. But again, with the GPU, you get pretty fast predictions. So it may not be that big of a deal if you have like 20,000, I don't know. But that's something to look at, yeah.

So, okay. So speaking of speed then, so do you feel like the training time with the TabPF and maybe compared to maybe more traditional models is it like quite a bit faster or is there a certain percentage faster or like did you guys, something you measured?

I did, but it's hard to say because there's really no training time. It doesn't actually estimate, right? So what I did, and I'm gonna reproduce this on a GPU. It's kind of on this slide. I did look to see what the total time from getting your data into the system and getting the final prediction over this size. So if I, I'm predicting 1300 samples in two seconds when I have a training set, that's about 5,000. So per sample, that's pretty quick. So I can imagine that going more exponential like further up this curve as you have like a wider training set or more rows but I can't imagine it's gonna be minutes. So I would suspect it's fairly competitive in terms of time. And there are some things you can tune but I can't imagine that, much like random forest like the default seem to work pretty well. It's possible you can get a little bit more performance by tuning these. But again, I'm not sure what the notion of like resampling is here because it's not really estimating anything. So these are like really weird and interesting problems for this particular model to figure out.