-
Notifications
You must be signed in to change notification settings - Fork 0
/
xgboost_callback.rb
91 lines (72 loc) · 2.79 KB
/
xgboost_callback.rb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
module Wandb
class XGBoostCallback
MINIMIZE_METRICS = %w[rmse logloss error] # Add other metrics as needed
MAXIMIZE_METRICS = %w[auc accuracy] # Add other metrics as needed
def initialize(log_model: false, log_feature_importance: true, importance_type: "gain", define_metric: true)
@log_model = log_model
@log_feature_importance = log_feature_importance
@importance_type = importance_type
@define_metric = define_metric
return if Wandb.current_run
raise "You must call wandb.init() before WandbCallback()"
end
def before_training(model:)
# Update Wandb config with model configuration
Wandb.current_run.config = model.params
Wandb.log(model.params)
end
def after_training(model:)
# Log the model as an artifact
log_model_as_artifact(model) if @log_model
# Log feature importance
log_feature_importance(model) if @log_feature_importance
# Log best score and best iteration
return unless model.best_score
Wandb.log(
"best_score" => model.best_score.to_f,
"best_iteration" => model.best_iteration.to_i
)
end
def before_iteration(model:, epoch:, evals:)
# noop
end
def after_iteration(model:, epoch:, evals:, res:)
res.each do |metric_name, value|
data, metric = metric_name.split("-", 2)
full_metric_name = "#{data}-#{metric}"
if @define_metric
define_metric(data, metric)
Wandb.log({ full_metric_name => value })
else
Wandb.log({ full_metric_name => value })
end
end
Wandb.log({ "epoch" => epoch })
@define_metric = false
end
private
def log_model_as_artifact(model)
model_name = "#{Wandb.current_run.id}_model.json"
model_path = File.join(Wandb.current_run.dir, model_name)
model.save_model(model_path)
model_artifact = Wandb.Artifact(name: model_name, type: "model")
model_artifact.add_file(model_path)
Wandb.current_run.log_artifact(model_artifact)
end
def log_feature_importance(model)
fi = model.score(importance_type: @importance_type)
fi_data = fi.map { |k, v| [k, v] }
table = Wandb.Table(data: fi_data, columns: %w[Feature Importance])
bar_plot = Wandb.plot.bar(table, "Feature", "Importance", title: "Feature Importance")
Wandb.log({ "Feature Importance" => bar_plot })
end
def define_metric(data, metric_name)
full_metric_name = "#{data}-#{metric_name}"
if metric_name.downcase.include?("loss") || MINIMIZE_METRICS.include?(metric_name.downcase)
Wandb.define_metric(full_metric_name, summary: "min")
elsif MAXIMIZE_METRICS.include?(metric_name.downcase)
Wandb.define_metric(full_metric_name, summary: "max")
end
end
end
end