huawei-noah / HEBO

Bayesian optimisation & Reinforcement Learning library developped by Huawei Noah's Ark Lab

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

ValueError: NaN in distribution

NotNANtoN opened this issue · comments

Hi, thanks for this repository! So far it works quite well, but now I suddenly encountered a weird error after 11 optimization steps of non-batched HEBO:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_2773121/4102601230.py in <module>
     35 
     36 for i in range(opt_steps):
---> 37     rec = opt.suggest()
     38     if "bs" in rec:
     39         rec["bs"] = 2 ** rec["bs"]

~/.local/lib/python3.8/site-packages/hebo/optimizers/hebo.py in suggest(self, n_suggestions, fix_input)
    151             sig = Sigma(model, linear_a = -1.)
    152             opt = EvolutionOpt(self.space, acq, pop = 100, iters = 100, verbose = False, es=self.es)
--> 153             rec = opt.optimize(initial_suggest = best_x, fix_input = fix_input).drop_duplicates()
    154             rec = rec[self.check_unique(rec)]
    155 

~/.local/lib/python3.8/site-packages/hebo/acq_optimizers/evolution_optimizer.py in optimize(self, initial_suggest, fix_input, return_pop)
    125         crossover = self.get_crossover()
    126         algo      = get_algorithm(self.es, pop_size = self.pop, sampling = init_pop, mutation = mutation, crossover = crossover, repair = self.repair)
--> 127         res       = minimize(prob, algo, ('n_gen', self.iter), verbose = self.verbose)
    128         if res.X is not None and not return_pop:
    129             opt_x = res.X.reshape(-1, len(lb)).astype(float)

~/.local/lib/python3.8/site-packages/pymoo/optimize.py in minimize(problem, algorithm, termination, copy_algorithm, copy_termination, **kwargs)
     81 
     82     # actually execute the algorithm
---> 83     res = algorithm.run()
     84 
     85     # store the deep copied algorithm in the result object

~/.local/lib/python3.8/site-packages/pymoo/core/algorithm.py in run(self)
    211         # while termination criterion not fulfilled
    212         while self.has_next():
--> 213             self.next()
    214 
    215         # create the result object to be returned

~/.local/lib/python3.8/site-packages/pymoo/core/algorithm.py in next(self)
    231         # call the advance with them after evaluation
    232         if infills is not None:
--> 233             self.evaluator.eval(self.problem, infills, algorithm=self)
    234             self.advance(infills=infills)
    235 

~/.local/lib/python3.8/site-packages/pymoo/core/evaluator.py in eval(self, problem, pop, skip_already_evaluated, evaluate_values_of, count_evals, **kwargs)
     93         # actually evaluate all solutions using the function that can be overwritten
     94         if len(I) > 0:
---> 95             self._eval(problem, pop[I], evaluate_values_of=evaluate_values_of, **kwargs)
     96 
     97             # set the feasibility attribute if cv exists

~/.local/lib/python3.8/site-packages/pymoo/core/evaluator.py in _eval(self, problem, pop, evaluate_values_of, **kwargs)
    110         evaluate_values_of = self.evaluate_values_of if evaluate_values_of is None else evaluate_values_of
    111 
--> 112         out = problem.evaluate(pop.get("X"),
    113                                return_values_of=evaluate_values_of,
    114                                return_as_dictionary=True,

~/.local/lib/python3.8/site-packages/pymoo/core/problem.py in evaluate(self, X, return_values_of, return_as_dictionary, *args, **kwargs)
    122 
    123         # do the actual evaluation for the given problem - calls in _evaluate method internally
--> 124         self.do(X, out, *args, **kwargs)
    125 
    126         # make sure the array is 2d before doing the shape check

~/.local/lib/python3.8/site-packages/pymoo/core/problem.py in do(self, X, out, *args, **kwargs)
    160 
    161     def do(self, X, out, *args, **kwargs):
--> 162         self._evaluate(X, out, *args, **kwargs)
    163         out_to_2d_ndarray(out)
    164 

~/.local/lib/python3.8/site-packages/hebo/acq_optimizers/evolution_optimizer.py in _evaluate(self, x, out, *args, **kwargs)
     46 
     47         with torch.no_grad():
---> 48             acq_eval = self.acq(xcont, xenum).numpy().reshape(num_x, self.acq.num_obj + self.acq.num_constr)
     49             out['F'] = acq_eval[:, :self.acq.num_obj]
     50 

~/.local/lib/python3.8/site-packages/hebo/acquisitions/acq.py in __call__(self, x, xe)
     37 
     38     def __call__(self, x : Tensor,  xe : Tensor):
---> 39         return self.eval(x, xe)
     40 
     41 class SingleObjectiveAcq(Acquisition):

~/.local/lib/python3.8/site-packages/hebo/acquisitions/acq.py in eval(self, x, xe)
    155             normed    = ((self.tau - self.eps - py - noise * torch.randn(py.shape)) / ps)
    156             dist      = Normal(0., 1.)
--> 157             log_phi   = dist.log_prob(normed)
    158             Phi       = dist.cdf(normed)
    159             PI        = Phi

~/.local/lib/python3.8/site-packages/torch/distributions/normal.py in log_prob(self, value)
     71     def log_prob(self, value):
     72         if self._validate_args:
---> 73             self._validate_sample(value)
     74         # compute the variance
     75         var = (self.scale ** 2)

~/.local/lib/python3.8/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
    286         valid = support.check(value)
    287         if not valid.all():
--> 288             raise ValueError(
    289                 "Expected value argument "
    290                 f"({type(value).__name__} of shape {tuple(value.shape)}) "

ValueError: Expected value argument (Tensor of shape (100, 1)) to be within the support (Real()) of the distribution Normal(loc: 0.0, scale: 1.0), but found invalid values:
tensor([[ -1.1836],
        [ -1.2862],
        [-11.6360],
        [-11.3412],
        [  0.3811],
        [ -2.0235],
        [ -1.7288],
        [ -8.3472],
        [-10.1714],
        [ -2.6084],
        [ -0.8098],
        [ -0.9687],
        [ -9.0626],
        [ -2.2273],
        [ -9.0942],
        [ -1.6956],
        [ -6.6197],
        [ -9.3882],
        [ -6.1594],
        [ -9.2895],
        [ -1.7074],
        [  0.8382],
        [-14.6693],
        [ -0.8303],
        [-10.2741],
        [  0.2808],
        [ -9.3681],
        [ -0.6729],
        [ -2.0288],
        [ -1.4389],
        [ -7.1975],
        [-11.5732],
        [-10.2751],
        [ -1.3800],
        [ -1.9773],
        [ -1.4668],
        [ -9.7166],
        [ -8.3093],
        [-15.5914],
        [ -0.0808],
        [  0.3732],
        [-16.2714],
        [ -2.3120],
        [ -8.7503],
        [ -1.6276],
        [     nan],
        [-15.3692],
        [ -9.1615],
        [ -9.8093],
        [ -2.0716],
        [ -1.9259],
        [  0.9543],
        [ -8.1521],
        [ -2.5709],
        [ -1.6153],
        [-10.7236],
        [ -0.0763],
        [  0.0543],
        [ -7.2755],
        [-10.6411],
        [ -7.9253],
        [-19.4996],
        [ -2.0001],
        [-11.7616],
        [-11.0187],
        [-12.0727],
        [ -1.3243],
        [-11.2528],
        [ -1.5527],
        [ -0.9219],
        [ -1.0130],
        [-10.1825],
        [-18.3420],
        [-11.1005],
        [ -8.5818],
        [-11.1588],
        [ -8.8115],
        [ -1.0410],
        [-15.2722],
        [ -1.8399],
        [ -1.0827],
        [ -1.0277],
        [ -6.4768],
        [ -8.3902],
        [ -0.9513],
        [ -1.3429],
        [ -1.0889],
        [ -7.2952],
        [ -7.8548],
        [ -0.0231],
        [ -7.1898],
        [-20.4194],
        [ -1.2503],
        [-19.6157],
        [ -0.3398],
        [-15.7221],
        [-10.3210],
        [ -9.5764],
        [ -0.2335],
        [ -0.3788]])

Seems like there is a NaN in some distribution of HEBO. But my input parameters (opt.X) and losses (opt.y) are never NaN.
This is the design space I'm using:

space = DesignSpace().parse([{'name': 'lr', 'type' : 'num', 'lb' : 0.00005, 'ub' : 0.1},
                                 {'name': 'n_estimators', 'type' : 'int', 'lb' : 1, 'ub' : 20},  # multiplied by 10
                                 {'name': 'max_depth', 'type' : 'int', 'lb' : 1, 'ub' : 10},
                                 {'name': 'subsample', 'type' : 'num', 'lb' : 0.5, 'ub' : 0.99},
                                 {'name': 'colsample_bytree', 'type' : 'num', 'lb' : 0.5, 'ub' : 0.99},
                                 {'name': 'gamma', 'type' : 'num', 'lb' : 0.01, 'ub' : 10.0},
                                 {'name': 'min_child_weight', 'type' : 'int', 'lb' : 1, 'ub' : 10},
                                 
                                 {'name': 'fill_type', 'type' : 'cat', 'categories' : ['median', 'pat_median','pat_ema']},
                                 {'name': 'flat_block_size', 'type' : 'int', 'lb' : 1, 'ub' : 1}
                                ])
    
opt = HEBO(space)

I already commented out flat_block_size as I thought that maybe it is a problem if lb == ub, but it still crashes.

Any ideas on how I can debug this?

It looks like some acquisition function value calculated from Gaussian process model are NaN, this might be due to the numerical stability of GP model. Do you have a sample code that I can use to reproduce this crash?

Based on your design space, I assume you are tuning a XGBoost model, I wrote the below code using hebo.sklearn_tuner.sklearn_tuner to see if anything goes wrong, but it looks like everything was OK

import warnings
warnings.filterwarnings('ignore')
import numpy as np
import matplotlib.pyplot as plt

from hebo.optimizers.hebo import HEBO
from hebo.sklearn_tuner import sklearn_tuner
from hebo.design_space import DesignSpace

from xgboost import XGBRegressor
from sklearn.metrics import r2_score
from sklearn.model_selection import KFold, cross_val_predict
from sklearn.datasets import load_boston

space_cfg = [{'name': 'lr', 'type' : 'num', 'lb' : 0.00005, 'ub' : 0.1},
                                 {'name': 'n_estimators', 'type' : 'int', 'lb' : 10, 'ub' : 200},  # multiplied by 10
                                 {'name': 'max_depth', 'type' : 'int', 'lb' : 1, 'ub' : 10},
                                 {'name': 'subsample', 'type' : 'num', 'lb' : 0.5, 'ub' : 0.99},
                                 {'name': 'colsample_bytree', 'type' : 'num', 'lb' : 0.5, 'ub' : 0.99},
                                 {'name': 'gamma', 'type' : 'num', 'lb' : 0.01, 'ub' : 10.0},
                                 {'name': 'min_child_weight', 'type' : 'int', 'lb' : 1, 'ub' : 10},
                                 {'name': 'fill_type', 'type' : 'cat', 'categories' : ['median', 'pat_median','pat_ema']},
                                 {'name': 'flat_block_size', 'type' : 'int', 'lb' : 1, 'ub' : 1}, 
                                 {'name': 'verbosity', 'type' : 'int', 'lb' : 0, 'ub' : 0}
                                ]
    
X, y = load_boston(return_X_y = True)
cv   = KFold(n_splits = 10, shuffle = True, random_state = 42)
result, report = sklearn_tuner(XGBRegressor, space_cfg, X, y, r2_score, cv = cv, max_iter = 64, report = True)
print(report)
report.metric.plot()
plot.show()

Hi, thanks a lot for your answer!
I tried your code and that seems to run fine. I am not using the sklearn_tuner, instead I use opt.suggestmanually. I modified the code with the boston dataset and that seems to work... But I noticed that it gets quite slow after about 50 steps - is that due to changed hyperparameters or does HEBO take much longer with more iterations?

As for my problem, it still occurs. Unfortunately my dataset is not publicly available. I'm basically running this loop, where the obj function returns an r2_score:

for i in range(opt_steps):
    rec = opt.suggest()
    if "bs" in rec:
        rec["bs"] = 2 ** rec["bs"]
    if "n_estimators" in rec:
        rec["n_estimators"] *= 10
    print(i)
    print(list(zip(rec.columns, rec.values[0])))
    start_time = time.time()
    opt.observe(rec, obj(df, cfg, rec))
    print("Opt time: ", time.time() - start_time)
    min_idx = np.argmin(opt.y)
    print("Current score:", 1 - opt.y[-1][0])
    print("Best score so far:", 1 - opt.y[min_idx][0])
    print(f'After {i} iterations, best obj is {1 - opt.y[min_idx][0]:.4f} with params {opt.X.iloc[min_idx][0]}')
    print()

This is my full output:

0
[('lr', 4.999999873689376e-05), ('n_estimators', 10), ('max_depth', 1), ('subsample', 0.5), ('colsample_bytree', 0.5), ('gamma', 0.009999999776482582), ('min_child_weight', 0.009999999776482582), ('fill_type', 'median')]
Opt time:  8.373520851135254
Current score: 0.06503021650669472
Best score so far: 0.06503021650669472
After 0 iterations, best obj is 0.0650 with params 4.999999873689376e-05

1
[('lr', 0.05002500116825104), ('n_estimators', 100), ('max_depth', 6), ('subsample', 0.7450000047683716), ('colsample_bytree', 0.7450000047683716), ('gamma', 2.504999876022339), ('min_child_weight', 2.504999876022339), ('fill_type', 'pat_ema')]
Opt time:  12.982976198196411
Current score: 0.22393181747644808
Best score so far: 0.22393181747644808
After 1 iterations, best obj is 0.2239 with params 0.05002500116825104

2
[('lr', 0.09783410604757857), ('n_estimators', 90), ('max_depth', 10), ('subsample', 0.9768197084360646), ('colsample_bytree', 0.9498448347173593), ('gamma', 4.438085471364415), ('min_child_weight', 4.949910987243103), ('fill_type', 'pat_ema')]
Opt time:  19.771536111831665
Current score: 0.27596342318731093
Best score so far: 0.27596342318731093
After 2 iterations, best obj is 0.2760 with params 0.09783410604757857

3
[('lr', 0.09534893514045893), ('n_estimators', 20), ('max_depth', 10), ('subsample', 0.958603449440306), ('colsample_bytree', 0.9897392694341535), ('gamma', 4.597594328222324), ('min_child_weight', 4.54544847187763), ('fill_type', 'pat_ema')]
Opt time:  13.476808786392212
Current score: 0.2911832440771226
Best score so far: 0.2911832440771226
After 3 iterations, best obj is 0.2912 with params 0.09534893514045893

4
[('lr', 0.08974096345257622), ('n_estimators', 30), ('max_depth', 10), ('subsample', 0.9875473709445614), ('colsample_bytree', 0.9899865032565488), ('gamma', 4.502709642709985), ('min_child_weight', 0.1675799138458133), ('fill_type', 'pat_ema')]
Opt time:  14.120468378067017
Current score: 0.281744205852656
Best score so far: 0.2911832440771226
After 4 iterations, best obj is 0.2912 with params 0.09534893514045893

5
[('lr', 0.056794210828607895), ('n_estimators', 100), ('max_depth', 5), ('subsample', 0.7841850110546084), ('colsample_bytree', 0.761872148068391), ('gamma', 4.660798376085397), ('min_child_weight', 1.706409948969703), ('fill_type', 'median')]
Opt time:  9.955632448196411
Current score: 0.31906698897163577
Best score so far: 0.31906698897163577
After 5 iterations, best obj is 0.3191 with params 0.056794210828607895

6
[('lr', 0.06194953234522322), ('n_estimators', 190), ('max_depth', 9), ('subsample', 0.5245152196754684), ('colsample_bytree', 0.7979904402789458), ('gamma', 4.999710831827684), ('min_child_weight', 1.31406233972266), ('fill_type', 'median')]
Opt time:  18.016623735427856
Current score: 0.23482889590088885
Best score so far: 0.31906698897163577
After 6 iterations, best obj is 0.3191 with params 0.056794210828607895

7
[('lr', 0.04779509898029476), ('n_estimators', 80), ('max_depth', 5), ('subsample', 0.7323350700727138), ('colsample_bytree', 0.7339707453149883), ('gamma', 4.681703025075991), ('min_child_weight', 1.731618124440871), ('fill_type', 'median')]
Opt time:  9.83765172958374
Current score: 0.3134999604020726
Best score so far: 0.31906698897163577
After 7 iterations, best obj is 0.3191 with params 0.056794210828607895

8
[('lr', 0.026268868648699047), ('n_estimators', 180), ('max_depth', 5), ('subsample', 0.6306731708933159), ('colsample_bytree', 0.7468692282491458), ('gamma', 3.7312811738115284), ('min_child_weight', 1.661273660312038), ('fill_type', 'median')]
Opt time:  11.977625846862793
Current score: 0.34071786913084745
Best score so far: 0.34071786913084745
After 8 iterations, best obj is 0.3407 with params 0.026268868648699047

9
[('lr', 0.017358057301979493), ('n_estimators', 190), ('max_depth', 5), ('subsample', 0.7265885258298752), ('colsample_bytree', 0.755437663728787), ('gamma', 3.294058784514677), ('min_child_weight', 2.0467435360563058), ('fill_type', 'median')]
Opt time:  11.51999807357788
Current score: 0.3195188613478257
Best score so far: 0.34071786913084745
After 9 iterations, best obj is 0.3407 with params 0.026268868648699047

10
[('lr', 0.043367372765737065), ('n_estimators', 200), ('max_depth', 4), ('subsample', 0.6074111596682685), ('colsample_bytree', 0.7127508642287075), ('gamma', 4.236515970760105), ('min_child_weight', 1.2258569949494702), ('fill_type', 'none')]
Opt time:  9.249167919158936
Current score: 0.3131071323890887
Best score so far: 0.34071786913084745
After 10 iterations, best obj is 0.3407 with params 0.026268868648699047

11
[('lr', 0.02077600494088448), ('n_estimators', 200), ('max_depth', 5), ('subsample', 0.8374782341247184), ('colsample_bytree', 0.545661817026894), ('gamma', 3.9953090379449208), ('min_child_weight', 0.7574154138166923), ('fill_type', 'median')]
Opt time:  11.864658832550049
Current score: 0.3306416509958666
Best score so far: 0.34071786913084745
After 11 iterations, best obj is 0.3407 with params 0.026268868648699047

12
[('lr', 0.06935039082110936), ('n_estimators', 50), ('max_depth', 5), ('subsample', 0.6314640561076242), ('colsample_bytree', 0.9636399285390094), ('gamma', 3.6190603244245025), ('min_child_weight', 3.0198621856406427), ('fill_type', 'median')]
Opt time:  9.606015682220459
Current score: 0.3283411937461205
Best score so far: 0.34071786913084745
After 12 iterations, best obj is 0.3407 with params 0.026268868648699047

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_2820306/4048605528.py in <module>
     45 
     46 for i in range(opt_steps):
---> 47     rec = opt.suggest()
     48     if "bs" in rec:
     49         rec["bs"] = 2 ** rec["bs"]

~/.local/lib/python3.8/site-packages/hebo/optimizers/hebo.py in suggest(self, n_suggestions, fix_input)
    151             sig = Sigma(model, linear_a = -1.)
    152             opt = EvolutionOpt(self.space, acq, pop = 100, iters = 100, verbose = False, es=self.es)
--> 153             rec = opt.optimize(initial_suggest = best_x, fix_input = fix_input).drop_duplicates()
    154             rec = rec[self.check_unique(rec)]
    155 

~/.local/lib/python3.8/site-packages/hebo/acq_optimizers/evolution_optimizer.py in optimize(self, initial_suggest, fix_input, return_pop)
    125         crossover = self.get_crossover()
    126         algo      = get_algorithm(self.es, pop_size = self.pop, sampling = init_pop, mutation = mutation, crossover = crossover, repair = self.repair)
--> 127         res       = minimize(prob, algo, ('n_gen', self.iter), verbose = self.verbose)
    128         if res.X is not None and not return_pop:
    129             opt_x = res.X.reshape(-1, len(lb)).astype(float)

~/.local/lib/python3.8/site-packages/pymoo/optimize.py in minimize(problem, algorithm, termination, copy_algorithm, copy_termination, **kwargs)
     81 
     82     # actually execute the algorithm
---> 83     res = algorithm.run()
     84 
     85     # store the deep copied algorithm in the result object

~/.local/lib/python3.8/site-packages/pymoo/core/algorithm.py in run(self)
    211         # while termination criterion not fulfilled
    212         while self.has_next():
--> 213             self.next()
    214 
    215         # create the result object to be returned

~/.local/lib/python3.8/site-packages/pymoo/core/algorithm.py in next(self)
    231         # call the advance with them after evaluation
    232         if infills is not None:
--> 233             self.evaluator.eval(self.problem, infills, algorithm=self)
    234             self.advance(infills=infills)
    235 

~/.local/lib/python3.8/site-packages/pymoo/core/evaluator.py in eval(self, problem, pop, skip_already_evaluated, evaluate_values_of, count_evals, **kwargs)
     93         # actually evaluate all solutions using the function that can be overwritten
     94         if len(I) > 0:
---> 95             self._eval(problem, pop[I], evaluate_values_of=evaluate_values_of, **kwargs)
     96 
     97             # set the feasibility attribute if cv exists

~/.local/lib/python3.8/site-packages/pymoo/core/evaluator.py in _eval(self, problem, pop, evaluate_values_of, **kwargs)
    110         evaluate_values_of = self.evaluate_values_of if evaluate_values_of is None else evaluate_values_of
    111 
--> 112         out = problem.evaluate(pop.get("X"),
    113                                return_values_of=evaluate_values_of,
    114                                return_as_dictionary=True,

~/.local/lib/python3.8/site-packages/pymoo/core/problem.py in evaluate(self, X, return_values_of, return_as_dictionary, *args, **kwargs)
    122 
    123         # do the actual evaluation for the given problem - calls in _evaluate method internally
--> 124         self.do(X, out, *args, **kwargs)
    125 
    126         # make sure the array is 2d before doing the shape check

~/.local/lib/python3.8/site-packages/pymoo/core/problem.py in do(self, X, out, *args, **kwargs)
    160 
    161     def do(self, X, out, *args, **kwargs):
--> 162         self._evaluate(X, out, *args, **kwargs)
    163         out_to_2d_ndarray(out)
    164 

~/.local/lib/python3.8/site-packages/hebo/acq_optimizers/evolution_optimizer.py in _evaluate(self, x, out, *args, **kwargs)
     46 
     47         with torch.no_grad():
---> 48             acq_eval = self.acq(xcont, xenum).numpy().reshape(num_x, self.acq.num_obj + self.acq.num_constr)
     49             out['F'] = acq_eval[:, :self.acq.num_obj]
     50 

~/.local/lib/python3.8/site-packages/hebo/acquisitions/acq.py in __call__(self, x, xe)
     37 
     38     def __call__(self, x : Tensor,  xe : Tensor):
---> 39         return self.eval(x, xe)
     40 
     41 class SingleObjectiveAcq(Acquisition):

~/.local/lib/python3.8/site-packages/hebo/acquisitions/acq.py in eval(self, x, xe)
    155             normed    = ((self.tau - self.eps - py - noise * torch.randn(py.shape)) / ps)
    156             dist      = Normal(0., 1.)
--> 157             log_phi   = dist.log_prob(normed)
    158             Phi       = dist.cdf(normed)
    159             PI        = Phi

~/.local/lib/python3.8/site-packages/torch/distributions/normal.py in log_prob(self, value)
     71     def log_prob(self, value):
     72         if self._validate_args:
---> 73             self._validate_sample(value)
     74         # compute the variance
     75         var = (self.scale ** 2)

~/.local/lib/python3.8/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
    286         valid = support.check(value)
    287         if not valid.all():
--> 288             raise ValueError(
    289                 "Expected value argument "
    290                 f"({type(value).__name__} of shape {tuple(value.shape)}) "

ValueError: Expected value argument (Tensor of shape (100, 1)) to be within the support (Real()) of the distribution Normal(loc: 0.0, scale: 1.0), but found invalid values:
tensor([[-1.9959],
        [-1.2675],
        [-1.2204],
        [-0.5946],
        [-1.3163],
        [-0.8091],
        [-2.3450],
        [-1.1690],
        [-1.2374],
        [-0.5374],
        [-0.8852],
        [-1.5104],
        [-1.8167],
        [ 0.3373],
        [-1.0077],
        [-1.5388],
        [ 0.9909],
        [-0.9809],
        [-1.0140],
        [-0.1807],
        [-0.5176],
        [-0.3398],
        [-1.5057],
        [-1.3493],
        [-1.3827],
        [-0.7947],
        [-2.6809],
        [-0.7844],
        [-1.4292],
        [-0.8269],
        [-1.6755],
        [-1.6348],
        [-0.7895],
        [-0.8264],
        [-1.3902],
        [-0.5924],
        [-1.4093],
        [-0.8154],
        [ 0.2801],
        [-0.6707],
        [-1.0585],
        [-1.5289],
        [-1.2883],
        [-0.6418],
        [-3.6011],
        [    nan],
        [-1.3098],
        [-2.6957],
        [-0.9912],
        [ 0.4284],
        [-1.6822],
        [-0.5964],
        [-0.1601],
        [-1.2632],
        [-0.8173],
        [-0.1966],
        [ 1.8093],
        [ 0.5075],
        [-0.6223],
        [-1.1435],
        [-0.7424],
        [-1.6756],
        [ 1.7556],
        [-1.5124],
        [-1.4938],
        [-0.6549],
        [-0.6919],
        [-0.4789],
        [-1.6914],
        [-1.8472],
        [-0.3958],
        [-1.9369],
        [-1.5689],
        [-0.7813],
        [-0.8114],
        [-0.9482],
        [-0.9427],
        [-1.5766],
        [-0.6994],
        [-1.2480],
        [-1.1529],
        [-1.0359],
        [-1.6211],
        [-1.1925],
        [-0.7662],
        [-0.9530],
        [-0.0925],
        [ 0.1829],
        [-1.6802],
        [-1.7956],
        [-1.6634],
        [-1.8606],
        [-1.1047],
        [-0.5844],
        [-1.0566],
        [-1.6968],
        [-0.9914],
        [-0.8555],
        [-1.4518],
        [-1.6394]])

And opt.yis:

array([[0.93496978],
       [0.77606818],
       [0.72403658],
       [0.70881676],
       [0.71825579],
       [0.68093301],
       [0.7651711 ],
       [0.68650004],
       [0.65928213],
       [0.68048114],
       [0.68689287],
       [0.66935835],
       [0.67165881]])

I'm quite sure that this error is not dependent on XGBoost because it also happens when I train an RNN. It even crashed there on the third suggeststep.

I really don't know what is happening here. I'm now trying different surrogate models. As far as I can see there are gpy(default), gp, gpy_mlp and rf. As the error seems related to the Gaussian process I'm trying rffirst. But are there any metrics and evaluations of how well this performs? I could not really find anything in the arXiv paper at https://arxiv.org/pdf/2012.03826.pdf
rfdid not crash so far, even after 70 steps - but if it performs worse then I of course don't want to use it.

10
[('lr', 0.043367372765737065), ('n_estimators', 200), ('max_depth', 4), ('subsample', 0.6074111596682685), ('colsample_bytree', 0.7127508642287075), ('gamma', 4.236515970760105), ('min_child_weight', 1.2258569949494702), ('fill_type', 'none')]

Why is the fill_type being 'none'?

@MdAsifKhan I think I have found the reason, it's because these lines of code

    if "bs" in rec:
        rec["bs"] = 2 ** rec["bs"]
    if "n_estimators" in rec:
        rec["n_estimators"] *= 10

By doing this, you modified rec, so the rec you passed to observe is not same with the one returned by suggest, for example, your n_estimators is defined within [1,20], but the n_estimators you passed to observe would be within [10,200]

It looks like that you want n_estimators to be multiples of 10, and you want bs to be integer power of 2. Actually HEBO has built-in support for these requirement so you don't need to do the manual transformation.

You can write the space configurations like this

import pandas as pd
import numpy as np
from hebo.design_space import DesignSpace
from hebo.optimizers.hebo import HEBO
np.random.seed(42)
space = DesignSpace().parse([
    {'name': 'n_estimators', 'type' : 'step_int', 'lb' : 10, 'ub' : 200, 'step' : 10},  # multiplied by 10
    {'name': 'bs', 'type' : 'int_exponent', 'lb' : 16, 'ub' : 1024, 'base' : 2},  # 2**(int)
    ])
print(space.sample(10))

The output would be like

   n_estimators   bs
0            70   64
1           200  512
2           150  256
3           110   32
4            80  128
5            70  512
6           190  512
7           110   32
8           110  128
9            40  256

Ah thank you so much! That fixes the issue for me. Possibly you could raise a warning if the rec given to obj differs from the last output of suggest?

Also thanks for the distribution tips - I wanted to do it in my way as I want to compare optuna and HEBO and therefore want to modify these parameters independently.