Predict using TabPFN
Value
predict() returns a tibble of predictions and augment() appends the
columns in new_data. In either case, the number of rows in the tibble is
guaranteed to be the same as the number of rows in new_data.
For regression data, the prediction is in the column .pred. For
classification, the class predictions are in .pred_class and the
probability estimates are in columns with the pattern .pred_{level} where
level is the levels of the outcome factor vector.
Examples
# Minimal example for quick execution
car_train <- mtcars[ 1:5, ]
car_test <- mtcars[10:15, -1]
# Fit
if (is_tab_pfn_installed() & interactive()) {
mod <- tab_pfn(mpg ~ cyl + log(drat), car_train)
# Predict
predict(mod, car_test)
augment(mod, car_test)
}