Skip to content

Commit

Permalink
fix xgboost device flag
Browse files Browse the repository at this point in the history
  • Loading branch information
mieskolainen committed Nov 7, 2023
1 parent 166ad3b commit f615873
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion configs/hnl/models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ xgb0:
booster: 'gbtree' # 'gbtree' (default), 'dart' (dropout boosting)
tree_method: 'hist'
device: 'auto' # 'auto', 'cpu', 'cuda'

learning_rate: 0.1
gamma: 1.67
max_depth: 8
Expand Down
2 changes: 1 addition & 1 deletion configs/trg/models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ xgb0:

booster: 'gbtree' # 'gbtree' (default), 'dart' (dropout boosting)
tree_method: 'hist'
device: 'auto' # 'auto', 'cpu:0', 'cuda:0'
device: 'auto' # 'auto', 'cpu', 'cuda:0'

learning_rate: 0.1
gamma: 1.67
Expand Down
4 changes: 2 additions & 2 deletions icenet/deep/iceboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,10 @@ def train_xgb(config={'params': {}}, data_trn=None, data_val=None, y_soft=None,
# ---------------------------------------------------

if param['model_param']['device'] == 'auto':
param['model_param'].update({'device': 'cuda:0' if torch.cuda.is_available() else 'cpu:0'})
param['model_param'].update({'device': 'cuda' if torch.cuda.is_available() else 'cpu'})

print(__name__ + f'.train_xgb: Training <{param["label"]}> classifier ...')

### ** Optimization hyperparameters [possibly from Raytune] **
param['model_param'] = aux.replace_param(default=param['model_param'], raytune=config['params'])

Expand Down

0 comments on commit f615873

Please sign in to comment.