TabPFNClassifier causes segmentation faults
ahayler opened this issue · comments
Describe the bug
I have repeatedly observed that TabPFN causes (non-deterministic) segmentation faults in my pipeline. Similar issues have been reported by other users (see here).
While the suggested fix of setting "OMP_NUM_THREADS" to 1, seems to make the pipeline more stable, I still observed segmentation faults in my rather complex pipeline. After quite a bit of debugging, I have now managed to create a "minimal" reproducible example that produces segmentation faults eventually during the run.
Steps/Code to Reproduce
The compute node I am running this script on has around 750GB of RAM and I have observed this behaviour on multiple different L40S GPUs installed in the compute node.
#!/usr/bin/env python3
"""
Minimal example demonstrating segmentation errors with TabPFN.
This script reproduces segmentation errors that occur when using TabPFN
with datasets containing:
- 400 features
- 5000 training samples
- Large numbers of test samples
Usage:
python misc/segmentation_error_simple_example.py
"""
import numpy as np
import os
import faulthandler
faulthandler.enable()
# https://github.com/PriorLabs/TabPFN/issues/328
os.environ["OMP_NUM_THREADS"] = "1"
from tabpfn import TabPFNClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from tabpfn import display_debug_info
def main():
display_debug_info()
# Create a dataset that triggers segmentation errors
n_train_samples = 5000
n_features = 400
n_classes = 5
n_runs_per_test_size = 3
for i in range(10):
print(f"RUN {i + 1}")
for n_test_samples in [100, 500, 1000, 5000, 10000, 50000]:
y_pred_proba_list = []
for run in range(n_runs_per_test_size):
random_state = 42 + run * 1000
print(f"Running run {run + 1}/{n_runs_per_test_size} for {n_test_samples} test samples...")
# Generate synthetic data
X, y = make_classification(
n_samples=n_train_samples + n_test_samples,
n_features=n_features,
n_informative=50,
n_redundant=20,
n_classes=n_classes,
n_clusters_per_class=1,
class_sep=0.8,
random_state=random_state
)
# Split into train/test
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=n_test_samples / (n_train_samples + n_test_samples),
random_state=random_state,
stratify=y
)
clf = TabPFNClassifier(device='cuda:0', random_state=random_state)
clf.fit(X_train, y_train)
print(f"Running inference on {n_test_samples} test samples...")
y_pred_proba = clf.predict_proba(X_test)
y_pred_proba_list.append(y_pred_proba)
y_pred = np.argmax(y_pred_proba, axis=1)
accuracy = np.mean(y_pred == y_test)
print(f"Accuracy: {accuracy:.3f} for {n_test_samples} test samples")
if __name__ == "__main__":
main()To control for issues that arise with interactions with other packages, I created a minimal conda env to run this script. The .yaml is the following:
name: tabpfn_segfault
channels:
- conda-forge
- pytorch
dependencies:
- python=3.11
- numpy
- scikit-learn
- pip
- pip:
- tabpfn
Expected Results
The script runs through without a segmentation fault.
Actual Results
As stated above, when exactly a segmentation fault occurs is non-deterministic. The output below is from one of my runs. I have not been able to run the above script until termination without a segmentation fault occuring.
[COPIED TABPFN_INFO TO THE FIELD BELOW]
[OMITTED RUN 1-3 FOR READABLILITY]
RUN 4
Running run 1/3 for 100 test samples...
Running inference on 100 test samples...
Accuracy: 0.810 for 100 test samples
Running run 2/3 for 100 test samples...
Running inference on 100 test samples...
Accuracy: 0.830 for 100 test samples
Running run 3/3 for 100 test samples...
Running inference on 100 test samples...
Accuracy: 0.810 for 100 test samples
Running run 1/3 for 500 test samples...
Running inference on 500 test samples...
Accuracy: 0.778 for 500 test samples
Running run 2/3 for 500 test samples...
Running inference on 500 test samples...
Accuracy: 0.844 for 500 test samples
Running run 3/3 for 500 test samples...
Running inference on 500 test samples...
Accuracy: 0.836 for 500 test samples
Running run 1/3 for 1000 test samples...
Running inference on 1000 test samples...
Accuracy: 0.789 for 1000 test samples
Running run 2/3 for 1000 test samples...
Running inference on 1000 test samples...
Accuracy: 0.854 for 1000 test samples
Running run 3/3 for 1000 test samples...
Running inference on 1000 test samples...
Accuracy: 0.834 for 1000 test samples
Running run 1/3 for 5000 test samples...
Running inference on 5000 test samples...
Accuracy: 0.816 for 5000 test samples
Running run 2/3 for 5000 test samples...
Running inference on 5000 test samples...
Accuracy: 0.840 for 5000 test samples
Running run 3/3 for 5000 test samples...
Running inference on 5000 test samples...
Accuracy: 0.860 for 5000 test samples
Running run 1/3 for 10000 test samples...
Running inference on 10000 test samples...
Fatal Python error: Segmentation fault
Current thread 0x00007fa63e9db740 (most recent call first):
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/mlp.py", line 97 in _compute
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/memory.py", line 100 in method_
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/mlp.py", line 132 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762 in _call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751 in _wrapped_call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/layer.py", line 440 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762 in _call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751 in _wrapped_call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/transformer.py", line 89 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762 in _call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751 in _wrapped_call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/transformer.py", line 605 in _forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/transformer.py", line 383 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762 in _call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751 in _wrapped_call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/inference.py", line 512 in iter_outputs
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/classifier.py", line 754 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/classifier.py", line 685 in predict_proba
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/contextlib.py", line 81 in inner
File "/workspaces/[some_repo]/misc/segmentation_error_simple_example.py", line 77 in main
File "/workspaces/[some_repo]/misc/segmentation_error_simple_example.py", line 84 in <module>
Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, sklearn.__check_build._check_build, scipy._lib._ccallback_c, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._pcg64, numpy.random._mt19937, numpy.random._generator, numpy.random._philox, numpy.random._sfc64, numpy.random.mtrand, charset_normalizer.md, scipy.sparse._sparsetools, _csparsetools, _cyutility, scipy._cyutility, scipy.sparse._csparsetools, scipy.special._ufuncs_cxx, scipy.special._ellip_harm_2, scipy.special._special_ufuncs, scipy.special._gufuncs, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_schur_sqrtm, scipy.linalg._matfuncs_expm, scipy.linalg._linalg_pythran, scipy.linalg.cython_blas, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._hausdorff, scipy.spatial._distance_wrap, scipy.spatial.transform._rotation, scipy.spatial.transform._rigid_transform, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._slsqplib, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy._lib._uarray._uarray, scipy.linalg._decomp_interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.interpolate._fitpack, scipy.interpolate._dfitpack, scipy.interpolate._dierckx, scipy.interpolate._ppoly, scipy.interpolate._interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.special.cython_special, scipy.stats._stats, scipy.stats._biasedurn, scipy.stats._stats_pythran, scipy.stats._levy_stable.levyst, scipy.stats._ansari_swilk_statistics, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._rcont.rcont, scipy.stats._qmvnt_cy, scipy.ndimage._nd_image, scipy.ndimage._rank_filter_1d, _ni_label, scipy.ndimage._ni_label, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.lib, pandas._libs.ops, pandas._libs.hashing, pandas._libs.arrays, pandas._libs.tslib, pandas._libs.sparse, pandas._libs.internals, pandas._libs.indexing, pandas._libs.index, pandas._libs.writers, pandas._libs.join, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.json, pandas._libs.parsers, pandas._libs.testing, sklearn.utils._isfinite, sklearn.utils.sparsefuncs_fast, sklearn.utils.murmurhash, sklearn.utils._openmp_helpers, sklearn.preprocessing._csr_polynomial_expansion, sklearn.preprocessing._target_encoder_fast, sklearn.utils._random, sklearn.utils._seq_dataset, sklearn.metrics.cluster._expected_mutual_info_fast, sklearn.metrics._dist_metrics, sklearn.metrics._pairwise_distances_reduction._datasets_pair, sklearn.utils._cython_blas, sklearn.metrics._pairwise_distances_reduction._base, sklearn.metrics._pairwise_distances_reduction._middle_term_computer, sklearn.utils._heap, sklearn.utils._sorting, sklearn.metrics._pairwise_distances_reduction._argkmin, sklearn.metrics._pairwise_distances_reduction._argkmin_classmode, sklearn.utils._vector_sentinel, sklearn.metrics._pairwise_distances_reduction._radius_neighbors, sklearn.metrics._pairwise_distances_reduction._radius_neighbors_classmode, sklearn.metrics._pairwise_fast, sklearn.linear_model._cd_fast, _loss, sklearn._loss._loss, sklearn.utils.arrayfuncs, sklearn.svm._liblinear, sklearn.svm._libsvm, sklearn.svm._libsvm_sparse, sklearn.linear_model._sag_fast, sklearn.utils._weight_vector, sklearn.linear_model._sgd_fast, sklearn.decomposition._online_lda_fast, sklearn.decomposition._cdnmf_fast, sklearn.neighbors._partition_nodes, sklearn.neighbors._ball_tree, sklearn.neighbors._kd_tree, sklearn._isotonic, sklearn.utils._fast_dict, sklearn.cluster._hierarchical_fast, sklearn.cluster._k_means_common, sklearn.cluster._k_means_elkan, sklearn.cluster._k_means_lloyd, sklearn.cluster._k_means_minibatch, sklearn.cluster._dbscan_inner, sklearn.cluster._hdbscan._tree, sklearn.cluster._hdbscan._linkage, sklearn.cluster._hdbscan._reachability, sklearn.tree._utils, sklearn.tree._tree, sklearn.tree._partitioner, sklearn.tree._splitter, sklearn.tree._criterion, sklearn.neighbors._quad_tree, sklearn.manifold._barnes_hut_tsne, sklearn.manifold._utils, sklearn.ensemble._gradient_boosting, sklearn.ensemble._hist_gradient_boosting.common, sklearn.ensemble._hist_gradient_boosting._gradient_boosting, sklearn.ensemble._hist_gradient_boosting._binning, sklearn.ensemble._hist_gradient_boosting._bitset, sklearn.ensemble._hist_gradient_boosting.histogram, sklearn.ensemble._hist_gradient_boosting._predictor, sklearn.ensemble._hist_gradient_boosting.splitting, scipy.io.matlab._mio_utils, scipy.io.matlab._streams, scipy.io.matlab._mio5_utils, sklearn.datasets._svmlight_format_fast, sklearn.feature_extraction._hashing_fast (total: 218)
Segmentation fault (core dumped)
Versions
Collecting system and dependency information...
PyTorch version: 2.7.1+cu126
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 12 (bookworm) (x86_64)
GCC version: (Debian 12.2.0-14+deb12u1) 12.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.36
Python version: 3.11.13 | packaged by conda-forge | (main, Jun 4 2025, 14:48:23) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-143-generic-x86_64-with-glibc2.36
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA L40S
Nvidia driver version: 575.57.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 256
On-line CPU(s) list: 0-255
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9554 64-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 2
Core(s) per socket: 64
Socket(s): 2
Stepping: 1
BogoMIPS: 6190.91
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d
Virtualization: AMD-V
L1d cache: 4 MiB (128 instances)
L1i cache: 4 MiB (128 instances)
L2 cache: 128 MiB (128 instances)
L3 cache: 512 MiB (16 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-63,128-191
NUMA node1 CPU(s): 64-127,192-255
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Dependency Versions:
--------------------
tabpfn: 2.1.0
torch: 2.7.1
numpy: 2.3.1
scipy: 1.16.0
pandas: 2.3.1
scikit-learn: 1.6.1
typing_extensions: 4.14.1
einops: 0.8.1
huggingface-hub: 0.33.4For good measure, I tried to run this script on a different compute node with torch.cuda.synchronize added before and after the fit and predict_proba step. Sadly, this neither fixes the error nor leads to a more informative error.
Full error log
Collecting system and dependency information...
PyTorch version: 2.7.1+cu126
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 12 (bookworm) (x86_64)
GCC version: (Debian 12.2.0-14+deb12u1) 12.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.36
Python version: 3.11.13 | packaged by conda-forge | (main, Jun 4 2025, 14:48:23) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-139-generic-x86_64-with-glibc2.36
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A40
Nvidia driver version: 560.35.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7513 32-Core Processor
CPU family: 25
Model: 1
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
Stepping: 1
Frequency boost: enabled
CPU(s) scaling MHz: 73%
CPU max MHz: 3681.6399
CPU min MHz: 1500.0000
BogoMIPS: 5190.47
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca
Virtualization: AMD-V
L1d cache: 2 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 32 MiB (64 instances)
L3 cache: 256 MiB (8 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Dependency Versions:
--------------------
tabpfn: 2.1.0
torch: 2.7.1
numpy: 2.3.1
scipy: 1.16.0
pandas: 2.3.1
scikit-learn: 1.6.1
typing_extensions: 4.14.1
einops: 0.8.1
huggingface-hub: 0.33.4
RUN 1
Running run 1/3 for 100 test samples...
Running inference on 100 test samples...
Accuracy: 0.810 for 100 test samples
Running run 2/3 for 100 test samples...
Running inference on 100 test samples...
Accuracy: 0.840 for 100 test samples
Running run 3/3 for 100 test samples...
Running inference on 100 test samples...
Accuracy: 0.810 for 100 test samples
Running run 1/3 for 500 test samples...
Running inference on 500 test samples...
Accuracy: 0.778 for 500 test samples
Running run 2/3 for 500 test samples...
Running inference on 500 test samples...
Accuracy: 0.846 for 500 test samples
Running run 3/3 for 500 test samples...
Running inference on 500 test samples...
Accuracy: 0.836 for 500 test samples
Running run 1/3 for 1000 test samples...
Running inference on 1000 test samples...
Accuracy: 0.788 for 1000 test samples
Running run 2/3 for 1000 test samples...
Running inference on 1000 test samples...
Accuracy: 0.854 for 1000 test samples
Running run 3/3 for 1000 test samples...
Running inference on 1000 test samples...
Accuracy: 0.834 for 1000 test samples
Running run 1/3 for 5000 test samples...
Running inference on 5000 test samples...
Accuracy: 0.816 for 5000 test samples
Running run 2/3 for 5000 test samples...
Running inference on 5000 test samples...
Accuracy: 0.840 for 5000 test samples
Running run 3/3 for 5000 test samples...
Running inference on 5000 test samples...
Accuracy: 0.859 for 5000 test samples
Running run 1/3 for 10000 test samples...
Running inference on 10000 test samples...
Accuracy: 0.812 for 10000 test samples
Running run 2/3 for 10000 test samples...
Running inference on 10000 test samples...
Fatal Python error: Segmentation fault
Current thread 0x00007ffbdb262540 (most recent call first):
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/mlp.py", line 97 in _compute
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/memory.py", line 100 in method_
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/mlp.py", line 132 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762 in _call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751 in _wrapped_call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/layer.py", line 440 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762 in _call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751 in _wrapped_call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/transformer.py", line 89 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762 in _call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751 in _wrapped_call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/transformer.py", line 605 in _forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/model/transformer.py", line 383 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1762 in _call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1751 in _wrapped_call_impl
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/inference.py", line 512 in iter_outputs
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/classifier.py", line 754 in forward
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/tabpfn/classifier.py", line 685 in predict_proba
File "/opt/conda/envs/tabpfn_segfault/lib/python3.11/contextlib.py", line 81 in inner
File "/workspaces/[some_repo]/misc/segmentation_error_simple_example.py", line 90 in main
File "/workspaces/[some_repo]/misc/segmentation_error_simple_example.py", line 101 in <module>
Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, google._upb._message, yaml._yaml, charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, sklearn.__check_build._check_build, scipy._lib._ccallback_c, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._pcg64, numpy.random._mt19937, numpy.random._generator, numpy.random._philox, numpy.random._sfc64, numpy.random.mtrand, scipy.sparse._sparsetools, _csparsetools, _cyutility, scipy._cyutility, scipy.sparse._csparsetools, scipy.special._ufuncs_cxx, scipy.special._ellip_harm_2, scipy.special._special_ufuncs, scipy.special._gufuncs, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_schur_sqrtm, scipy.linalg._matfuncs_expm, scipy.linalg._linalg_pythran, scipy.linalg.cython_blas, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._hausdorff, scipy.spatial._distance_wrap, scipy.spatial.transform._rotation, scipy.spatial.transform._rigid_transform, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._slsqplib, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy._lib._uarray._uarray, scipy.linalg._decomp_interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.interpolate._fitpack, scipy.interpolate._dfitpack, scipy.interpolate._dierckx, scipy.interpolate._ppoly, scipy.interpolate._interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.special.cython_special, scipy.stats._stats, scipy.stats._biasedurn, scipy.stats._stats_pythran, scipy.stats._levy_stable.levyst, scipy.stats._ansari_swilk_statistics, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._rcont.rcont, scipy.stats._qmvnt_cy, scipy.ndimage._nd_image, scipy.ndimage._rank_filter_1d, _ni_label, scipy.ndimage._ni_label, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.lib, pandas._libs.ops, pandas._libs.hashing, pandas._libs.arrays, pandas._libs.tslib, pandas._libs.sparse, pandas._libs.internals, pandas._libs.indexing, pandas._libs.index, pandas._libs.writers, pandas._libs.join, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.json, pandas._libs.parsers, pandas._libs.testing, sklearn.utils._isfinite, sklearn.utils.sparsefuncs_fast, sklearn.utils.murmurhash, sklearn.utils._openmp_helpers, sklearn.preprocessing._csr_polynomial_expansion, sklearn.preprocessing._target_encoder_fast, sklearn.utils._random, sklearn.utils._seq_dataset, sklearn.metrics.cluster._expected_mutual_info_fast, sklearn.metrics._dist_metrics, sklearn.metrics._pairwise_distances_reduction._datasets_pair, sklearn.utils._cython_blas, sklearn.metrics._pairwise_distances_reduction._base, sklearn.metrics._pairwise_distances_reduction._middle_term_computer, sklearn.utils._heap, sklearn.utils._sorting, sklearn.metrics._pairwise_distances_reduction._argkmin, sklearn.metrics._pairwise_distances_reduction._argkmin_classmode, sklearn.utils._vector_sentinel, sklearn.metrics._pairwise_distances_reduction._radius_neighbors, sklearn.metrics._pairwise_distances_reduction._radius_neighbors_classmode, sklearn.metrics._pairwise_fast, sklearn.linear_model._cd_fast, _loss, sklearn._loss._loss, sklearn.utils.arrayfuncs, sklearn.svm._liblinear, sklearn.svm._libsvm, sklearn.svm._libsvm_sparse, sklearn.linear_model._sag_fast, sklearn.utils._weight_vector, sklearn.linear_model._sgd_fast, sklearn.decomposition._online_lda_fast, sklearn.decomposition._cdnmf_fast, sklearn.neighbors._partition_nodes, sklearn.neighbors._ball_tree, sklearn.neighbors._kd_tree, sklearn._isotonic, sklearn.utils._fast_dict, sklearn.cluster._hierarchical_fast, sklearn.cluster._k_means_common, sklearn.cluster._k_means_elkan, sklearn.cluster._k_means_lloyd, sklearn.cluster._k_means_minibatch, sklearn.cluster._dbscan_inner, sklearn.cluster._hdbscan._tree, sklearn.cluster._hdbscan._linkage, sklearn.cluster._hdbscan._reachability, sklearn.tree._utils, sklearn.tree._tree, sklearn.tree._partitioner, sklearn.tree._splitter, sklearn.tree._criterion, sklearn.neighbors._quad_tree, sklearn.manifold._barnes_hut_tsne, sklearn.manifold._utils, sklearn.ensemble._gradient_boosting, sklearn.ensemble._hist_gradient_boosting.common, sklearn.ensemble._hist_gradient_boosting._gradient_boosting, sklearn.ensemble._hist_gradient_boosting._binning, sklearn.ensemble._hist_gradient_boosting._bitset, sklearn.ensemble._hist_gradient_boosting.histogram, sklearn.ensemble._hist_gradient_boosting._predictor, sklearn.ensemble._hist_gradient_boosting.splitting, scipy.io.matlab._mio_utils, scipy.io.matlab._streams, scipy.io.matlab._mio5_utils, sklearn.datasets._svmlight_format_fast, sklearn.feature_extraction._hashing_fast (total: 222)
Segmentation fault (core dumped)
A temporary solution
I have tried to further diagnose the segmentation fault. For this, I used gdb and ran the above script a few times. The backtraces all look very similar, you can find one example below:
gdb backtrace example
Accuracy: 0.816 for 5000 test samples
Running run 2/3 for 5000 test samples...
Running inference on 5000 test samples...
Accuracy: 0.840 for 5000 test samples
Running run 3/3 for 5000 test samples...
Running inference on 5000 test samples...
Accuracy: 0.859 for 5000 test samples
Running run 1/3 for 10000 test samples...
Running inference on 10000 test samples...
Accuracy: 0.812 for 10000 test samples
Running run 2/3 for 10000 test samples...
Running inference on 10000 test samples...
Thread 1 "python" received signal SIGSEGV, Segmentation fault.
0x00007f159f1d3193 in c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::malloc(signed char, unsigned long, CUstream_st*) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libc10_cuda.so
(gdb) bt
#0 0x00007f159f1d3193 in c10::cuda::CUDACachingAllocator::Native::DeviceCachingAllocator::malloc(signed char, unsigned long, CUstream_st*) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libc10_cuda.so
#1 0x00007f159f1d4b85 in c10::cuda::CUDACachingAllocator::Native::NativeCachingAllocator::malloc(void**, signed char, unsigned long, CUstream_st*) () from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libc10_cuda.so
#2 0x00007f159f1d524f in c10::cuda::CUDACachingAllocator::Native::NativeCachingAllocator::allocate(unsigned long) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libc10_cuda.so
#3 0x00007f1583b3d15a in c10::StorageImpl::StorageImpl(c10::StorageImpl::use_byte_size_t, c10::SymInt const&, c10::Allocator*, bool) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
#4 0x00007f1583b3dde5 in at::TensorBase at::detail::_empty_generic<long>(c10::ArrayRef<long>, c10::Allocator*, c10::DispatchKeySet, c10::ScalarType, std::optional<c10::MemoryFormat>) () from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
#5 0x00007f1583b399b4 in at::detail::empty_generic(c10::ArrayRef<long>, c10::Allocator*, c10::DispatchKeySet, c10::ScalarType, std::optional<c10::MemoryFormat>) () from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
#6 0x00007f1549623e1c in at::detail::empty_cuda(c10::ArrayRef<long>, c10::ScalarType, std::optional<c10::Device>, std::optional<c10::MemoryFormat>) () from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
#7 0x00007f1549623fae in at::detail::empty_cuda(c10::ArrayRef<long>, std::optional<c10::ScalarType>, std::optional<c10::Layout>, std::optional<c10::Device>, std::optional<bool>, std::optional<c10::MemoryFormat>) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
#8 0x00007f15496240d9 in at::detail::empty_cuda(c10::ArrayRef<long>, c10::TensorOptions const&) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
#9 0x00007f154bdf0d2b in at::(anonymous namespace)::create_out(c10::ArrayRef<long>, c10::ArrayRef<long>, c10::TensorOptions const&) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
#10 0x00007f154bee79e9 in at::(anonymous namespace)::structured_gelu_out_cuda_functional::set_output_raw_strided(long, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::TensorOptions, c10::ArrayRef<at::Dimname>) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
#11 0x00007f1583be6ed8 in at::TensorIteratorBase::fast_set_up(at::TensorIteratorConfig const&) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
#12 0x00007f1583be9fca in at::TensorIteratorBase::build(at::TensorIteratorConfig&) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
#13 0x00007f1583beca47 in at::TensorIteratorBase::build_unary_op(at::TensorBase const&, at::TensorBase const&) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
#14 0x00007f154be71176 in at::(anonymous namespace)::wrapper_CUDA_gelu(at::Tensor const&, std::basic_string_view<char, std::char_traits<char> >)
() from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
#15 0x00007f154be71233 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, std::basic_string_view<char, std::char_traits<char> >), &at::(anonymous namespace)::wrapper_CUDA_gelu>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, std::basic_string_view<char, std::char_traits<char> > > >, at::Tensor (at::Tensor const&, std::basic_string_view<char, std::char_traits<char> >)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, std::basic_string_view<char, std::char_traits<char> >) () from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so
#16 0x00007f1584945ef6 in at::_ops::gelu::call(at::Tensor const&, std::basic_string_view<char, std::char_traits<char> >) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so
#17 0x00007f15968e2eff in torch::autograd::THPVariable_gelu(_object*, _object*, _object*) ()
from /opt/conda/envs/tabpfn_segfault/lib/python3.11/site-packages/torch/lib/libtorch_python.so
#18 0x0000558f42eed196 in cfunction_call (func=0x7f159a9d91c0, args=<optimized out>, kwargs=<optimized out>)
at /usr/local/src/conda/python-3.11.13/Objects/methodobject.c:542
#19 0x0000558f42eca6ab in _PyObject_MakeTpCall (tstate=0x558f432654d8 <_PyRuntime+166328>, callable=0x7f159a9d91c0, args=0x7f19e2f71008,
nargs=<optimized out>, keywords=0x0) at /usr/local/src/conda/python-3.11.13/Objects/call.c:214
#20 0x0000558f42ed7d2a in _PyEval_EvalFrameDefault (tstate=tstate@entry=0x558f432654d8 <_PyRuntime+166328>, frame=<optimized out>, Given that this is my first time working with gdb and cuda, it was a bit difficult to go further from here. As far as I can interpret the backtrace, the error is somehow related to cached memory in cuda. When I logged the cached and allocated memory during my runs, I quickly noticed that TabPFN seems to produce very large amounts of cached memory (basically all VRAM that isn't allocated is cached). While this is generally not a problem, I suspect that it is somehow related to the error above.
Therefore, I did run the above script using torch.cuda.empty_cache() after every clf.predict_proba. Surprisingly, this actually seems to prevent the segmentation faults.
Of course, this more of a band-aid than a permanent solution. Still, I hope that this helps anyone who is facing similar issues.
Things I tried that didn't work
#194 mentions this issue, which was closed, but for which the fix never got merged. One suggested workaround in the issue was to disable the mkldnn backend. While I tried this approach, it did not work for me.
Thanks for the detailed investigation and for finding a workaround, really appreciate it!
We realize that at the moment the model is quite memory hungry and activation tensor allocations depend on number of features as well as number of samples - I guess the large tensor allocations cause the segfault here. Given the performance hit of torch.cuda.empty_cache() I'm a bit reluctant adding this to the inference code directly for now. Will discuss with the team what a long term solution might be here.
Hey @priorphil,
Thanks for getting back! I agree with you regarding adding torch.cuda.empty_cache() to code base. This should be, at best, a temporary workaround. Do you have any idea why I am getting a segfault instead of a CUDA - out of memory error?
I've seen segfaults before as well, so while not common it's also not uncommon :) Back then all I could find was "that can have many reasons", like the ones mentioned in this post:
https://forums.developer.nvidia.com/t/segmentation-fault-when-calling-backward-after-moving-data-to-gpu-pytorch-cuda-12-1/328464/3