From f5775c4068b8154719e9fb0ce65df98dbcc6c653 Mon Sep 17 00:00:00 2001 From: mralexdmitriy Date: Tue, 9 Apr 2024 18:58:19 +0300 Subject: [PATCH] added weight_decay added default ds_size fixed reproducibility fixed docs other minor fixes --- docs/tutorial.rst | 2 + docs/tutorials/batching_tutorial.rst | 62 +++++++++++++++++++ .../efficientnet-splitNN-mnist-local.yml | 1 + examples/configs/linreg-mnist-local.yml | 2 +- examples/configs/linreg-mnist-single.yml | 2 +- examples/configs/logreg-sbol-smm-local.yml | 1 + .../configs/mlp-splitNN-sbol-smm-local.yml | 3 +- examples/utils/local_arbitered_experiment.py | 3 +- examples/utils/local_experiment.py | 16 +++-- examples/utils/local_experiment_single.py | 6 +- examples/utils/prepare_mnist.py | 5 +- examples/utils/prepare_sbol_smm.py | 8 +-- stalactite/base.py | 4 +- stalactite/communications/local.py | 16 +++-- stalactite/configs.py | 8 +-- .../data_preprocessors/full_data_tensor.py | 2 +- .../tabular_preprocessor.py | 1 - stalactite/helpers.py | 6 +- stalactite/ml/honest/base.py | 4 +- .../honest/linear_regression/party_master.py | 4 +- .../honest/linear_regression/party_member.py | 7 +++ .../logistic_regression/party_member.py | 1 + stalactite/ml/honest/split_learning/base.py | 17 +++-- .../efficientnet/party_master.py | 6 +- .../efficientnet/party_member.py | 6 +- .../honest/split_learning/mlp/party_master.py | 6 +- .../honest/split_learning/mlp/party_member.py | 6 +- .../split_learning/resnet/party_master.py | 4 +- .../split_learning/resnet/party_member.py | 4 +- stalactite/models/linreg_batch.py | 2 +- .../split_learning/efficientnet_bottom.py | 8 ++- .../models/split_learning/efficientnet_top.py | 13 ++-- .../models/split_learning/mlp_bottom.py | 12 ++-- stalactite/models/split_learning/mlp_top.py | 12 ++-- stalactite/party_single_impl.py | 2 - stalactite/utils.py | 24 +++++++ stalactite/utils_main.py | 29 ++++++++- 37 files changed, 238 insertions(+), 77 deletions(-) create mode 100644 docs/tutorials/batching_tutorial.rst create mode 100644 stalactite/utils.py diff --git a/docs/tutorial.rst b/docs/tutorial.rst index 2112013..4baa06a 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -19,5 +19,7 @@ And to launch distributed multinode or multiprocess experiment go to :ref:`distr tutorials/local_communicator_tutorial tutorials/distr_communicator_tutorial tutorials/inference_tutorial + tutorials/batching_tutorial + diff --git a/docs/tutorials/batching_tutorial.rst b/docs/tutorials/batching_tutorial.rst new file mode 100644 index 0000000..22b3bf9 --- /dev/null +++ b/docs/tutorials/batching_tutorial.rst @@ -0,0 +1,62 @@ +.. _batching_tutorial: + +*how-to:* Make batchers +====================================== +We need to remind that the Stalactite framework utilizes train/infer loops inside each agent: PartyMembers and PartyMaster. +And each of these loops contains its batchers, which have to be synchronized. + +To make batch initialization you need to override the ``make_batcher`` method in your PartyMaster class. +For example, you can initialize batcher like this in ``stalactite/ml/honest/linear_regression/party_master.py`` or make your implementation. + +.. code-block:: python + + def make_batcher( + self, + uids: Optional[List[str]] = None, + party_members: Optional[List[str]] = None, + is_infer: bool = False, + ) -> Batcher: + + logger.info("Master %s: making a make_batcher for uids %s" % (self.id, len(uids))) + self.check_if_ready() + batch_size = self._eval_batch_size if is_infer else self._batch_size + epochs = 1 if is_infer else self.epochs + if uids is None: + raise RuntimeError('Master must initialize batcher with collected uids.') + assert party_members is not None, "Master is trying to initialize make_batcher without members list" + return ListBatcher(epochs=epochs, members=party_members, uids=uids, batch_size=batch_size) + + +Party members use their batchers, which need to be initialized in the PartyMember class. +You can find an example of such initialization in ``stalactite/ml/honest/base.py``. +The point is that your batchers have to have the same ``uids, batch_size, and epochs`` to make training/infer properly. + +.. code-block:: python + + def make_batcher( + self, + uids: Optional[List[str]] = None, + party_members: Optional[List[str]] = None, + is_infer: bool = False, + ) -> Batcher: + epochs = 1 if is_infer else self.epochs + batch_size = self._eval_batch_size if is_infer else self._batch_size + uids_to_use = self._uids_to_use_test if is_infer else self._uids_to_use + if uids_to_use is None: + raise RuntimeError("Cannot create make_batcher, you must `register_records_uids` first.") + return self._create_batcher(epochs=epochs, uids=uids_to_use, batch_size=batch_size) + + def _create_batcher(self, epochs: int, uids: List[str], batch_size: int) -> Batcher: + """Create a make_batcher for training. + + :param epochs: Number of training epochs. + :param uids: List of unique identifiers for dataset rows. + :param batch_size: Size of the training batch. + """ + logger.info("Member %s: making a make_batcher for uids" % self.id) + if not self.is_consequently: + return ListBatcher(epochs=epochs, members=None, uids=uids, batch_size=batch_size) + else: + return ConsecutiveListBatcher( + epochs=epochs, members=self.members, uids=uids, batch_size=batch_size + ) diff --git a/examples/configs/efficientnet-splitNN-mnist-local.yml b/examples/configs/efficientnet-splitNN-mnist-local.yml index ddfd68c..9d1e793 100644 --- a/examples/configs/efficientnet-splitNN-mnist-local.yml +++ b/examples/configs/efficientnet-splitNN-mnist-local.yml @@ -4,6 +4,7 @@ common: world_size: 2 experiment_label: experiment-efficientnet-mnist-local reports_export_folder: "../../reports" + seed: 22 vfl_model: epochs: 1 diff --git a/examples/configs/linreg-mnist-local.yml b/examples/configs/linreg-mnist-local.yml index 4408c8d..f028b2c 100644 --- a/examples/configs/linreg-mnist-local.yml +++ b/examples/configs/linreg-mnist-local.yml @@ -4,6 +4,7 @@ common: world_size: 2 experiment_label: test-experiment-mnist-local reports_export_folder: ../../reports + seed: 22 vfl_model: epochs: 2 @@ -19,7 +20,6 @@ vfl_model: vfl_model_path: ../../saved_models/linreg_model data: - random_seed: 0 dataset_size: 5000 dataset: 'mnist' host_path_data_dir: ../../data/sber_ds_vfl/mnist_vfl_parts2 diff --git a/examples/configs/linreg-mnist-single.yml b/examples/configs/linreg-mnist-single.yml index dd4cfb8..fd8f700 100644 --- a/examples/configs/linreg-mnist-single.yml +++ b/examples/configs/linreg-mnist-single.yml @@ -1,7 +1,7 @@ common: report_train_metrics_iteration: 1 report_test_metrics_iteration: 1 - world_size: 2 + world_size: 1 experiment_label: experiment-mnist-centralized reports_export_folder: "../../reports" diff --git a/examples/configs/logreg-sbol-smm-local.yml b/examples/configs/logreg-sbol-smm-local.yml index 081027e..261f200 100644 --- a/examples/configs/logreg-sbol-smm-local.yml +++ b/examples/configs/logreg-sbol-smm-local.yml @@ -13,6 +13,7 @@ vfl_model: is_consequently: False use_class_weights: True learning_rate: 0.05 + weight_decay: 0.02 do_train: True do_predict: True do_save_model: True diff --git a/examples/configs/mlp-splitNN-sbol-smm-local.yml b/examples/configs/mlp-splitNN-sbol-smm-local.yml index ed20604..eb5761d 100644 --- a/examples/configs/mlp-splitNN-sbol-smm-local.yml +++ b/examples/configs/mlp-splitNN-sbol-smm-local.yml @@ -4,6 +4,7 @@ common: world_size: 2 experiment_label: experiment-mlp-sbol-smm-local reports_export_folder: "../../reports" + seed: 22 vfl_model: epochs: 1 @@ -11,7 +12,7 @@ vfl_model: eval_batch_size: 200 vfl_model_name: mlp is_consequently: False - use_class_weights: False + use_class_weights: True learning_rate: 0.01 do_train: True do_predict: False diff --git a/examples/utils/local_arbitered_experiment.py b/examples/utils/local_arbitered_experiment.py index 8a13036..58050c6 100644 --- a/examples/utils/local_arbitered_experiment.py +++ b/examples/utils/local_arbitered_experiment.py @@ -19,7 +19,7 @@ from examples.utils.prepare_mnist import load_data as load_mnist from examples.utils.prepare_sbol_smm import load_data as load_sbol_smm from stalactite.helpers import reporting, run_local_agents - +from stalactite.utils import seed_all logging.basicConfig(level=logging.DEBUG) logging.getLogger("urllib3").setLevel(logging.CRITICAL) @@ -90,6 +90,7 @@ def run(config_path: Optional[str] = None): ) config = VFLConfig.load_and_validate(config_path) + seed_all(config.common.seed) master_processor, processors = load_processors(config) with reporting(config): diff --git a/examples/utils/local_experiment.py b/examples/utils/local_experiment.py index 8cba8d7..ed5d850 100644 --- a/examples/utils/local_experiment.py +++ b/examples/utils/local_experiment.py @@ -25,6 +25,7 @@ from examples.utils.prepare_mnist import load_data as load_mnist from examples.utils.prepare_sbol_smm import load_data as load_sbol_smm from stalactite.helpers import reporting, run_local_agents +from stalactite.utils import seed_all logging.basicConfig(level=logging.DEBUG) logging.getLogger("urllib3").setLevel(logging.CRITICAL) @@ -43,10 +44,12 @@ def load_processors(config: VFLConfig): """ if config.data.dataset.lower() == "mnist": - binary = False if config.vfl_model.vfl_model_name == "efficientnet" else True + binary = False if config.vfl_model.vfl_model_name in ["efficientnet", "logreg"] else True if len(os.listdir(config.data.host_path_data_dir)) == 0: - load_mnist(config.data.host_path_data_dir, config.common.world_size, binary=binary) + load_mnist( + save_path=Path(config.data.host_path_data_dir), parts_num=config.common.world_size, binary=binary + ) dataset = {} for m in range(config.common.world_size): @@ -96,7 +99,10 @@ def run(config_path: Optional[str] = None): ) config = VFLConfig.load_and_validate(config_path) + seed_all(config.common.seed) master_processor, processors = load_processors(config) + if config.data.dataset_size == -1: + config.data.dataset_size = len(master_processor.dataset[config.data.train_split][config.data.uids_key]) with reporting(config): shared_party_info = dict() @@ -134,7 +140,8 @@ def run(config_path: Optional[str] = None): do_predict=config.vfl_model.do_predict, model_name=config.vfl_model.vfl_model_name if config.vfl_model.vfl_model_name in ["resnet", "mlp", "efficientnet"] else None, - model_params=config.master.master_model_params + model_params=config.master.master_model_params, + seed=config.common.seed ) member_ids = [f"member-{member_rank}" for member_rank in range(config.common.world_size)] @@ -158,7 +165,8 @@ def run(config_path: Optional[str] = None): do_save_model=config.vfl_model.do_save_model, model_path=config.vfl_model.vfl_model_path, model_params=config.member.member_model_params, - use_inner_join=True if member_rank == 0 else False + use_inner_join=True if member_rank == 0 else False, + seed=config.common.seed ) for member_rank, member_uid in enumerate(member_ids) diff --git a/examples/utils/local_experiment_single.py b/examples/utils/local_experiment_single.py index 14e05b5..fe1b609 100644 --- a/examples/utils/local_experiment_single.py +++ b/examples/utils/local_experiment_single.py @@ -11,6 +11,7 @@ from examples.utils.prepare_mnist import load_data as load_mnist from examples.utils.prepare_sbol_smm import load_data as load_sbol_smm from stalactite.helpers import reporting +from stalactite.utils import seed_all logging.basicConfig(level=logging.DEBUG) logging.getLogger("urllib3").setLevel(logging.CRITICAL) @@ -23,7 +24,7 @@ def load_processors(config_path: str): if config.data.dataset.lower() == "mnist": if len(os.listdir(config.data.host_path_data_dir)) == 0: - load_mnist(config.data.host_path_data_dir, parts_num=1, is_single=True, binary=True) + load_mnist(Path(config.data.host_path_data_dir), parts_num=1, binary=True) dataset = {0: datasets.load_from_disk( os.path.join(f"{config.data.host_path_data_dir}/part_{0}") @@ -36,7 +37,7 @@ def load_processors(config_path: str): elif config.data.dataset.lower() == "sbol_smm": if len(os.listdir(config.data.host_path_data_dir)) == 0: - load_sbol_smm(os.path.dirname(config.data.host_path_data_dir), parts_num=1, is_single=True, sbol_only=False) + load_sbol_smm(os.path.dirname(config.data.host_path_data_dir), parts_num=1, sbol_only=False) dataset = {0: datasets.load_from_disk( os.path.join(f"{config.data.host_path_data_dir}/part_{0}") @@ -59,6 +60,7 @@ def run(config_path: Optional[str] = None): os.path.join(Path(__file__).parent.parent.parent, 'configs/linreg-mnist-local.yml') ) config = VFLConfig.load_and_validate(config_path) + seed_all(config.common.seed) model_name = config.vfl_model.vfl_model_name with reporting(config): diff --git a/examples/utils/prepare_mnist.py b/examples/utils/prepare_mnist.py index 8c56a52..e9ebd73 100644 --- a/examples/utils/prepare_mnist.py +++ b/examples/utils/prepare_mnist.py @@ -136,14 +136,13 @@ def save_splitted_dataset(ds_list, path, part_dir_name='part_', clean_dir=False) ds.save_to_disk(part_path) -def load_data(save_path, parts_num, binary: bool = True, is_single: bool = False): +def load_data(save_path: Path, parts_num: int, binary: bool = True): """ The input is the original MNIST dataset. 1. Labels filtered and replaced so that the task is binary. """ - make_validation = True test_size = 0.15 stratify_by_column = 'label' @@ -170,7 +169,7 @@ def load_data(save_path, parts_num, binary: bool = True, is_single: bool = False if make_validation: train_train, train_val = make_train_val_split(mnist['train'], test_size=test_size, stratify_by_column=stratify_by_column, shuffle=shuffle, seed=seed) - if not is_single: + if parts_num != 1: train_train_labels = train_train.select_columns(["image_idx", "label"]) train_val_labels = train_val.select_columns(["image_idx", "label"]) diff --git a/examples/utils/prepare_sbol_smm.py b/examples/utils/prepare_sbol_smm.py index 3eba1f2..3d4f389 100644 --- a/examples/utils/prepare_sbol_smm.py +++ b/examples/utils/prepare_sbol_smm.py @@ -28,7 +28,7 @@ def split_save_datasets(df, train_users, test_users, columns, postfix_sample, di ) -def load_data(data_dir_path: str, parts_num: int = 2, is_single: bool = False, sbol_only: bool = False): +def load_data(data_dir_path: str, parts_num: int = 2, sbol_only: bool = False): sbol_path = os.path.join(data_dir_path, "sbol") smm_path = os.path.join(data_dir_path, "smm") @@ -44,7 +44,7 @@ def load_data(data_dir_path: str, parts_num: int = 2, is_single: bool = False, s sbol_labels[["user_id"]], shuffle=True, random_state=seed, test_size=0.15 ) - if not is_single: + if parts_num != 1: logger.info("Save vfl dataset labels part...") split_save_datasets(df=sbol_labels, train_users=users_train, test_users=users_test, columns=["user_id", "labels"], postfix_sample=sample, part_postfix="master_part", @@ -61,7 +61,7 @@ def load_data(data_dir_path: str, parts_num: int = 2, is_single: bool = False, s lambda x: list(x), axis=1) sbol_user_features = sbol_user_features[["user_id", "features_part_0", "labels"]] - if not is_single: + if parts_num != 1: logger.info("Save vfl dataset part 0...") split_save_datasets(df=sbol_user_features, train_users=users_train, test_users=users_test, columns=["user_id", "features_part_0"], postfix_sample=sample, part_postfix="part_0", @@ -85,7 +85,7 @@ def load_data(data_dir_path: str, parts_num: int = 2, is_single: bool = False, s smm_user_factors = sbol_labels[["user_id"]].merge(smm_user_factors, on="user_id", how="inner") smm_user_factors.rename(columns={"user_factors": "features_part_1"}, inplace=True) - if not is_single: + if parts_num != 1: logger.info("Save vfl dataset part 1...") split_save_datasets(df=smm_user_factors, train_users=users_train, test_users=users_test, columns=["user_id", "features_part_1"], postfix_sample=sample, part_postfix="part_1", diff --git a/stalactite/base.py b/stalactite/base.py index 8046d48..e48f1a3 100644 --- a/stalactite/base.py +++ b/stalactite/base.py @@ -233,10 +233,12 @@ def make_batcher( party_members: Optional[List[str]] = None, is_infer: bool = False ) -> Batcher: - """ Make a make_batcher for training. + """ Make a make_batcher for training or inference. :param uids: List of unique identifiers of dataset records. :param party_members: List of party members` identifiers. + :param is_infer: Flag indicating whether to use inference mode. + Inference mode means that the batcher will have only one epoch and use eval_batch_size. :return: Batcher instance. """ diff --git a/stalactite/communications/local.py b/stalactite/communications/local.py index 09a73e9..fe0b824 100644 --- a/stalactite/communications/local.py +++ b/stalactite/communications/local.py @@ -198,10 +198,11 @@ def broadcast( return tasks def gather(self, tasks: List[Task], recv_results: bool = False) -> List[Task]: - _recv_futures = [ - RecvFuture(method_name=task.method_name, receive_from_id=task.from_id if not recv_results else task.to_id) - for task in tasks - ] + receive_from_ids, _recv_futures = [], [] + for task in tasks: + rfi = task.from_id if not recv_results else task.to_id + receive_from_ids.append(rfi) + _recv_futures.append(RecvFuture(method_name=task.method_name, receive_from_id=rfi)) for recv_f in _recv_futures: threading.Thread(target=self._get_from_recv, args=(recv_f,), daemon=True).start() @@ -209,7 +210,12 @@ def gather(self, tasks: List[Task], recv_results: bool = False) -> List[Task]: if number_to := len(pending_tasks): raise TimeoutError(f"{self.participant.id} could not gather tasks from {number_to} members.") - return [task.result() for task in done_tasks] + + results = {} + for task in done_tasks: + results[task.receive_from_id] = task.result() + + return [results[idx] for idx in receive_from_ids] def raise_if_not_ready(self): """Raise an exception if the communicator was not initialized properly.""" diff --git a/stalactite/configs.py b/stalactite/configs.py index 7b78975..c945a31 100644 --- a/stalactite/configs.py +++ b/stalactite/configs.py @@ -55,6 +55,7 @@ class CommonConfig(BaseModel): default=Path(__file__).parent, description="Folder for exporting tests` and experiments` reports" ) rendezvous_timeout: float = Field(default=3600, description="Initial agents rendezvous timeout in sec") + seed: int = Field(default=42, description="Initial random seed") logging_level: Literal["debug", "info", "warning"] = Field(default="info", description="Logging level") @model_validator(mode="after") @@ -81,6 +82,7 @@ class VFLModelConfig(BaseModel): learning_rate: float = Field(default=0.01, description='Learning rate') l2_alpha: Optional[float] = Field(default=None, description='Alpha used for L2 regularization') momentum: Optional[float] = Field(default=0, description='Optimizer momentum') + weight_decay: Optional[float] = Field(default=0.01, description='Optimizer weight decay') do_train: bool = Field(default=True, description='Whether to run a training loop.') do_predict: bool = Field(default=True, description='Whether to run an inference loop.') do_save_model: bool = Field(default=True, description='Whether to save the model after training.') @@ -93,11 +95,7 @@ class VFLModelConfig(BaseModel): class DataConfig(BaseModel): """Experimental data parameters config.""" - random_seed: int = Field( - default=0, - description="Experiment data random seed (including random, numpy, torch)" - ) - dataset_size: int = Field(default=1000, description="Number of dataset rows to use") + dataset_size: int = Field(default=100, description="Number of dataset rows to use") host_path_data_dir: str = Field(default='.', description="Path to datasets` directory") dataset: Literal[ 'mnist', 'sbol', 'sbol_smm', 'home_credit', 'home_credit_bureau_pos', 'avito', 'avito_texts_images'] = Field( diff --git a/stalactite/data_preprocessors/full_data_tensor.py b/stalactite/data_preprocessors/full_data_tensor.py index 4662a36..dc73f62 100644 --- a/stalactite/data_preprocessors/full_data_tensor.py +++ b/stalactite/data_preprocessors/full_data_tensor.py @@ -17,7 +17,7 @@ def transform(self, inp_data): def _transform(self, inp_data): - num_rows = inp_data.num_rows #todo: refactor + num_rows = inp_data.num_rows tnsr = torch.as_tensor(inp_data[self.input_feature_name][0:num_rows]) diff --git a/stalactite/data_preprocessors/tabular_preprocessor.py b/stalactite/data_preprocessors/tabular_preprocessor.py index 383af0d..8a07a7f 100644 --- a/stalactite/data_preprocessors/tabular_preprocessor.py +++ b/stalactite/data_preprocessors/tabular_preprocessor.py @@ -49,7 +49,6 @@ def fit_transform(self): split_dict[feature_name] = standard_scaler.fit_transform(split_dict[feature_name]) split_dict[uids_name] = split_data[uids_name] - # split_dict[label_name] = [x[6] for x in split_dict[label_name]] # todo: remove (for debugging only) if self.is_master and isinstance(train_split_data[label_name][0], list): self.multilabel = True diff --git a/stalactite/helpers.py b/stalactite/helpers.py index 43512b2..dc4932a 100644 --- a/stalactite/helpers.py +++ b/stalactite/helpers.py @@ -1,4 +1,5 @@ import logging +import os import time from contextlib import contextmanager from threading import Thread @@ -52,7 +53,8 @@ def log_timing(name: str, log_func: Callable = print): @contextmanager def reporting(config: VFLConfig): if config.master.run_mlflow: - mlflow.set_tracking_uri(f"http://{config.prerequisites.mlflow_host}:{config.prerequisites.mlflow_port}") + mlflow_host = os.environ.get('STALACTITE_MLFLOW_HOST', config.prerequisites.mlflow_host) + mlflow.set_tracking_uri(f"http://{mlflow_host}:{config.prerequisites.mlflow_port}") mlflow.set_experiment(config.common.experiment_label) mlflow.start_run() @@ -66,6 +68,7 @@ def reporting(config: VFLConfig): "is_consequently": config.vfl_model.is_consequently, "model_name": config.vfl_model.vfl_model_name, "learning_rate": config.vfl_model.learning_rate, + "weight_decay": config.vfl_model.weight_decay, "dataset": config.data.dataset, } @@ -106,4 +109,3 @@ def run_local_agents( for thread in threads: thread.join() - diff --git a/stalactite/ml/honest/base.py b/stalactite/ml/honest/base.py index 257ff0b..f97e584 100644 --- a/stalactite/ml/honest/base.py +++ b/stalactite/ml/honest/base.py @@ -331,7 +331,8 @@ def __init__( do_predict: bool = False, do_save_model: bool = False, use_inner_join: bool = False, - model_params: dict = None + model_params: dict = None, + seed: int = None ) -> None: """ Initialize PartyMemberImpl. @@ -371,6 +372,7 @@ def __init__( self.do_save_model = do_save_model self._optimizer = None self.use_inner_join = use_inner_join + self.seed = seed if self.is_consequently: if self.members is None: diff --git a/stalactite/ml/honest/linear_regression/party_master.py b/stalactite/ml/honest/linear_regression/party_master.py index 8be48bc..798c11a 100644 --- a/stalactite/ml/honest/linear_regression/party_master.py +++ b/stalactite/ml/honest/linear_regression/party_master.py @@ -36,7 +36,8 @@ def __init__( do_train: bool = True, do_predict: bool = False, model_name: str = None, - model_params: dict = None + model_params: dict = None, + seed: int = None ) -> None: """ Initialize PartyMaster. @@ -73,6 +74,7 @@ def __init__( self._model_name = model_name self.aggregated_output = None self._model_params = model_params + self.seed = seed self.uid2tensor_idx = None self.uid2tensor_idx_test = None diff --git a/stalactite/ml/honest/linear_regression/party_member.py b/stalactite/ml/honest/linear_regression/party_member.py index 2c14519..fca16e0 100644 --- a/stalactite/ml/honest/linear_regression/party_member.py +++ b/stalactite/ml/honest/linear_regression/party_member.py @@ -2,12 +2,16 @@ import os from abc import ABC from typing import Optional, Any +import math import torch +from torch import nn +import numpy as np from stalactite.base import RecordsBatch, DataTensor from stalactite.ml.honest.base import HonestPartyMember from stalactite.models import LinearRegressionBatch +from stalactite.utils import init_linear_np logger = logging.getLogger(__name__) @@ -27,6 +31,9 @@ def initialize_model(self, do_load_model: bool = False) -> None: **self._model_params ) + init_linear_np(self._model.linear, seed=self.seed) + + def initialize_optimizer(self) -> None: pass diff --git a/stalactite/ml/honest/logistic_regression/party_member.py b/stalactite/ml/honest/logistic_regression/party_member.py index 3c6f138..000b0d2 100644 --- a/stalactite/ml/honest/logistic_regression/party_member.py +++ b/stalactite/ml/honest/logistic_regression/party_member.py @@ -26,4 +26,5 @@ def initialize_optimizer(self) -> None: ], lr=self._common_params.learning_rate, momentum=self._common_params.momentum, + weight_decay=self._common_params.weight_decay, ) diff --git a/stalactite/ml/honest/split_learning/base.py b/stalactite/ml/honest/split_learning/base.py index 3d7cba9..484336a 100644 --- a/stalactite/ml/honest/split_learning/base.py +++ b/stalactite/ml/honest/split_learning/base.py @@ -116,11 +116,8 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: participating_members=titer.participating_members, ) - ordered_gather = sorted(party.gather(update_predict_tasks, recv_results=True), - key=lambda x: int(x.from_id.split('-')[-1])) - party_members_predictions = [ - task.result for task in ordered_gather + task.result for task in party.gather(update_predict_tasks, recv_results=True) ] agg_members_predictions = self.aggregate(titer.participating_members, party_members_predictions) @@ -147,9 +144,9 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: participating_members=titer.participating_members, ) - ordered_gather = sorted(party.gather(predict_tasks, recv_results=True), - key=lambda x: int(x.from_id.split('-')[-1])) - party_members_predictions = [task.result for task in ordered_gather] + party_members_predictions = [ + task.result for task in party.gather(predict_tasks, recv_results=True) + ] agg_members_predictions = self.aggregate(party.members, party_members_predictions, is_infer=True) master_predictions = self.predict(x=agg_members_predictions, use_activation=True) @@ -170,10 +167,10 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: method_kwargs=MethodKwargs(other_kwargs={"uids": None, "is_infer": True}), participating_members=titer.participating_members, ) - ordered_gather = sorted(party.gather(predict_test_tasks, recv_results=True), - key=lambda x: int(x.from_id.split('-')[-1])) - party_members_predictions = [task.result for task in ordered_gather] + party_members_predictions = [ + task.result for task in party.gather(predict_test_tasks, recv_results=True) + ] agg_members_predictions = self.aggregate(party.members, party_members_predictions, is_infer=True) master_predictions = self.predict(x=agg_members_predictions, use_activation=True) self.report_metrics( diff --git a/stalactite/ml/honest/split_learning/efficientnet/party_master.py b/stalactite/ml/honest/split_learning/efficientnet/party_master.py index 2382e51..d1ed512 100644 --- a/stalactite/ml/honest/split_learning/efficientnet/party_master.py +++ b/stalactite/ml/honest/split_learning/efficientnet/party_master.py @@ -18,7 +18,7 @@ class HonestPartyMasterEfficientNetSplitNN(HonestPartyMasterSplitNN): def initialize_model(self, do_load_model: bool = False) -> None: """ Initialize the model based on the specified model name. """ - self._model = EfficientNetTop(**self._model_params) + self._model = EfficientNetTop(**self._model_params, seed=self.seed) class_weights = None if self.class_weights is None else self.class_weights.type(torch.FloatTensor) self._criterion = nn.CrossEntropyLoss(weight=class_weights) self._activation = nn.Softmax(dim=1) @@ -28,7 +28,9 @@ def initialize_optimizer(self) -> None: {"params": self._model.parameters()}, ], lr=self._common_params.learning_rate, - momentum=self._common_params.momentum + momentum=self._common_params.momentum, + weight_decay=self._common_params.weight_decay, + ) def aggregate( diff --git a/stalactite/ml/honest/split_learning/efficientnet/party_member.py b/stalactite/ml/honest/split_learning/efficientnet/party_member.py index 854fcd1..45757c6 100644 --- a/stalactite/ml/honest/split_learning/efficientnet/party_member.py +++ b/stalactite/ml/honest/split_learning/efficientnet/party_member.py @@ -13,12 +13,14 @@ def initialize_model(self, do_load_model: bool = False) -> None: if do_load_model: self._model = self.load_model() else: - self._model = EfficientNetBottom(**self._model_params) + self._model = EfficientNetBottom(**self._model_params, seed=self.seed) def initialize_optimizer(self) -> None: self._optimizer = SGD([ {"params": self._model.parameters()}, ], lr=self._common_params.learning_rate, - momentum=self._common_params.momentum + momentum=self._common_params.momentum, + weight_decay=self._common_params.weight_decay, + ) diff --git a/stalactite/ml/honest/split_learning/mlp/party_master.py b/stalactite/ml/honest/split_learning/mlp/party_master.py index b2d4279..d1d2cca 100644 --- a/stalactite/ml/honest/split_learning/mlp/party_master.py +++ b/stalactite/ml/honest/split_learning/mlp/party_master.py @@ -14,7 +14,7 @@ class HonestPartyMasterMLPSplitNN(HonestPartyMasterSplitNN): def initialize_model(self, do_load_model: bool = False) -> None: """ Initialize the model based on the specified model name. """ - self._model = MLPTop(**self._model_params) + self._model = MLPTop(**self._model_params, seed=self.seed) self._criterion = torch.nn.BCEWithLogitsLoss(pos_weight=self.class_weights) if self.binary else torch.nn.CrossEntropyLoss(weight=self.class_weights) def initialize_optimizer(self) -> None: @@ -22,7 +22,9 @@ def initialize_optimizer(self) -> None: {"params": self._model.parameters()}, ], lr=self._common_params.learning_rate, - momentum=self._common_params.momentum + momentum=self._common_params.momentum, + weight_decay=self._common_params.weight_decay, + ) def aggregate( diff --git a/stalactite/ml/honest/split_learning/mlp/party_member.py b/stalactite/ml/honest/split_learning/mlp/party_member.py index dba8d16..72a7e85 100644 --- a/stalactite/ml/honest/split_learning/mlp/party_member.py +++ b/stalactite/ml/honest/split_learning/mlp/party_member.py @@ -12,12 +12,14 @@ def initialize_model(self, do_load_model: bool = False) -> None: if do_load_model: self._model = self.load_model() else: - self._model = MLPBottom(input_dim=input_dim, **self._model_params) + self._model = MLPBottom(input_dim=input_dim, **self._model_params, seed=self.seed) def initialize_optimizer(self) -> None: self._optimizer = SGD([ {"params": self._model.parameters()}, ], lr=self._common_params.learning_rate, - momentum=self._common_params.momentum + momentum=self._common_params.momentum, + weight_decay=self._common_params.weight_decay, + ) diff --git a/stalactite/ml/honest/split_learning/resnet/party_master.py b/stalactite/ml/honest/split_learning/resnet/party_master.py index a55f6c7..85b318c 100644 --- a/stalactite/ml/honest/split_learning/resnet/party_master.py +++ b/stalactite/ml/honest/split_learning/resnet/party_master.py @@ -23,7 +23,9 @@ def initialize_optimizer(self) -> None: {"params": self._model.parameters()}, ], lr=self._common_params.learning_rate, - momentum=self._common_params.momentum + momentum=self._common_params.momentum, + weight_decay=self._common_params.weight_decay, + ) def aggregate( diff --git a/stalactite/ml/honest/split_learning/resnet/party_member.py b/stalactite/ml/honest/split_learning/resnet/party_member.py index de7d78d..3fbc73c 100644 --- a/stalactite/ml/honest/split_learning/resnet/party_member.py +++ b/stalactite/ml/honest/split_learning/resnet/party_member.py @@ -21,5 +21,7 @@ def initialize_optimizer(self) -> None: {"params": self._model.parameters()}, ], lr=self._common_params.learning_rate, - momentum=self._common_params.momentum + momentum=self._common_params.momentum, + weight_decay=self._common_params.weight_decay, + ) diff --git a/stalactite/models/linreg_batch.py b/stalactite/models/linreg_batch.py index e3cf41b..225bd25 100644 --- a/stalactite/models/linreg_batch.py +++ b/stalactite/models/linreg_batch.py @@ -60,4 +60,4 @@ def init_params(self): 'input_dim': self.input_dim, 'output_dim': self.output_dim, 'reg_lambda': self.reg_lambda, - } \ No newline at end of file + } diff --git a/stalactite/models/split_learning/efficientnet_bottom.py b/stalactite/models/split_learning/efficientnet_bottom.py index 851f8a3..d879a8c 100644 --- a/stalactite/models/split_learning/efficientnet_bottom.py +++ b/stalactite/models/split_learning/efficientnet_bottom.py @@ -10,6 +10,8 @@ from torchvision.utils import _log_api_usage_once from torchvision.ops.misc import Conv2dNormActivation +from stalactite.utils import seed_all + def _efficientnet_conf( width_mult: float, @@ -38,7 +40,8 @@ def __init__( depth_mult: float = 1.0, stochastic_depth_prob: float = 0.2, norm_layer: Optional[Callable[..., nn.Module]] = None, - init_weights: float = None + init_weights: float = None, + seed: int = None, ) -> None: """ EfficientNet V1 and V2 main class @@ -49,7 +52,7 @@ def __init__( """ super().__init__() _log_api_usage_once(self) - + self.seed = seed inverted_residual_setting, last_channel = _efficientnet_conf(width_mult=width_mult, depth_mult=depth_mult) if norm_layer is None: @@ -108,6 +111,7 @@ def __init__( if init_weights: nn.init.constant_(m.weight, init_weights) else: + seed_all(self.seed) nn.init.kaiming_normal_(m.weight, mode="fan_out") if m.bias is not None: nn.init.zeros_(m.bias) diff --git a/stalactite/models/split_learning/efficientnet_top.py b/stalactite/models/split_learning/efficientnet_top.py index 2e87697..3d55235 100644 --- a/stalactite/models/split_learning/efficientnet_top.py +++ b/stalactite/models/split_learning/efficientnet_top.py @@ -3,11 +3,14 @@ from typing import Optional, Sequence, Union, Tuple import torch +import numpy as np from torch import nn, Tensor from torchvision.models.efficientnet import MBConvConfig, FusedMBConvConfig from torchvision.utils import _log_api_usage_once +from stalactite.utils import init_linear_np + def _efficientnet_conf( width_mult: float, @@ -33,9 +36,10 @@ class EfficientNetTop(nn.Module): def __init__( self, dropout: float = 0.1, - input_dim=None, # todo: get it somewhere + input_dim=None, num_classes: int = 1000, - init_weights: float = None + init_weights: float = None, + seed: int = None, ) -> None: """ @@ -51,7 +55,7 @@ def __init__( _log_api_usage_once(self) self.criterion = torch.nn.CrossEntropyLoss() - + self.seed = seed self.avgpool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Sequential( nn.Dropout(p=dropout, inplace=True), @@ -63,8 +67,7 @@ def __init__( if init_weights: nn.init.constant_(m.weight, init_weights) else: - init_range = 1.0 / math.sqrt(m.out_features) - nn.init.uniform_(m.weight, -init_range, init_range) + init_linear_np(m, seed=self.seed) nn.init.zeros_(m.bias) def _forward_impl(self, x: Tensor) -> Tensor: diff --git a/stalactite/models/split_learning/mlp_bottom.py b/stalactite/models/split_learning/mlp_bottom.py index 8da7f12..5b73b61 100644 --- a/stalactite/models/split_learning/mlp_bottom.py +++ b/stalactite/models/split_learning/mlp_bottom.py @@ -1,13 +1,13 @@ import logging -import math from typing import Callable, List, Optional - import torch from torch import nn, Tensor from torchvision.utils import _log_api_usage_once +from stalactite.utils import init_linear_np + logger = logging.getLogger(__name__) @@ -22,12 +22,13 @@ def __init__( dropout: float = 0.0, multilabel: bool = True, init_weights: float = None, - class_weights: torch.Tensor = None + class_weights: torch.Tensor = None, + seed: int = None, ) -> None: super().__init__() _log_api_usage_once(self) - + self.seed = seed if multilabel: self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=class_weights) @@ -48,8 +49,7 @@ def __init__( if init_weights: nn.init.constant_(m.weight, init_weights) else: - init_range = 1.0 / math.sqrt(m.out_features) - nn.init.uniform_(m.weight, -init_range, init_range) + init_linear_np(m, seed=self.seed) nn.init.zeros_(m.bias) def _forward_impl(self, x: Tensor) -> Tensor: diff --git a/stalactite/models/split_learning/mlp_top.py b/stalactite/models/split_learning/mlp_top.py index 2946be0..1181e00 100644 --- a/stalactite/models/split_learning/mlp_top.py +++ b/stalactite/models/split_learning/mlp_top.py @@ -1,12 +1,12 @@ import logging -import math - import torch from torch import nn, Tensor from torchvision.utils import _log_api_usage_once +from stalactite.utils import init_linear_np + logger = logging.getLogger(__name__) @@ -18,12 +18,13 @@ def __init__( bias: bool = True, multilabel: bool = True, init_weights: float = None, - class_weights: torch.Tensor = None + class_weights: torch.Tensor = None, + seed: int = None, ) -> None: super().__init__() _log_api_usage_once(self) - + self.seed = seed if multilabel: self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=class_weights) else: @@ -36,8 +37,7 @@ def __init__( if init_weights: nn.init.constant_(m.weight, init_weights) else: - init_range = 1.0 / math.sqrt(m.out_features) - nn.init.uniform_(m.weight, -init_range, init_range) + init_linear_np(m, seed=self.seed) nn.init.zeros_(m.bias) def _forward_impl(self, x: Tensor) -> Tensor: diff --git a/stalactite/party_single_impl.py b/stalactite/party_single_impl.py index a9f594e..49e38c2 100644 --- a/stalactite/party_single_impl.py +++ b/stalactite/party_single_impl.py @@ -167,9 +167,7 @@ def initialize(self) -> None: self._dataset = self.processor.fit_transform() self.x_train = self._dataset[self.processor.data_params.train_split][self.processor.data_params.features_key] - a = torch.isnan(self.x_train).any() # todo: remove self.x_test = self._dataset[self.processor.data_params.test_split][self.processor.data_params.features_key] - b = torch.isnan(self.x_test).any() # todo: remove self.target = self._dataset[self.processor.data_params.train_split][self.processor.data_params.label_key] self.test_target = self._dataset[self.processor.data_params.test_split][self.processor.data_params.label_key] diff --git a/stalactite/utils.py b/stalactite/utils.py new file mode 100644 index 0000000..a94385c --- /dev/null +++ b/stalactite/utils.py @@ -0,0 +1,24 @@ +import math +import os +import random + +import numpy as np +import torch +from torch import nn as nn + + +def init_linear_np(module: nn.Linear, seed: int): + seed_all(seed) + init_range = 1.0 / math.sqrt(module.out_features) + np_uniform = np.random.uniform(low=-init_range, high=init_range, size=module.weight.shape) + module.weight.data = torch.from_numpy(np_uniform).type(torch.float) + + +def seed_all(seed: int): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + np.random.seed(seed) + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) diff --git a/stalactite/utils_main.py b/stalactite/utils_main.py index 859ca54..d7aa800 100644 --- a/stalactite/utils_main.py +++ b/stalactite/utils_main.py @@ -26,6 +26,7 @@ BASE_IMAGE_FILE_CPU = "grpc-base-cpu.dockerfile" BASE_IMAGE_TAG = "grpc-base:latest" PREREQUISITES_NETWORK = "prerequisites_vfl-network" # Do not change this value +MLFLOW_NETWORK = "stalactite-mlflow_default" logger = logging.getLogger(__name__) logging.getLogger('docker').setLevel(logging.ERROR) @@ -186,6 +187,30 @@ def create_and_start_container( return container +def get_mlflow_endpoint(config: VFLConfig) -> str: + mlflow_host = config.prerequisites.mlflow_host + if mlflow_host in ['0.0.0.0', 'localhost']: + logger.info('Searching the MlFlow container locally') + client = APIClient() + try: + container_info = client.inspect_container('stalactite-mlflow-mlflow-vfl-1') + except NotFound as exc: + logger.error( + 'Could not find the `stalactite-mlflow-mlflow-vfl-1` container locally. Are you sure, that you have ' + 'started prerequisites group `mlflow` on current machine?' + ) + raise exc + try: + mlflow_host = container_info['NetworkSettings']['Networks'][MLFLOW_NETWORK]['Gateway'] + except KeyError: + raise ValueError( + 'MlFlow container does not configured via `stalactite prerequisites`, rerun the command or use' + ' host machine IP address in the `config.prerequisites.mlflow_host` configuration parameter' + ) + logger.info(f'Found MlFlow at {mlflow_host}') + return mlflow_host + + def start_distributed_agent( config_path: str, role: str, @@ -240,6 +265,7 @@ def start_distributed_agent( port_binds = {config.grpc_server.port: config.grpc_server.port} ports = [config.grpc_server.port] name = ctx.obj["master_container_name"] + ("-predict" if infer else "") + env_vars['STALACTITE_MLFLOW_HOST'] = get_mlflow_endpoint(config) elif role == Role.arbiter: ports = [config.grpc_arbiter.port] port_binds = {config.grpc_arbiter.port: config.grpc_arbiter.port} @@ -362,11 +388,12 @@ def start_multiprocess_agents( "--infer", "--role", "master" ] if is_infer else None + create_and_start_container( client=client, image=BASE_IMAGE_TAG, container_label=container_label, - environment={"GRPC_ARBITER_HOST": grpc_arbiter_host}, + environment={"GRPC_ARBITER_HOST": grpc_arbiter_host, 'STALACTITE_MLFLOW_HOST': get_mlflow_endpoint(config)}, volumes=volumes, host_config=mounts_host_config, network_config=networking_config,