-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathsshe_lr_launcher.py
62 lines (52 loc) · 1.78 KB
/
sshe_lr_launcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import logging
import pprint
import typing
from dataclasses import dataclass, field
from fate.arch.launchers.argparser import HfArgumentParser
from fate.arch.launchers.multiprocess_launcher import launch
if typing.TYPE_CHECKING:
from fate.arch import Context
logger = logging.getLogger(__name__)
@dataclass
class SSHEArguments:
lr: float = field(default=0.15)
guest_data: str = field(default=None)
host_data: str = field(default=None)
def run_sshe_lr(ctx: "Context"):
from fate.ml.glm.hetero.sshe import SSHELogisticRegression
from fate.arch import dataframe
ctx.mpc.init()
args, _ = HfArgumentParser(SSHEArguments).parse_args_into_dataclasses(return_remaining_strings=True)
inst = SSHELogisticRegression(
epochs=5,
batch_size=300,
tol=0.01,
early_stop="diff",
learning_rate=args.lr,
init_param={"method": "random_uniform", "fit_intercept": True, "random_state": 1},
reveal_every_epoch=False,
reveal_loss_freq=2,
threshold=0.5,
)
if ctx.is_on_guest:
kwargs = {
"sample_id_name": None,
"match_id_name": "id",
"delimiter": ",",
"label_name": "y",
"label_type": "int32",
"dtype": "float32",
}
input_data = dataframe.CSVReader(**kwargs).to_frame(ctx, args.guest_data)
else:
kwargs = {
"sample_id_name": None,
"match_id_name": "id",
"delimiter": ",",
"dtype": "float32",
}
input_data = dataframe.CSVReader(**kwargs).to_frame(ctx, args.host_data)
inst.fit(ctx, train_data=input_data)
logger.info(f"model: {pprint.pformat(inst.get_model())}")
if __name__ == "__main__":
launch(run_sshe_lr, extra_args_desc=[SSHEArguments])