Skip to content

Commit

Permalink
docs(examples) Update XGBoost tutorial (#3634)
Browse files Browse the repository at this point in the history
  • Loading branch information
yan-gao-GY authored Jul 1, 2024
1 parent 04fb0d7 commit 4418171
Showing 1 changed file with 134 additions and 78 deletions.
212 changes: 134 additions & 78 deletions doc/source/tutorial-quickstart-xgboost.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,26 +96,26 @@ Prior to local training, we require loading the HIGGS dataset from Flower Datase
fds = FederatedDataset(dataset="jxie/higgs", partitioners={"train": partitioner})
# Load the partition for this `node_id`
partition = fds.load_partition(node_id=args.node_id, split="train")
partition = fds.load_partition(partition_id=args.partition_id, split="train")
partition.set_format("numpy")
In this example, we split the dataset into two partitions with uniform distribution (:code:`IidPartitioner(num_partitions=2)`).
Then, we load the partition for the given client based on :code:`node_id`:
In this example, we split the dataset into 30 partitions with uniform distribution (:code:`IidPartitioner(num_partitions=30)`).
Then, we load the partition for the given client based on :code:`partition_id`:

.. code-block:: python
# We first define arguments parser for user to specify the client/node ID.
# We first define arguments parser for user to specify the client/partition ID.
parser = argparse.ArgumentParser()
parser.add_argument(
"--node-id",
"--partition-id",
default=0,
type=int,
help="Node ID used for the current client.",
help="Partition ID used for the current client.",
)
args = parser.parse_args()
# Load the partition for this `node_id`.
partition = fds.load_partition(idx=args.node_id, split="train")
# Load the partition for this `partition_id`.
partition = fds.load_partition(idx=args.partition_id, split="train")
partition.set_format("numpy")
After that, we do train/test splitting on the given partition (client's local data), and transform data format for :code:`xgboost` package.
Expand Down Expand Up @@ -186,12 +186,23 @@ We follow the general rule to define :code:`XgbClient` class inherited from :cod
.. code-block:: python
class XgbClient(fl.client.Client):
def __init__(self):
self.bst = None
self.config = None
def __init__(
self,
train_dmatrix,
valid_dmatrix,
num_train,
num_val,
num_local_round,
params,
):
self.train_dmatrix = train_dmatrix
self.valid_dmatrix = valid_dmatrix
self.num_train = num_train
self.num_val = num_val
self.num_local_round = num_local_round
self.params = params
The :code:`self.bst` is used to keep the Booster objects that remain consistent across rounds,
allowing them to store predictions from trees integrated in earlier rounds and maintain other essential data structures for training.
All required parameters defined above are passed to :code:`XgbClient`'s constructor.

Then, we override :code:`get_parameters`, :code:`fit` and :code:`evaluate` methods insides :code:`XgbClient` class as follows.

Expand All @@ -214,27 +225,27 @@ As a result, let's return an empty tensor in :code:`get_parameters` when it is c
.. code-block:: python
def fit(self, ins: FitIns) -> FitRes:
if not self.bst:
global_round = int(ins.config["global_round"])
if global_round == 1:
# First round local training
log(INFO, "Start training at round 1")
bst = xgb.train(
params,
train_dmatrix,
num_boost_round=num_local_round,
evals=[(valid_dmatrix, "validate"), (train_dmatrix, "train")],
self.params,
self.train_dmatrix,
num_boost_round=self.num_local_round,
evals=[(self.valid_dmatrix, "validate"), (self.train_dmatrix, "train")],
)
self.config = bst.save_config()
self.bst = bst
else:
bst = xgb.Booster(params=self.params)
for item in ins.parameters.tensors:
global_model = bytearray(item)
# Load global model into booster
self.bst.load_model(global_model)
self.bst.load_config(self.config)
bst.load_model(global_model)
bst = self._local_boost()
# Local training
bst = self._local_boost(bst)
# Save model
local_model = bst.save_raw("json")
local_model_bytes = bytes(local_model)
Expand All @@ -244,60 +255,81 @@ As a result, let's return an empty tensor in :code:`get_parameters` when it is c
message="OK",
),
parameters=Parameters(tensor_type="", tensors=[local_model_bytes]),
num_examples=num_train,
num_examples=self.num_train,
metrics={},
)
In :code:`fit`, at the first round, we call :code:`xgb.train()` to build up the first set of trees.
the returned Booster object and config are stored in :code:`self.bst` and :code:`self.config`, respectively.
From the second round, we load the global model sent from server to :code:`self.bst`,
From the second round, we load the global model sent from server to new build Booster object,
and then update model weights on local training data with function :code:`local_boost` as follows:

.. code-block:: python
def _local_boost(self):
def _local_boost(self, bst_input):
# Update trees based on local training data.
for i in range(num_local_round):
self.bst.update(train_dmatrix, self.bst.num_boosted_rounds())
for i in range(self.num_local_round):
bst_input.update(self.train_dmatrix, bst_input.num_boosted_rounds())
# Extract the last N=num_local_round trees for sever aggregation
bst = self.bst[
self.bst.num_boosted_rounds()
- num_local_round : self.bst.num_boosted_rounds()
# Bagging: extract the last N=num_local_round trees for sever aggregation
bst = bst_input[
bst_input.num_boosted_rounds()
- self.num_local_round : bst_input.num_boosted_rounds()
]
Given :code:`num_local_round`, we update trees by calling :code:`self.bst.update` method.
return bst
Given :code:`num_local_round`, we update trees by calling :code:`bst_input.update` method.
After training, the last :code:`N=num_local_round` trees will be extracted to send to the server.

.. code-block:: python
def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
eval_results = self.bst.eval_set(
evals=[(valid_dmatrix, "valid")],
iteration=self.bst.num_boosted_rounds() - 1,
# Load global model
bst = xgb.Booster(params=self.params)
for para in ins.parameters.tensors:
para_b = bytearray(para)
bst.load_model(para_b)
# Run evaluation
eval_results = bst.eval_set(
evals=[(self.valid_dmatrix, "valid")],
iteration=bst.num_boosted_rounds() - 1,
)
auc = round(float(eval_results.split("\t")[1].split(":")[1]), 4)
global_round = ins.config["global_round"]
log(INFO, f"AUC = {auc} at round {global_round}")
return EvaluateRes(
status=Status(
code=Code.OK,
message="OK",
),
loss=0.0,
num_examples=num_val,
num_examples=self.num_val,
metrics={"AUC": auc},
)
In :code:`evaluate`, we call :code:`self.bst.eval_set` function to conduct evaluation on valid set.
In :code:`evaluate`, after loading the global model, we call :code:`bst.eval_set` function to conduct evaluation on valid set.
The AUC value will be returned.

Now, we can create an instance of our class :code:`XgbClient` and add one line to actually run this client:

.. code-block:: python
fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient())
fl.client.start_client(
server_address="127.0.0.1:8080",
client=XgbClient(
train_dmatrix,
valid_dmatrix,
num_train,
num_val,
num_local_round,
params,
).to_client(),
)
That's it for the client. We only have to implement :code:`Client`and call :code:`fl.client.start_client()`.
That's it for the client. We only have to implement :code:`Client` and call :code:`fl.client.start_client()`.
The string :code:`"[::]:8080"` tells the client which server to connect to.
In our case we can run the server and the client on the same machine, therefore we use
:code:`"[::]:8080"`. If we run a truly federated workload with the server and
Expand Down Expand Up @@ -325,6 +357,8 @@ We first define a strategy for XGBoost bagging aggregation.
min_evaluate_clients=2,
fraction_evaluate=1.0,
evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation,
on_evaluate_config_fn=config_func,
on_fit_config_fn=config_func,
)
def evaluate_metrics_aggregation(eval_metrics):
Expand All @@ -336,8 +370,16 @@ We first define a strategy for XGBoost bagging aggregation.
metrics_aggregated = {"AUC": auc_aggregated}
return metrics_aggregated
def config_func(rnd: int) -> Dict[str, str]:
"""Return a configuration with global epochs."""
config = {
"global_round": str(rnd),
}
return config
We use two clients for this example.
An :code:`evaluate_metrics_aggregation` function is defined to collect and wighted average the AUC values from clients.
The :code:`config_func` function is to return the current FL round number to client's :code:`fit()` and :code:`evaluate()` methods.

Then, we start the server:

Expand All @@ -346,7 +388,7 @@ Then, we start the server:
# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=num_rounds),
config=fl.server.ServerConfig(num_rounds=5),
strategy=strategy,
)
Expand Down Expand Up @@ -535,52 +577,66 @@ Open a new terminal and start the first client:

.. code-block:: shell
$ python3 client.py --node-id=0
$ python3 client.py --partition-id=0
Open another terminal and start the second client:

.. code-block:: shell
$ python3 client.py --node-id=1
$ python3 client.py --partition-id=1
Each client will have its own dataset.
You should now see how the training does in the very first terminal (the one that started the server):

.. code-block:: shell
INFO flwr 2023-11-20 11:21:56,454 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=5, round_timeout=None)
INFO flwr 2023-11-20 11:21:56,473 | app.py:176 | Flower ECE: gRPC server running (5 rounds), SSL is disabled
INFO flwr 2023-11-20 11:21:56,473 | server.py:89 | Initializing global parameters
INFO flwr 2023-11-20 11:21:56,473 | server.py:276 | Requesting initial parameters from one random client
INFO flwr 2023-11-20 11:22:38,302 | server.py:280 | Received initial parameters from one random client
INFO flwr 2023-11-20 11:22:38,302 | server.py:91 | Evaluating initial parameters
INFO flwr 2023-11-20 11:22:38,302 | server.py:104 | FL starting
DEBUG flwr 2023-11-20 11:22:38,302 | server.py:222 | fit_round 1: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,636 | server.py:236 | fit_round 1 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:38,643 | server.py:173 | evaluate_round 1: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,653 | server.py:187 | evaluate_round 1 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:38,653 | server.py:222 | fit_round 2: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,721 | server.py:236 | fit_round 2 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:38,745 | server.py:173 | evaluate_round 2: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,756 | server.py:187 | evaluate_round 2 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:38,756 | server.py:222 | fit_round 3: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,831 | server.py:236 | fit_round 3 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:38,868 | server.py:173 | evaluate_round 3: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,881 | server.py:187 | evaluate_round 3 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:38,881 | server.py:222 | fit_round 4: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:38,960 | server.py:236 | fit_round 4 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:39,012 | server.py:173 | evaluate_round 4: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:39,026 | server.py:187 | evaluate_round 4 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:39,026 | server.py:222 | fit_round 5: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:39,111 | server.py:236 | fit_round 5 received 2 results and 0 failures
DEBUG flwr 2023-11-20 11:22:39,177 | server.py:173 | evaluate_round 5: strategy sampled 2 clients (out of 2)
DEBUG flwr 2023-11-20 11:22:39,193 | server.py:187 | evaluate_round 5 received 2 results and 0 failures
INFO flwr 2023-11-20 11:22:39,193 | server.py:153 | FL finished in 0.8905023969999988
INFO flwr 2023-11-20 11:22:39,193 | app.py:226 | app_fit: losses_distributed [(1, 0), (2, 0), (3, 0), (4, 0), (5, 0)]
INFO flwr 2023-11-20 11:22:39,193 | app.py:227 | app_fit: metrics_distributed_fit {}
INFO flwr 2023-11-20 11:22:39,193 | app.py:228 | app_fit: metrics_distributed {'AUC': [(1, 0.7572), (2, 0.7705), (3, 0.77595), (4, 0.78), (5, 0.78385)]}
INFO flwr 2023-11-20 11:22:39,193 | app.py:229 | app_fit: losses_centralized []
INFO flwr 2023-11-20 11:22:39,193 | app.py:230 | app_fit: metrics_centralized {}
INFO : Starting Flower server, config: num_rounds=5, no round_timeout
INFO : Flower ECE: gRPC server running (5 rounds), SSL is disabled
INFO : [INIT]
INFO : Requesting initial parameters from one random client
INFO : Received initial parameters from one random client
INFO : Evaluating initial global parameters
INFO :
INFO : [ROUND 1]
INFO : configure_fit: strategy sampled 2 clients (out of 2)
INFO : aggregate_fit: received 2 results and 0 failures
INFO : configure_evaluate: strategy sampled 2 clients (out of 2)
INFO : aggregate_evaluate: received 2 results and 0 failures
INFO :
INFO : [ROUND 2]
INFO : configure_fit: strategy sampled 2 clients (out of 2)
INFO : aggregate_fit: received 2 results and 0 failures
INFO : configure_evaluate: strategy sampled 2 clients (out of 2)
INFO : aggregate_evaluate: received 2 results and 0 failures
INFO :
INFO : [ROUND 3]
INFO : configure_fit: strategy sampled 2 clients (out of 2)
INFO : aggregate_fit: received 2 results and 0 failures
INFO : configure_evaluate: strategy sampled 2 clients (out of 2)
INFO : aggregate_evaluate: received 2 results and 0 failures
INFO :
INFO : [ROUND 4]
INFO : configure_fit: strategy sampled 2 clients (out of 2)
INFO : aggregate_fit: received 2 results and 0 failures
INFO : configure_evaluate: strategy sampled 2 clients (out of 2)
INFO : aggregate_evaluate: received 2 results and 0 failures
INFO :
INFO : [ROUND 5]
INFO : configure_fit: strategy sampled 2 clients (out of 2)
INFO : aggregate_fit: received 2 results and 0 failures
INFO : configure_evaluate: strategy sampled 2 clients (out of 2)
INFO : aggregate_evaluate: received 2 results and 0 failures
INFO :
INFO : [SUMMARY]
INFO : Run finished 5 round(s) in 1.67s
INFO : History (loss, distributed):
INFO : round 1: 0
INFO : round 2: 0
INFO : round 3: 0
INFO : round 4: 0
INFO : round 5: 0
INFO : History (metrics, distributed, evaluate):
INFO : {'AUC': [(1, 0.76755), (2, 0.775), (3, 0.77935), (4, 0.7836), (5, 0.7872)]}
Congratulations!
You've successfully built and run your first federated XGBoost system.
Expand Down

0 comments on commit 4418171

Please sign in to comment.