-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathserver.py
57 lines (45 loc) · 2.16 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import flwr as fl
import utils
import pandas as pd
import keras
import tensorflow as tf
# df = pd.read_csv("train_split2.csv")
# # X_test, y_test = utils.preprocess(df)
# X_train, X_test, y_train, y_test = utils.preprocess(df)
# X_test = X_train + X_test
# y_test = y_train + y_test
from flwr.common import NDArrays, Scalar, EvaluateRes, FitRes
from typing import Dict, Optional, Tuple, Union, List
from flwr.server.client_proxy import ClientProxy
loaded_model = tf.keras.models.load_model("initialised_model.h5")
# Get model weights as a list of NumPy ndarray's
weights = loaded_model.get_weights()
# Serialize ndarrays to `Parameters`
parameters = fl.common.ndarrays_to_parameters(weights)
class AggregateCustomMetricStrategy(fl.server.strategy.FedAvg):
def aggregate_evaluate(
self,
server_round: int,
results: List[Tuple[ClientProxy, EvaluateRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[float], Dict[str, Scalar]]:
"""Aggregate evaluation accuracy using weighted average."""
if not results:
return None, {}
# Call aggregate_evaluate from base class (FedAvg) to aggregate loss and metrics
aggregated_loss, aggregated_metrics = super().aggregate_evaluate(server_round, results, failures)
# Weigh accuracy of each client by number of examples used
accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results]
examples = [r.num_examples for _, r in results]
# Aggregate and print custom metric
aggregated_accuracy = sum(accuracies) / sum(examples)
print(f"Round {server_round} accuracy aggregated from client results: {aggregated_accuracy}")
# Return aggregated loss and metrics (i.e., aggregated accuracy)
return aggregated_loss, {"accuracy": aggregated_accuracy}
strategy = AggregateCustomMetricStrategy(
# (same arguments as FedAvg here)
min_available_clients=2,
initial_parameters=parameters,
)
# Start Flower server for four rounds of federated learning
fl.server.start_server(server_address="localhost:3000", strategy=strategy, config=fl.server.ServerConfig(num_rounds=5))