From 97fd9ae96192fabdb10a3359f3b6893ead48c7f9 Mon Sep 17 00:00:00 2001 From: zakharova-anastasiia Date: Thu, 6 Jun 2024 01:36:36 +0400 Subject: [PATCH] Add GPU support and fix logging --- docker/grpc-base-cpu.dockerfile | 2 +- docker/grpc-base.dockerfile | 2 +- docs/tutorial.rst | 1 + docs/tutorials/plugins.rst | 40 ++++ ...fficientnet-splitNN-mnist-multiprocess.yml | 64 +++++++ .../configs/linreg-mnist-multiprocess.yml | 72 ++++---- .../mlp-splitNN-sbol-smm-multiprocess.yml | 62 +++++++ .../resnet-splitNN-sbol-smm-multiprocess.yml | 61 +++++++ plugins/logistic_regression/__init__.py | 0 plugins/logistic_regression/party_master.py | 129 +++++++++++++ plugins/logistic_regression/party_member.py | 42 +++++ stalactite/base.py | 12 +- .../communications/distributed_grpc_comm.py | 3 +- .../grpc_utils/grpc_arbiter_servicer.py | 2 +- stalactite/configs.py | 2 +- stalactite/data_utils.py | 81 ++++++--- stalactite/helpers.py | 27 ++- stalactite/main.py | 171 +----------------- stalactite/ml/arbitered/base.py | 108 +++++------ .../logistic_regression/party_agent.py | 49 +++-- .../logistic_regression/party_arbiter.py | 17 +- .../logistic_regression/party_master.py | 48 +++-- .../logistic_regression/party_member.py | 26 ++- stalactite/ml/honest/base.py | 89 +++++---- .../honest/linear_regression/party_master.py | 57 +++--- .../honest/linear_regression/party_member.py | 49 ++--- .../logistic_regression/party_master.py | 29 +-- .../logistic_regression/party_member.py | 17 +- stalactite/ml/honest/split_learning/base.py | 97 +++++++--- .../efficientnet/party_master.py | 33 +++- .../efficientnet/party_member.py | 18 +- .../honest/split_learning/mlp/party_master.py | 23 ++- .../honest/split_learning/mlp/party_member.py | 19 +- .../split_learning/resnet/party_master.py | 24 ++- .../split_learning/resnet/party_member.py | 17 +- stalactite/models/efficient_net.py | 19 +- stalactite/models/linreg_batch.py | 8 +- stalactite/models/logreg_batch.py | 1 + stalactite/models/mlp.py | 32 +++- stalactite/models/resnet.py | 67 ++++--- .../split_learning/efficientnet_bottom.py | 15 ++ .../models/split_learning/efficientnet_top.py | 41 +++-- .../models/split_learning/mlp_bottom.py | 18 ++ stalactite/models/split_learning/mlp_top.py | 18 +- .../models/split_learning/resnet_bottom.py | 23 ++- .../models/split_learning/resnet_top.py | 19 ++ stalactite/run_grpc_agent.py | 3 +- stalactite/utils.py | 7 + stalactite/utils_main.py | 2 +- 49 files changed, 1204 insertions(+), 562 deletions(-) create mode 100644 docs/tutorials/plugins.rst create mode 100644 examples/configs/efficientnet-splitNN-mnist-multiprocess.yml create mode 100644 examples/configs/mlp-splitNN-sbol-smm-multiprocess.yml create mode 100644 examples/configs/resnet-splitNN-sbol-smm-multiprocess.yml create mode 100644 plugins/logistic_regression/__init__.py create mode 100644 plugins/logistic_regression/party_master.py create mode 100644 plugins/logistic_regression/party_member.py diff --git a/docker/grpc-base-cpu.dockerfile b/docker/grpc-base-cpu.dockerfile index cd55649..0953117 100644 --- a/docker/grpc-base-cpu.dockerfile +++ b/docker/grpc-base-cpu.dockerfile @@ -24,5 +24,5 @@ RUN poetry install --only-root WORKDIR /opt/stalactite ENV GIT_PYTHON_REFRESH="quiet" LABEL framework="stalactite" - +COPY ./plugins /opt/plugins # docker build -f ./docker/grpc-base.dockerfile -t grpc-base:latest . diff --git a/docker/grpc-base.dockerfile b/docker/grpc-base.dockerfile index f7d4cf0..1d9c823 100644 --- a/docker/grpc-base.dockerfile +++ b/docker/grpc-base.dockerfile @@ -24,5 +24,5 @@ WORKDIR /opt/stalactite ENV GIT_PYTHON_REFRESH="quiet" ENV CUDA_DEVICE_ORDER="PCI_BUS_ID" LABEL framework="stalactite" - +COPY ./plugins /opt/plugins # docker build -f ./docker/grpc-base.dockerfile -t grpc-base:latest . diff --git a/docs/tutorial.rst b/docs/tutorial.rst index b7a3615..14d1fa7 100644 --- a/docs/tutorial.rst +++ b/docs/tutorial.rst @@ -20,6 +20,7 @@ And to launch distributed multinode or multiprocess experiment go to :ref:`distr tutorials/distr_communicator_tutorial tutorials/inference_tutorial tutorials/batching_tutorial + tutorials/plugins tutorials/master_types diff --git a/docs/tutorials/plugins.rst b/docs/tutorials/plugins.rst new file mode 100644 index 0000000..5f2af65 --- /dev/null +++ b/docs/tutorials/plugins.rst @@ -0,0 +1,40 @@ +.. _plugins_tutorial: + +*how-to:* Implement your own ML-algorithm (plugin) +==================================================== + +In the :ref:`master_types`, the implemented algorithms are listed. If you want to incorporate your own +logic into the framework, you should write the agents classes furnished with the specifications on your algorithm. +For the framework to find the plugins you write, you should create (or use existing) folder `plugins` +alongside the sourcecode of the `Stalactite`. + + + +.. code-block:: bash + + |-- ... + |-- plugins + |-- stalactite + `-- ... + +In the `plugins` folder, create a folder containing your agents. The name of this folder does not matter, but +it is important for the agent implementation discovery to name your files correctly: + +- the master class implementation should be placed in a file named: `party_master.py` +- the member class implementation should be placed in a file named: `party_member.py` +- the arbiter (if implemented) should be placed in a file named: `party_arbiter.py` + +We have copied the honest logistic regression implementation into the repository `plugins` folder for you to see as the example. + +At runtime, to use the plugin in the experiment, the configuration file must be adjusted accordingly. For example, to make the framework use +the honest logistic regression implementation from the plugins folder, you should change the ``vfl_model.vfl_model_name`` +to the path from `plugins` to your directory with agents' files. + + +.. code-block:: yaml + + vfl_model: + vfl_model_name: plugins.logistic_regression + +After performing the aforementioned steps, the framework should be able to discover implemented agents and will use them +in the experiment. \ No newline at end of file diff --git a/examples/configs/efficientnet-splitNN-mnist-multiprocess.yml b/examples/configs/efficientnet-splitNN-mnist-multiprocess.yml new file mode 100644 index 0000000..8f8070a --- /dev/null +++ b/examples/configs/efficientnet-splitNN-mnist-multiprocess.yml @@ -0,0 +1,64 @@ +common: + report_train_metrics_iteration: 1 + report_test_metrics_iteration: 1 + world_size: 2 + experiment_label: experiment-efficientnet-mnist-local + reports_export_folder: "../../reports" + seed: 22 + +vfl_model: + epochs: 2 + batch_size: 250 + eval_batch_size: 9000 + vfl_model_name: efficientnet + is_consequently: False + use_class_weights: True + learning_rate: 0.01 + do_train: True + do_predict: False + do_save_model: True + vfl_model_path: ../../saved_models/efficientnet_model + +prerequisites: + mlflow_host: 'node3.bdcl' + mlflow_port: '5555' + +master: + external_host: 'node3.bdcl' + run_mlflow: True + master_model_params: { + input_dim: 128, + dropout: 0.2, + num_classes: 10, + } + run_prometheus: False + port: "50051" + logging_level: 'debug' + disconnect_idle_client_time: 500. + recv_timeout: 3600. + cuda_visible_devices: "0" + +member: + member_model_params: { + width_mult: 0.1, + depth_mult: 0.1, + } + heartbeat_interval: 2. + logging_level: 'info' + recv_timeout: 3600. + +data: + dataset_size: 750 + dataset: 'mnist' + host_path_data_dir: ../../data/sber_ds_vfl/mnist_efficientnet_multiclass + dataset_part_prefix: 'part_' # used in dataset folder structure inspection. Concatenated with the index of a party: 0,1,... etc. + train_split: "train_train" # name of the train split + test_split: "train_val" # name of the test split + features_key: "image_part_" + label_key: "label" + uids_key: "image_idx" + +docker: + docker_compose_command: "docker compose" + docker_compose_path: '../../prerequisites' + use_gpu: True \ No newline at end of file diff --git a/examples/configs/linreg-mnist-multiprocess.yml b/examples/configs/linreg-mnist-multiprocess.yml index 7778550..46b31d1 100644 --- a/examples/configs/linreg-mnist-multiprocess.yml +++ b/examples/configs/linreg-mnist-multiprocess.yml @@ -1,66 +1,58 @@ common: - experiment_label: experiment-vm - reports_export_folder: "../../reports" report_train_metrics_iteration: 1 report_test_metrics_iteration: 1 world_size: 2 + experiment_label: test-experiment-mnist-local + reports_export_folder: ../../reports + seed: 22 vfl_model: - epochs: 5 - batch_size: 1000 + epochs: 2 + batch_size: 5000 + eval_batch_size: 200 vfl_model_name: linreg is_consequently: False use_class_weights: False - learning_rate: 0.01 - -prerequisites: - mlflow_host: 'node3.bdcl' - mlflow_port: '5555' - prometheus_host: 'node3.bdcl' - prometheus_port: '9090' - grafana_port: '3001' + learning_rate: 0.2 + do_train: True + do_predict: True + do_save_model: True + vfl_model_path: ../../saved_models/linreg_model data: - dataset_size: 1000 - host_path_data_dir: ../../data/sber_ds_vfl/mnist_binary38_parts2 + dataset_size: 5000 dataset: 'mnist' - dataset_part_prefix: 'part_' - train_split: "train_train" - test_split: "train_val" + host_path_data_dir: ../../data/sber_ds_vfl/mnist_vfl_parts2 + dataset_part_prefix: 'part_' # used in dataset folder structure inspection. Concatenated with the index of a party: 0,1,... etc. + train_split: "train_train" # name of the train split + test_split: "train_val" # name of the test split features_key: "image_part_" label_key: "label" + uids_key: "image_idx" -grpc_server: - port: '50051' - max_message_size: -1 - # server_threadpool_max_workers: 10 - +prerequisites: + mlflow_host: 'node3.bdcl' + mlflow_port: '5555' master: external_host: 'node3.bdcl' - run_mlflow: True - run_prometheus: True + run_prometheus: False + port: "50051" logging_level: 'debug' - disconnect_idle_client_time: 120. - # time_between_idle_connections_checks: 3 - # recv_timeout: 360 + disconnect_idle_client_time: 500. + recv_timeout: 3600. + cuda_visible_devices: "0" member: - logging_level: 'debug' + member_model_params: { + output_dim: 1, + reg_lambda: 0.5 + } heartbeat_interval: 2. - # heartbeat_interval: 2 - # sent_task_timout: 3600 + logging_level: 'info' + recv_timeout: 3600. docker: docker_compose_command: "docker compose" docker_compose_path: '../../prerequisites' - use_gpu: False - - -#grpc_arbiter: -# use_arbiter: False - - - - - + use_gpu: True \ No newline at end of file diff --git a/examples/configs/mlp-splitNN-sbol-smm-multiprocess.yml b/examples/configs/mlp-splitNN-sbol-smm-multiprocess.yml new file mode 100644 index 0000000..3ae1b06 --- /dev/null +++ b/examples/configs/mlp-splitNN-sbol-smm-multiprocess.yml @@ -0,0 +1,62 @@ +common: + report_train_metrics_iteration: 10 + report_test_metrics_iteration: 10 + world_size: 2 + experiment_label: experiment-mlp-sbol-smm-local + reports_export_folder: "../../reports" + seed: 22 + +vfl_model: + epochs: 2 + batch_size: 250 + eval_batch_size: 200 + vfl_model_name: mlp + is_consequently: False + use_class_weights: True + learning_rate: 0.01 + do_train: True + do_predict: False + do_save_model: True + vfl_model_path: ../../saved_models/mlp_model + +prerequisites: + mlflow_host: 'node3.bdcl' + mlflow_port: '5555' + +master: + external_host: 'node3.bdcl' + run_mlflow: True + master_model_params: { + input_dim: 100, + output_dim: 19, + multilabel: True, + } + run_prometheus: False + port: "50051" + logging_level: 'debug' + disconnect_idle_client_time: 500. + recv_timeout: 3600. + cuda_visible_devices: "0" + +member: + member_model_params: { + hidden_channels:[1000, 300, 100], + } + heartbeat_interval: 2. + logging_level: 'info' + recv_timeout: 3600. + +data: + dataset_size: 10000 + dataset: 'sbol_smm' + host_path_data_dir: ../../data/sber_ds_vfl/multilabel_sber_sample10000_smm_parts2 + dataset_part_prefix: 'part_' + train_split: "train_train" + test_split: "train_val" + features_key: "features_part_" + label_key: "labels" + +docker: + docker_compose_command: "docker compose" + docker_compose_path: '../../prerequisites' + use_gpu: True \ No newline at end of file diff --git a/examples/configs/resnet-splitNN-sbol-smm-multiprocess.yml b/examples/configs/resnet-splitNN-sbol-smm-multiprocess.yml new file mode 100644 index 0000000..39f0c3f --- /dev/null +++ b/examples/configs/resnet-splitNN-sbol-smm-multiprocess.yml @@ -0,0 +1,61 @@ +common: + report_train_metrics_iteration: 1 + report_test_metrics_iteration: 1 + world_size: 2 + experiment_label: experiment-resnet-sbol-smm-local + reports_export_folder: "../../reports" + +vfl_model: + epochs: 2 + batch_size: 2500 + eval_batch_size: 2000 + vfl_model_name: resnet + is_consequently: False + use_class_weights: False + learning_rate: 0.01 + do_train: True + do_predict: False + do_save_model: True + vfl_model_path: ../../saved_models/resnet_model + +prerequisites: + mlflow_host: 'node3.bdcl' + mlflow_port: '5555' + +master: + external_host: 'node3.bdcl' + run_mlflow: True + master_model_params: { + input_dim: 1356, + output_dim: 19, + use_bn: True, + } + run_prometheus: False + port: "50051" + logging_level: 'debug' + disconnect_idle_client_time: 500. + recv_timeout: 3600. + cuda_visible_devices: "0" + +member: + member_model_params: { + hid_factor: [ 1, 1 ], + } + heartbeat_interval: 2. + logging_level: 'info' + recv_timeout: 3600. + +data: + dataset_size: 10000 + dataset: 'sbol_smm' + host_path_data_dir: ../../data/sber_ds_vfl/multilabel_sber_sample10000_smm_parts2 + dataset_part_prefix: 'part_' + train_split: "train_train" + test_split: "train_val" + features_key: "features_part_" + label_key: "labels" + +docker: + docker_compose_command: "docker compose" + docker_compose_path: '../../prerequisites' + use_gpu: True \ No newline at end of file diff --git a/plugins/logistic_regression/__init__.py b/plugins/logistic_regression/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/logistic_regression/party_master.py b/plugins/logistic_regression/party_master.py new file mode 100644 index 0000000..cd4dda4 --- /dev/null +++ b/plugins/logistic_regression/party_master.py @@ -0,0 +1,129 @@ +import logging +from typing import List + +import mlflow +import torch +from sklearn.metrics import roc_auc_score, root_mean_squared_error + +from stalactite.base import DataTensor, PartyDataTensor +from stalactite.ml.honest.linear_regression.party_master import HonestPartyMasterLinReg + +logger = logging.getLogger(__name__) + + +class HonestPartyMasterLogReg(HonestPartyMasterLinReg): + """ Implementation class of the VFL honest PartyMaster specific to the Logistic Regression algorithm. """ + + def make_init_updates(self, world_size: int) -> PartyDataTensor: + """ Make initial updates for logistic regression. + + :param world_size: Number of party members. + :return: Initial updates as a list of zero tensors. + """ + logger.info(f"Master {self.id}: makes initial updates for {world_size} members") + self.check_if_ready() + return [torch.zeros(self._batch_size).to(self.device) for _ in range(world_size)] + + def aggregate( + self, participating_members: List[str], party_predictions: PartyDataTensor, is_infer: bool = False + ) -> DataTensor: + """ Aggregate party predictions for logistic regression. + + :param participating_members: List of participating party member identifiers. + :param party_predictions: List of party predictions. + :param is_infer: Flag indicating whether to perform inference. + + :return: Aggregated predictions after applying sigmoid function. + """ + logger.info(f"Master {self.id}: aggregates party predictions (number of predictions {len(party_predictions)})") + self.check_if_ready() + if not is_infer: + for member_id, member_prediction in zip(participating_members, party_predictions): + self.party_predictions[member_id] = member_prediction + party_predictions = list(self.party_predictions.values()) + predictions = torch.sum(torch.stack(party_predictions, dim=1).to(self.device), dim=1) + else: + predictions = self.activation(torch.sum(torch.stack(party_predictions, dim=1).to(self.device), dim=1)) + return predictions + + def compute_updates( + self, + participating_members: List[str], + predictions: DataTensor, + party_predictions: PartyDataTensor, + world_size: int, + uids: list[str], + ) -> List[DataTensor]: + """ Compute updates for logistic regression. + + :param participating_members: List of participating party members identifiers. + :param predictions: Model predictions. + :param party_predictions: List of party predictions. + :param world_size: Number of party members. + :param subiter_seq_num: Sub-iteration sequence number. + + :return: List of gradients as tensors. + """ + logger.info(f"Master {self.id}: computes updates (world size {world_size})") + self.check_if_ready() + self.iteration_counter += 1 + tensor_idx = [self.uid2tensor_idx[uid] for uid in uids] + y = self.target[tensor_idx] + criterion = torch.nn.BCEWithLogitsLoss(pos_weight=self.class_weights) \ + if self.binary else torch.nn.CrossEntropyLoss(weight=self.class_weights) + targets_type = torch.LongTensor if isinstance(criterion, + torch.nn.CrossEntropyLoss) else torch.FloatTensor + predictions = predictions.to(self.device) + loss = criterion(torch.squeeze(predictions), y.type(targets_type).to(self.device)) + grads = torch.autograd.grad(outputs=loss, inputs=predictions) + + for i, member_id in enumerate(participating_members): + self.updates[member_id] = grads[0] + logger.debug(f"Master {self.id}: computed updates") + return [self.updates[member_id] for member_id in participating_members] + + def report_metrics(self, y: DataTensor, predictions: DataTensor, name: str, step: int) -> None: + """Report metrics for logistic regression. + + Compute main classification metrics, if `use_mlflow` parameter was set to true, log them to MlFLow, log them to + stdout. + + :param y: Target values. + :param predictions: Model predictions. + :param name: Name of the dataset ("Train" or "Test"). + + :return: None. + """ + logger.info(f"Master {self.id} reporting metrics") + logger.debug(f"Predictions size: {predictions.size()}, Target size: {y.size()}") + + y = y.cpu().numpy() + predictions = predictions.cpu().detach().numpy() + postfix = '-infer' if step == -1 else "" + step = step if step != -1 else None + + if self.binary: + for avg in ["macro", "micro"]: + try: + roc_auc = roc_auc_score(y, predictions, average=avg) + except ValueError: + roc_auc = 0 + + rmse = root_mean_squared_error(y, predictions) + + logger.info(f'{name} RMSE on step {step}: {rmse}') + logger.info(f'{name} ROC AUC {avg} on step {step}: {roc_auc}') + if self.run_mlflow: + mlflow.log_metric(f"{name.lower()}_roc_auc_{avg}{postfix}", roc_auc, step=step) + mlflow.log_metric(f"{name.lower()}_rmse{postfix}", rmse, step=step) + else: + avg = "macro" + try: + roc_auc = roc_auc_score(y, predictions, average=avg, multi_class="ovr") + except ValueError: + roc_auc = 0 + + logger.info(f'{name} ROC AUC {avg} on step {step}: {roc_auc}') + + if self.run_mlflow: + mlflow.log_metric(f"{name.lower()}_roc_auc_{avg}{postfix}", roc_auc, step=step) diff --git a/plugins/logistic_regression/party_member.py b/plugins/logistic_regression/party_member.py new file mode 100644 index 0000000..e03ae92 --- /dev/null +++ b/plugins/logistic_regression/party_member.py @@ -0,0 +1,42 @@ +import logging +from typing import Any + +from torch.optim import SGD + +from stalactite.ml.honest.linear_regression.party_member import HonestPartyMemberLinReg +from stalactite.models import LogisticRegressionBatch +from stalactite.utils import init_linear_np + +logger = logging.getLogger(__name__) + +class HonestPartyMemberLogReg(HonestPartyMemberLinReg): + def initialize_model_from_params(self, **model_params) -> Any: + return LogisticRegressionBatch(**model_params).to(self.device) + + def initialize_model(self, do_load_model: bool = False) -> None: + """ Initialize the model based on the specified model name. """ + logger.info(f"Member {self.id} initializes model on device: {self.device}") + logger.debug(f"Model is loaded from path: {do_load_model}") + if do_load_model: + self._model = self.load_model().to(self.device) + else: + self._model = LogisticRegressionBatch( + input_dim=self._dataset[self._data_params.train_split][self._data_params.features_key].shape[1], + **self._model_params + ).to(self.device) + init_linear_np(self._model.linear, seed=self.seed) + self._model.linear.to(self.device) + + def initialize_optimizer(self) -> None: + self._optimizer = SGD([ + {"params": self._model.parameters()}, + ], + lr=self._common_params.learning_rate, + momentum=self._common_params.momentum, + weight_decay=self._common_params.weight_decay, + ) + + def move_model_to_device(self): + # As the class is inherited from the Linear regression model, we need to skip this step with returning the model + # to device after weights updates + pass diff --git a/stalactite/base.py b/stalactite/base.py index e48f1a3..1376dd2 100644 --- a/stalactite/base.py +++ b/stalactite/base.py @@ -15,7 +15,6 @@ logger = logging.getLogger(__name__) - DataTensor = torch.Tensor # in reality, it will be a DataTensor but with one more dimension PartyDataTensor = List[torch.Tensor] @@ -327,6 +326,7 @@ def save_model(self, is_ovr_models: bool = False): if self.model_path is None: raise RuntimeError('If `do_save_model` is True, the `model_path` must be not None.') os.makedirs(os.path.join(self.model_path, f'agent_{self.id}'), exist_ok=True) + logger.info(f"Agent {self.id} saving the model to {self.model_path}") if is_ovr_models: for idx, model in enumerate(self._model): torch.save(model.state_dict(), os.path.join(self.model_path, f'agent_{self.id}', f'model-{idx}.pt')) @@ -356,6 +356,10 @@ def load_model(self, is_ovr_models: bool = False) -> Any: with open(os.path.join(agent_model_path, 'model_init_params.json')) as f: init_model_params = json.load(f) + if (model_params := getattr(self, '_model_params', None)) is not None: + for k, v in model_params.items(): + init_model_params[k] = v + if is_ovr_models: logger.info(f'Loading OVR models from {agent_model_path}') model = [] @@ -397,7 +401,7 @@ def synchronize_uids( :return: Common records identifiers among the agents used in training loop. """ - logger.debug("Master %s: synchronizing uids for party of size %s" % (self.id, world_size)) + logger.debug(f"Master {self.id}: synchronizing uuids for party of size {world_size}") inner_collected_uids = [col_uids[0] for col_uids in collected_uids if col_uids[1]] uids = self.inference_target_uids if is_infer else self.target_uids if len(inner_collected_uids) > 0: @@ -407,12 +411,12 @@ def synchronize_uids( shared_uids = sorted( [uid for uid, count in collections.Counter(uids).items() if count == len(inner_collected_uids) + 1] ) - logger.debug("Master %s: registering shared uids f size %s" % (self.id, len(shared_uids))) + logger.debug(f"Master {self.id}: registering shared uuids of size {len(shared_uids)}") if is_infer: self.inference_target_uids = shared_uids else: self.target_uids = shared_uids - logger.debug("Master %s: record uids has been successfully synchronized") + logger.debug(f"Master {self.id}: record uuids has been successfully synchronized") return shared_uids @abstractmethod diff --git a/stalactite/communications/distributed_grpc_comm.py b/stalactite/communications/distributed_grpc_comm.py index 97106ef..94b1e04 100644 --- a/stalactite/communications/distributed_grpc_comm.py +++ b/stalactite/communications/distributed_grpc_comm.py @@ -37,7 +37,8 @@ start_thread, ) from stalactite.communications.helpers import METHOD_VALUES, Method, MethodKwargs -from stalactite.ml.arbitered.base import PartyArbiter, Role +from stalactite.ml.arbitered.base import PartyArbiter +from stalactite.utils import Role logger = logging.getLogger(__name__) diff --git a/stalactite/communications/grpc_utils/grpc_arbiter_servicer.py b/stalactite/communications/grpc_utils/grpc_arbiter_servicer.py index 3d9f69e..ba2e126 100644 --- a/stalactite/communications/grpc_utils/grpc_arbiter_servicer.py +++ b/stalactite/communications/grpc_utils/grpc_arbiter_servicer.py @@ -9,7 +9,7 @@ from stalactite.communications.grpc_utils.generated_code import arbitered_communicator_pb2, \ arbitered_communicator_pb2_grpc from stalactite.communications.grpc_utils.utils import Status -from stalactite.ml.arbitered.base import Role +from stalactite.utils import Role logger = logging.getLogger(__name__) logging.getLogger('asyncio').setLevel(logging.ERROR) diff --git a/stalactite/configs.py b/stalactite/configs.py index f29ffc2..24d068a 100644 --- a/stalactite/configs.py +++ b/stalactite/configs.py @@ -73,7 +73,7 @@ class VFLModelConfig(BaseModel): epochs: int = Field(default=3, description="Number of epochs to train a model") batch_size: int = Field(default=100, description="Batch size used for training") eval_batch_size: int = Field(default=100, description="Batch size used for evaluation") - vfl_model_name: Literal['linreg', 'logreg', 'logreg_sklearn', 'efficientnet', 'mlp', 'resnet'] = Field( + vfl_model_name: str = Field( default='linreg', description='Model type. One of `linreg`, `logreg`, `logreg_sklearn`, `efficientnet`, `mlp`, `resnet`' ) diff --git a/stalactite/data_utils.py b/stalactite/data_utils.py index 6f238c3..236c467 100644 --- a/stalactite/data_utils.py +++ b/stalactite/data_utils.py @@ -19,6 +19,8 @@ ) from stalactite.configs import VFLConfig +from stalactite.helpers import get_plugin_agent +from stalactite.utils import Role from examples.utils.local_experiment import load_processors as load_processors_honest from examples.utils.local_arbitered_experiment import load_processors as load_processors_arbitered @@ -35,7 +37,10 @@ def get_party_master(config_path: str, is_infer: bool = False) -> PartyMaster: master_processor = master_processor if config.data.dataset.lower() == "sbol_smm" else processors[0] if config.data.dataset_size == -1: config.data.dataset_size = len(master_processor.dataset[config.data.train_split][config.data.uids_key]) - master_class = ArbiteredPartyMasterLogReg + if config.vfl_model.vfl_model_name in ['logreg']: + master_class = ArbiteredPartyMasterLogReg + else: + master_class = get_plugin_agent(config.vfl_model.vfl_model_name, Role.master) if config.grpc_arbiter.security_protocol_params is not None: if config.grpc_arbiter.security_protocol_params.he_type == 'paillier': sp_agent = SecurityProtocolPaillier(**config.grpc_arbiter.security_protocol_params.init_params) @@ -62,26 +67,31 @@ def get_party_master(config_path: str, is_infer: bool = False) -> PartyMaster: do_train=not is_infer, do_save_model=config.vfl_model.do_save_model, model_path=config.vfl_model.vfl_model_path, - seed=config.common.seed + seed=config.common.seed, + device='cuda' if config.docker.use_gpu else 'cpu', ) else: master_processor, processors = load_processors_honest(config) if config.data.dataset_size == -1: config.data.dataset_size = len(master_processor.dataset[config.data.train_split][config.data.uids_key]) - if 'logreg' in config.vfl_model.vfl_model_name: - master_class = HonestPartyMasterLogReg - elif "resnet" in config.vfl_model.vfl_model_name: - master_class = HonestPartyMasterResNetSplitNN - elif "efficientnet" in config.vfl_model.vfl_model_name: - master_class = HonestPartyMasterEfficientNetSplitNN - elif "mlp" in config.vfl_model.vfl_model_name: - master_class = HonestPartyMasterMLPSplitNN - else: - if config.vfl_model.is_consequently: - master_class = HonestPartyMasterLinRegConsequently + if config.vfl_model.vfl_model_name in ['logreg', 'resnet', 'efficientnet', 'mlp', 'linreg']: + if 'logreg' in config.vfl_model.vfl_model_name: + master_class = HonestPartyMasterLogReg + elif "resnet" in config.vfl_model.vfl_model_name: + master_class = HonestPartyMasterResNetSplitNN + elif "efficientnet" in config.vfl_model.vfl_model_name: + master_class = HonestPartyMasterEfficientNetSplitNN + elif "mlp" in config.vfl_model.vfl_model_name: + master_class = HonestPartyMasterMLPSplitNN else: - master_class = HonestPartyMasterLinReg + if config.vfl_model.is_consequently: + master_class = HonestPartyMasterLinRegConsequently + else: + master_class = HonestPartyMasterLinReg + else: + master_class = get_plugin_agent(config.vfl_model.vfl_model_name, Role.master) + return master_class( uid="master", epochs=config.vfl_model.epochs, @@ -100,7 +110,10 @@ def get_party_master(config_path: str, is_infer: bool = False) -> PartyMaster: model_name=config.vfl_model.vfl_model_name if config.vfl_model.vfl_model_name in ["resnet", "mlp", "efficientnet"] else None, model_params=config.master.master_model_params, - seed=config.common.seed + seed=config.common.seed, + device='cuda' if config.docker.use_gpu else 'cpu', + do_save_model=config.vfl_model.do_save_model, + model_path=config.vfl_model.vfl_model_path, ) @@ -108,7 +121,10 @@ def get_party_member(config_path: str, member_rank: int, is_infer: bool = False) config = VFLConfig.load_and_validate(config_path) if config.grpc_arbiter.use_arbiter: master_processor, processors = load_processors_arbitered(config) - member_class = ArbiteredPartyMemberLogReg + if config.vfl_model.vfl_model_name in ['logreg']: + member_class = ArbiteredPartyMemberLogReg + else: + member_class = get_plugin_agent(config.vfl_model.vfl_model_name, Role.member) if config.grpc_arbiter.security_protocol_params is not None: if config.grpc_arbiter.security_protocol_params.he_type == 'paillier': sp_agent = SecurityProtocolPaillier(**config.grpc_arbiter.security_protocol_params.init_params) @@ -135,21 +151,25 @@ def get_party_member(config_path: str, member_rank: int, is_infer: bool = False) do_save_model=config.vfl_model.do_save_model, model_path=config.vfl_model.vfl_model_path, use_inner_join=False, - seed=config.common.seed + seed=config.common.seed, + device='cuda' if config.docker.use_gpu else 'cpu', ) else: master_processor, processors = load_processors_honest(config) - if 'logreg' in config.vfl_model.vfl_model_name: - member_class = HonestPartyMemberLogReg - elif "resnet" in config.vfl_model.vfl_model_name: - member_class = HonestPartyMemberResNet - elif "efficientnet" in config.vfl_model.vfl_model_name: - member_class = HonestPartyMemberEfficientNet - elif "mlp" in config.vfl_model.vfl_model_name: - member_class = HonestPartyMemberMLP + if config.vfl_model.vfl_model_name in ['logreg', 'resnet', 'efficientnet', 'mlp', 'linreg']: + if 'logreg' in config.vfl_model.vfl_model_name: + member_class = HonestPartyMemberLogReg + elif "resnet" in config.vfl_model.vfl_model_name: + member_class = HonestPartyMemberResNet + elif "efficientnet" in config.vfl_model.vfl_model_name: + member_class = HonestPartyMemberEfficientNet + elif "mlp" in config.vfl_model.vfl_model_name: + member_class = HonestPartyMemberMLP + else: + member_class = HonestPartyMemberLinReg else: - member_class = HonestPartyMemberLinReg + member_class = get_plugin_agent(config.vfl_model.vfl_model_name, Role.member) member_ids = [f"member-{member_rank}" for member_rank in range(config.common.world_size)] return member_class( @@ -171,7 +191,8 @@ def get_party_member(config_path: str, member_rank: int, is_infer: bool = False) model_path=config.vfl_model.vfl_model_path, model_params=config.member.member_model_params, use_inner_join=True if member_rank == 0 else False, - seed=config.common.seed + seed=config.common.seed, + device='cuda' if config.docker.use_gpu else 'cpu', ) @@ -179,8 +200,10 @@ def get_party_arbiter(config_path: str, is_infer: bool = False) -> PartyArbiter: config = VFLConfig.load_and_validate(config_path) if not config.grpc_arbiter.use_arbiter: raise RuntimeError('Arbiter should not be called in honest setting.') - - arbiter_class = PartyArbiterLogReg + if config.vfl_model.vfl_model_name in ['logreg']: + arbiter_class = PartyArbiterLogReg + else: + arbiter_class = get_plugin_agent(config.vfl_model.vfl_model_name, Role.arbiter) if config.grpc_arbiter.security_protocol_params is not None: if config.grpc_arbiter.security_protocol_params.he_type == 'paillier': sp_arbiter = SecurityProtocolArbiterPaillier(**config.grpc_arbiter.security_protocol_params.init_params) diff --git a/stalactite/helpers.py b/stalactite/helpers.py index eb865db..a052f11 100644 --- a/stalactite/helpers.py +++ b/stalactite/helpers.py @@ -1,3 +1,5 @@ +import importlib +import inspect import logging import os import time @@ -9,7 +11,28 @@ from stalactite.base import PartyMaster, PartyMember from stalactite.configs import VFLConfig -from stalactite.ml.arbitered.base import PartyArbiter, Role +from stalactite.ml.arbitered.base import PartyArbiter +from stalactite.utils import Role + + +def get_plugin_agent(module_path: str, role: Role): + agent_path = f"{module_path}.party_{role}" + target_class = None + try: + module = importlib.import_module(agent_path) + except ModuleNotFoundError as exc: + raise ValueError(f"No {role} is defined in plugin/. Check correctness of the config file") from exc + classes = inspect.getmembers(module, inspect.isclass) + for cls, path in classes: + if agent_path in path.__module__ and f"party{role}" in cls.lower(): + target_class = cls + if target_class is not None: + return getattr(module, target_class) + else: + raise NameError( + f"Defined classes` names violate the naming convention " + f"(Party{role.capitalize()})" + ) def global_logging( @@ -38,6 +61,8 @@ def global_logging( logger.setLevel(logging.INFO) else: logger.setLevel(logging_level) + else: + logger.propagate = False @contextmanager diff --git a/stalactite/main.py b/stalactite/main.py index 5f3e21d..27ac9eb 100644 --- a/stalactite/main.py +++ b/stalactite/main.py @@ -7,7 +7,6 @@ import enum import logging import os -from pathlib import Path import click import mlflow as _mlflow @@ -52,11 +51,13 @@ def cli(): """Main stalactite CLI command group.""" click.echo("Stalactite module API") + @cli.command() @click.option( "--remove-images", is_flag=True, show_default=True, default=False, help="Remove built images." ) def clean_all(remove_images): + """ Clear all created Docker objects (containers, volumes, networks, images). """ click.echo("Removing all the stalactite related docker objects") client = APIClient() filters = { @@ -70,10 +71,10 @@ def clean_all(remove_images): except NotFound: logger.warning('Could not remove container, resource already has been deleted') volumes = client.volumes(filters=filters) - click.echo(f'Removing {len(volumes)} volumes') - for volume in volumes: + click.echo(f'Removing {len(volumes["Volumes"])} volumes') + for volume in volumes['Volumes']: try: - client.remove_volume(volume, force=True) + client.remove_volume(volume['Name'], force=True) except NotFound: logger.warning('Could not remove volume, resource already has been deleted') networks = client.networks(filters=filters) @@ -651,167 +652,6 @@ def logs(ctx, agent_id, follow, tail): logger.info("Logs in the single-process mode are not available") -@cli.group() -@click.pass_context -@click.option( - "--single-process", - is_flag=True, - show_default=True, - default=False, - help="Run single-node single-process (multi-thread) test.", -) -@click.option( - "--multi-process", - is_flag=True, - show_default=True, - default=False, - help="Run single-node multi-process (dockerized) test.", -) -def test(ctx, multi_process, single_process): - """ - Local tests (multi-process / single process) mode command group. - - :param ctx: Click context - :param single_process: Run tests on single process experiment - :param multi_process: Run tests on multiple process (dockerized) experiment - """ - if multi_process and not single_process: - click.echo("Manage multi-process single-node tests") - elif single_process and not multi_process: - click.echo("Manage single-process (multi-thread) single-node tests") - else: - raise SyntaxError("Either `--single-process` or `--multi-process` flag can be set.") - ctx.obj = dict() - ctx.obj["multi_process"] = multi_process - ctx.obj["single_process"] = single_process - - -@test.command() -@click.pass_context -@click.option( - "--config-path", type=str, required=True, help="Absolute path to the configuration file in `YAML` format." -) -def start(ctx, config_path): - """ - Start local test experiment. - For a multiprocess mode run integration tests on started VFL master and members containers. - Single-process test will run tests within a python process. - - :param ctx: Click context - :param config_path: Absolute path to the configuration file in `YAML` format - """ - config = VFLConfig.load_and_validate(config_path) - if ctx.obj["multi_process"] and not ctx.obj["single_process"]: - test_group_name = "TestLocalGroupStart" - report_file_name = f"{test_group_name}-log-{config.common.experiment_label}.jsonl" - run_subprocess_command( - command=f"python -m pytest --test_config_path {config_path} " - f"tests/distributed_grpc/integration_test.py -k '{test_group_name}' -x " - f"--report-log={os.path.join(config.common.reports_export_folder, report_file_name)} " - "-W ignore::DeprecationWarning", - logger_err_info="Failed running test", - cwd=Path(__file__).parent.parent, - shell=True, - ) - else: - report_file_name = f"local-tests-log-{config.common.experiment_label}.jsonl" - run_subprocess_command( - command=f"python -m pytest tests/test_local.py -x " - f"--report-log={os.path.join(config.common.reports_export_folder, report_file_name)} " - "-W ignore::DeprecationWarning --log-cli-level 20", - logger_err_info="Failed running test", - cwd=Path(__file__).parent.parent, - shell=True, - ) - - -@test.command() -@click.pass_context -@click.option("--config-path", type=str, help="Absolute path to the configuration file in `YAML` format.") -@click.option( - "--no-tests", - is_flag=True, - show_default=True, - default=False, - help="Remove test containers without launching Pytest.", -) -def stop(ctx, config_path, no_tests): - """ - Stop local test experiment. - For a multiprocess mode run integration tests while stopping VFL master and members containers. - Does nothing for a single-process mode. - - :param ctx: Click context - :param config_path: Absolute path to the configuration file in `YAML` format - :param no_tests: Remove test containers without launching Pytest. Useful if `start` command did not succeed, - but containers have already been created - """ - if config_path is None and not no_tests: - raise SyntaxError("Specify `--config-path` or pass flag `--no-tests`") - if no_tests: - _test = 1 - if ctx.obj["multi_process"] and not ctx.obj["single_process"]: - logger.info("Removing test containers") - client = ctx.obj.get("client", APIClient()) - try: - container_label = BASE_CONTAINER_LABEL + ("-test" if _test else "") - containers = client.containers(all=True, filters={"label": f"{KEY_CONTAINER_LABEL}={container_label}"}) - stop_containers(client, containers, leave_containers=False) - except APIError as exc: - logger.error("Error while stopping (and removing) containers", exc_info=exc) - return - config = VFLConfig.load_and_validate(config_path) - if ctx.obj["multi_process"] and not ctx.obj["single_process"]: - test_group_name = "TestLocalGroupStop" - report_file_name = f"{test_group_name}-log-{config.common.experiment_label}.jsonl" - run_subprocess_command( - command=f"python -m pytest --test_config_path {config_path} " - f"tests/distributed_grpc/integration_test.py -k '{test_group_name}' -x " - f"--report-log={os.path.join(config.common.reports_export_folder, report_file_name)} " - "-W ignore::DeprecationWarning", - logger_err_info="Failed running test", - cwd=Path(__file__).parent.parent, - shell=True, - ) - - -@test.command() -@click.option("--agent-id", type=str, default=None, help="ID of the agents` container.") -def status(agent_id): - """ - Print status of the experimental test container(s). - If the `agent-id` is not passed, all the created on test containers` statuses will be returned. - - :param ctx: Click context - :param agent_id: ID of the agents` container - """ - _test = True - container_label = BASE_CONTAINER_LABEL + ("-test" if _test else "") - get_status(agent_id=agent_id, containers_label=f"{KEY_CONTAINER_LABEL}={container_label}") - - -@test.command() -@click.option("--agent-id", type=str, default=None, help="ID of the agents` container.") -@click.option("--tail", type=str, default="all", help="Number of lines to show from the end of the logs.") -@click.option("--config-path", type=str, default=None, help="Absolute path to the configuration file in `YAML` format.") -def logs(agent_id, config_path, tail): - """ - Print logs of the experimental test container or return path to tests` logs. - If the `agent-id` is passed, show container logs, otherwise, prints test report - - :param agent_id: ID of the agents` container - :param config_path: Absolute path to the configuration file in `YAML` format - :param tail: Number of lines to show from the end of the logs - """ - if agent_id is None and config_path is None: - raise SyntaxError("Either `--agent-id` or `--config-path` argument must be specified.") - if agent_id is not None: - get_logs(agent_id=agent_id, tail=tail) - if config_path is not None: - config = VFLConfig.load_and_validate(config_path) - logger.info(f"Test-report-logs path: {config.common.reports_export_folder}") - - @cli.command() @click.option("--config-path", type=str, required=True) @click.option( @@ -829,6 +669,7 @@ def logs(agent_id, config_path, tail): help="Run single-node multi-process (dockerized) test.", ) def predict(multi_process, single_process, config_path): + """ Run VFL inference for local experiments (multi-process / single process) """ click.echo("Run VFL predictions") if multi_process and not single_process: client = APIClient() diff --git a/stalactite/ml/arbitered/base.py b/stalactite/ml/arbitered/base.py index 3880626..e10d891 100644 --- a/stalactite/ml/arbitered/base.py +++ b/stalactite/ml/arbitered/base.py @@ -1,4 +1,3 @@ -import enum import logging import time from abc import ABC, abstractmethod @@ -33,12 +32,6 @@ class Keys: private: Any = None -class Role(str, enum.Enum): - arbiter = "arbiter" - master = "master" - member = "member" - - class SecurityProtocol(ABC): """ Base proxy class for Homomorphic Encryption (HE) protocol. """ _keys: Optional[Keys] = None @@ -168,6 +161,7 @@ def register_records_uids(self, uids: List[str]) -> None: def get_public_key(self) -> Keys: """ Return public key if the security protocol is initialized, otherwise, return an empty Keys object. """ + logger.info(f'Arbiter {self.id} returns public key') if self.security_protocol is not None: return self.security_protocol.public_key return Keys() @@ -191,7 +185,7 @@ def calculate_updates(self, gradients: dict) -> dict[str, DataTensor]: def run(self, party: PartyCommunicator) -> None: """ Run main arbiter loops (training and inference). """ - logger.info("Running arbiter %s" % self.id) + logger.info(f"Running arbiter {self.id}") if self.do_train: self.fit(party) @@ -230,12 +224,16 @@ def fit(self, party: PartyCommunicator): finalize_task = party.recv(Task(method_name=Method.finalize, from_id=party.master, to_id=self.id)) self.execute_received_task(finalize_task) - logger.info("Finished master %s" % self.id) + logger.info(f"Finished arbiter {self.id}") def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: """ Main training loop on arbiter. """ - logger.info("Arbiter %s: entering training loop" % self.id) + logger.info(f"Arbiter {self.id}: entering training loop") for titer in batcher: + logger.info( + f"Arbiter {self.id}: train loop - starting batch {titer.seq_num} (sub iter {titer.subiter_seq_num}) " + f"on epoch {titer.epoch}" + ) master_gradient = party.recv( Task(method_name=Method.calculate_updates, from_id=party.master, to_id=self.id), recv_results=False @@ -347,7 +345,7 @@ def aggregate_predictions( def run(self, party: PartyCommunicator) -> None: """ Run main master loops (training and inference). """ - + logger.info(f"Running master {self.id}") if self.do_train: self.fit(party) @@ -357,67 +355,49 @@ def run(self, party: PartyCommunicator) -> None: def fit(self, party: PartyCommunicator): """ Run master for VFL training. """ - logger.info("Running master %s" % self.id) - records_uids_tasks = party.broadcast( Method.records_uids, participating_members=party.members, ) records_uids_results = party.gather(records_uids_tasks, recv_results=True) - collected_uids_results = [task.result for task in records_uids_results] - test_records_uids_tasks = party.broadcast( Method.records_uids, participating_members=party.members, method_kwargs=MethodKwargs(other_kwargs={'is_infer': True}), ) - test_records_uids_results = party.gather(test_records_uids_tasks, recv_results=True) - test_collected_uids_results = [task.result for task in test_records_uids_results] - party.broadcast( Method.initialize, participating_members=party.members + [party.arbiter], ) - self.initialize(is_infer=False) - uids = self.synchronize_uids(collected_uids_results, world_size=party.world_size) test_uids = self.synchronize_uids(test_collected_uids_results, world_size=party.world_size, is_infer=True) - party.broadcast( Method.register_records_uids, method_kwargs=MethodKwargs(other_kwargs={"uids": uids}), participating_members=party.members + [party.master] + [party.arbiter], ) - register_records_uids_task = party.recv( Task(method_name=Method.register_records_uids, from_id=party.master, to_id=self.id) ) - self.execute_received_task(register_records_uids_task) - party.broadcast( Method.register_records_uids, method_kwargs=MethodKwargs(other_kwargs={"uids": test_uids, "is_infer": True}), participating_members=party.members + [party.master] + [party.arbiter], ) - register_test_records_uids_task = party.recv( Task(method_name=Method.register_records_uids, from_id=party.master, to_id=self.id) ) - self.execute_received_task(register_test_records_uids_task) - pk_task = party.send(method_name=Method.get_public_key, send_to_id=party.arbiter) pk = party.recv(pk_task, recv_results=True).result - if self.security_protocol is not None: self.security_protocol.keys = pk self.security_protocol.initialize() - self.loop(batcher=self.make_batcher(uids=uids, party_members=party.members), party=party) party.broadcast( @@ -425,7 +405,7 @@ def fit(self, party: PartyCommunicator): participating_members=party.members + [party.arbiter], ) self.finalize() - logger.info("Finished master %s" % self.id) + logger.info(f"Finished master {self.id}") def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: """ Run main training loop on the VFL master. @@ -435,9 +415,13 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: :return: None """ - logger.info("Master %s: entering training loop" % self.id) + logger.info(f"Master {self.id}: entering training loop") for titer in batcher: + logger.info( + f"Master {self.id}: train loop - starting batch {titer.seq_num} (sub iter {titer.subiter_seq_num}) " + f"on epoch {titer.epoch}" + ) iter_start_time = time.time() participant_partial_pred_tasks = party.broadcast( Method.predict_partial, @@ -447,29 +431,23 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: ) ) master_partial_preds = self.predict_partial(uids=titer.batch) # d - participant_partial_predictions_tasks = party.gather(participant_partial_pred_tasks, recv_results=True) partial_predictions = [task.result for task in participant_partial_predictions_tasks] - predictions_delta = self.aggregate_partial_predictions( master_prediction=master_partial_preds, members_predictions=partial_predictions, uids=titer.batch, ) # d - if self.security_protocol is None: tensor_kw, other_kw = {'aggregated_predictions_diff': predictions_delta}, {'uids': titer.batch} else: tensor_kw, other_kw = dict(), {'uids': titer.batch, 'aggregated_predictions_diff': predictions_delta} - party.broadcast( Method.compute_gradient, method_kwargs=MethodKwargs(tensor_kwargs=tensor_kw, other_kwargs=other_kw), participating_members=titer.participating_members ) - master_gradient = self.compute_gradient(predictions_delta, titer.batch) # g_enc - if self.security_protocol is None: tensor_kw, other_kw = {'gradient': master_gradient}, dict() else: @@ -481,10 +459,12 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: method_kwargs=MethodKwargs(tensor_kwargs=tensor_kw, other_kwargs=other_kw), ) model_updates = party.recv(calculate_updates_task, recv_results=True) - self.update_weights(upd=model_updates.result, uids=titer.batch) - if self.report_train_metrics_iteration > 0 and titer.seq_num % self.report_train_metrics_iteration == 0: + logger.debug( + f"Master {self.id}: train loop - reporting train metrics on iteration {titer.seq_num} " + f"of epoch {titer.epoch}" + ) predict_tasks = party.broadcast( Method.predict, method_kwargs=MethodKwargs( @@ -492,18 +472,18 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: ), participating_members=titer.participating_members ) - master_predictions, targets = self.predict(uids=None) - participant_partial_predictions_tasks = party.gather(predict_tasks, recv_results=True) aggr_predictions = self.aggregate_predictions( master_predictions=master_predictions, members_predictions=[task.result for task in participant_partial_predictions_tasks], ) - self.report_metrics(targets, aggr_predictions, 'Train', step=titer.seq_num) - if self.report_test_metrics_iteration > 0 and titer.seq_num % self.report_test_metrics_iteration == 0: + logger.debug( + f"Master {self.id}: train loop - reporting test metrics on iteration {titer.seq_num} " + f"of epoch {titer.epoch}" + ) predict_tasks = party.broadcast( Method.predict, method_kwargs=MethodKwargs( @@ -511,14 +491,12 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: ), participating_members=titer.participating_members ) - master_predictions, targets = self.predict(uids=None, is_infer=True) participant_partial_predictions_tasks = party.gather(predict_tasks, recv_results=True) aggr_predictions = self.aggregate_predictions( master_predictions=master_predictions, members_predictions=[task.result for task in participant_partial_predictions_tasks] ) - self.report_metrics(targets, aggr_predictions, 'Test', step=titer.seq_num) self.iteration_times.append( IterationTime(client_id=self.id, iteration=titer.seq_num, iteration_time=time.time() - iter_start_time) @@ -569,19 +547,18 @@ def inference(self, party: PartyCommunicator): method_kwargs=MethodKwargs(other_kwargs={'is_infer': True}) ) self.finalize(is_infer=True) - logger.info("Finished master %s" % self.id) + logger.info(f"Finished master {self.id}") def inference_loop(self, batcher: Batcher, party: PartyCommunicator) -> None: """ Run VFL inference loop on master. """ - logger.info("Master %s: entering inference loop" % self.id) + logger.info(f"Master {self.id}: entering inference loop") party_predictions_test = defaultdict(list) test_targets = torch.tensor([]) for titer in batcher: if titer.last_batch: break - logger.debug( - f"Master %s: inference loop - starting batch %s (sub iter %s) on epoch %s" - % (self.id, titer.seq_num, titer.subiter_seq_num, titer.epoch) + logger.info( + f"Master {self.id}: inference loop - starting batch {titer.seq_num} (sub iter {titer.subiter_seq_num})" ) predict_test_tasks = party.broadcast( Method.predict, @@ -653,8 +630,7 @@ def compute_gradient( def run(self, party: PartyCommunicator) -> None: """ Run main member loops (training and inference). """ - logger.info("Running member %s" % self.id) - + logger.info(f"Running member {self.id}") if self.do_train: self.fit(party) @@ -686,15 +662,18 @@ def inference(self, party: PartyCommunicator): finalize_task = party.recv(Task(method_name=Method.finalize, from_id=party.master, to_id=self.id)) self.execute_received_task(finalize_task) - logger.info("Finished member %s" % self.id) + logger.info(f"Finished member {self.id}") def inference_loop(self, batcher: Batcher, party: PartyCommunicator) -> None: """ Run VFL inference loop on member. """ - logger.info("Member %s: entering inference loop" % self.id) + logger.info(f"Member {self.id}: entering inference loop") for titer in batcher: if titer.last_batch: break + logger.info( + f"Member {self.id}: inference loop - starting batch {titer.seq_num} (sub iter {titer.subiter_seq_num})" + ) if titer.participating_members is not None: if self.id not in titer.participating_members: logger.debug(f'Member {self.id} skipping {titer.seq_num}.') @@ -744,44 +723,47 @@ def fit(self, party: PartyCommunicator): finalize_task = party.recv(Task(method_name=Method.finalize, from_id=party.master, to_id=self.id)) self.execute_received_task(finalize_task) - logger.info("Finished member %s" % self.id) + logger.info(f"Finished member {self.id}") def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: """ Run VFL training loop on member. """ + logger.info(f"Member {self.id}: entering training loop") for titer in batcher: + logger.info( + f"Member {self.id}: train loop - starting batch {titer.seq_num} (sub iter {titer.subiter_seq_num}) " + f"on epoch {titer.epoch}" + ) participant_partial_pred_task = party.recv( Task(Method.predict_partial, from_id=party.master, to_id=self.id), recv_results=False ) member_partial_preds = self.execute_received_task(participant_partial_pred_task) - party.send( send_to_id=party.master, method_name=Method.predict_partial, result=member_partial_preds ) - compute_gradient_task = party.recv( Task(Method.compute_gradient, from_id=party.master, to_id=self.id), recv_results=False ) member_gradient = self.execute_received_task(compute_gradient_task) - if self.security_protocol is None: tensor_kw, other_kw = {'gradient': member_gradient}, dict() else: tensor_kw, other_kw = dict(), {'gradient': member_gradient} - calculate_updates_task = party.send( send_to_id=party.arbiter, method_name=Method.calculate_updates, method_kwargs=MethodKwargs(tensor_kwargs=tensor_kw, other_kwargs=other_kw), ) model_updates = party.recv(calculate_updates_task, recv_results=True) - self.update_weights(upd=model_updates.result, uids=titer.batch) - if self.report_train_metrics_iteration > 0 and titer.seq_num % self.report_train_metrics_iteration == 0: + logger.debug( + f"Member {self.id}: train loop - reporting train metrics on iteration {titer.seq_num} " + f"of epoch {titer.epoch}" + ) predict_task = party.recv( Task(Method.predict, from_id=party.master, to_id=self.id), recv_results=False @@ -793,6 +775,10 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: result=member_prediction ) if self.report_test_metrics_iteration > 0 and titer.seq_num % self.report_test_metrics_iteration == 0: + logger.debug( + f"Member {self.id}: test loop - reporting train metrics on iteration {titer.seq_num} " + f"of epoch {titer.epoch}" + ) predict_task = party.recv( Task(Method.predict, from_id=party.master, to_id=self.id), recv_results=False diff --git a/stalactite/ml/arbitered/logistic_regression/party_agent.py b/stalactite/ml/arbitered/logistic_regression/party_agent.py index 63381b4..c45dc87 100644 --- a/stalactite/ml/arbitered/logistic_regression/party_agent.py +++ b/stalactite/ml/arbitered/logistic_regression/party_agent.py @@ -9,7 +9,8 @@ from pydantic import BaseModel from stalactite.base import RecordsBatch, DataTensor, PartyAgent -from stalactite.ml.arbitered.base import SecurityProtocol, T, Role +from stalactite.ml.arbitered.base import SecurityProtocol, T +from stalactite.utils import Role from stalactite.models import LogisticRegressionBatch logger = logging.getLogger(__name__) @@ -40,27 +41,41 @@ class ArbiteredPartyAgentLogReg(PartyAgent, ABC): _eval_batch_size: int num_classes: int use_inner_join: bool + device: torch.device + + device_dataset_train_split: torch.Tensor + device_dataset_test_split: torch.Tensor ovr = True + def prepare_device_data(self, is_infer: bool = False): + if not is_infer: + self.device_dataset_train_split = self._dataset[self._data_params.train_split][ + self._data_params.features_key].to(self.device) + else: + self.device_dataset_test_split = self._dataset[self._data_params.test_split][ + self._data_params.features_key].to(self.device) + def initialize_model_from_params(self, **model_params) -> Any: - return LogisticRegressionBatch(**model_params) + return LogisticRegressionBatch(**model_params).to(device=self.device) def update_weights(self, uids: RecordsBatch, upd: DataTensor): + logger.info(f"{self.role.capitalize()} {self.id}: updating weights. Incoming tensor: {upd.size()}") tensor_idx = [self.uid2tensor_idx[uid] for uid in uids] if uids else None - X = self._dataset[self._data_params.train_split][self._data_params.features_key][tensor_idx, :] + X = self.device_dataset_train_split[tensor_idx, :] if upd.shape[0] != len(self._model): raise RuntimeError( f'Incorrect number of the updates were received (number of models to update: {upd.shape[0]}), ' f'number of models: {len(self._model)}' ) for upd_i, model in zip(upd, self._model): - model.update_weights(X, upd_i, collected_from_arbiter=True) + model.update_weights(X, upd_i.to(device=self.device), collected_from_arbiter=True) + logger.info(f"{self.role.capitalize()} {self.id}: successfully updated weights") def compute_gradient(self, aggregated_predictions_diff: T, uids: List[str]) -> T: - logger.info(f'{self.id} started computing gradient.') + logger.info(f'{self.role.capitalize()} {self.id} computes gradient') tensor_idx = [self.uid2tensor_idx[uid] for uid in uids] if uids else None - X = self._dataset[self._data_params.train_split][self._data_params.features_key][tensor_idx, :] + X = self.device_dataset_train_split[tensor_idx, :] if self.security_protocol is not None: x = self.security_protocol.encode(X.T / X.shape[0]) g = np.stack([ @@ -75,13 +90,15 @@ def compute_gradient(self, aggregated_predictions_diff: T, uids: List[str]) -> T ]) else: x = X.T / X.shape[0] - g = torch.stack([torch.matmul(x, pred) for pred in aggregated_predictions_diff]) + g = torch.stack([torch.matmul(x, pred.to(device=self.device)) for pred in aggregated_predictions_diff]) if self.l2_alpha is not None: - weights_sum = [model.get_weights().T for model in self._model] + weights_sum = [model.get_weights().T.to(device=self.device) for model in self._model] g = torch.stack([self.l2_alpha * w_sum + g_class for w_sum, g_class in zip(weights_sum, g)]) + logger.debug(f'{self.role.capitalize()} {self.id} successfully computed gradient') return g def records_uids(self, is_infer: bool = False) -> Union[List[str], Tuple[List[str], bool]]: + logger.info(f"{self.role.capitalize()} {self.id}: reporting existing record uids") if not is_infer: if self.role == Role.master: return self.target_uids @@ -94,7 +111,7 @@ def records_uids(self, is_infer: bool = False) -> Union[List[str], Tuple[List[st return self._infer_uids, self.use_inner_join def register_records_uids(self, uids: List[str], is_infer: bool = False): - logger.info("%s: registering %s uids to be used." % (self.id, len(uids))) + logger.info(f"{self.role.capitalize()} {self.id}: registering {len(uids)} uids to be used.") if is_infer: self._uids_to_use_test = uids else: @@ -103,10 +120,12 @@ def register_records_uids(self, uids: List[str], is_infer: bool = False): self.fillna(is_infer=is_infer) def initialize_model(self, do_load_model: bool = False): + logger.info(f"{self.role.capitalize()} {self.id} initializes model on device: {self.device}") + logger.debug(f"Model is loaded from path: {do_load_model}") if do_load_model: self._model = self.load_model(is_ovr_models=self.ovr) else: - input_dim = self._dataset[self._data_params.train_split][self._data_params.features_key].shape[1] + input_dim = self.device_dataset_train_split.shape[1] model_type = 'OVR models' if self.num_classes > 1 else 'binary model' logger.info(f'{self.id} Initializing {model_type} for {self.num_classes} classes') self._model = [ @@ -116,8 +135,10 @@ def initialize_model(self, do_load_model: bool = False): init_weights=0.005 ) for _ in range(self.num_classes) ] + self._model = [model.to(device=self.device) for model in self._model] def finalize(self, is_infer: bool = False): + logger.info(f"{self.role.capitalize()} {self.id}: finalizing") self.check_if_ready() if self.do_save_model and not is_infer: self.save_model(is_ovr_models=self.ovr) @@ -135,16 +156,16 @@ def fillna(self, is_infer: bool = False) -> None: if len(uids_to_fill) == 0: return - logger.info(f"Member {self.id} has {len(uids_to_fill)} missing values : using fillna...") + logger.info(f"{self.role.capitalize()} {self.id} has {len(uids_to_fill)} missing values: using fillna...") start_idx = max(_uid2tensor_idx.values()) + 1 idx = start_idx for uid in uids_to_fill: _uid2tensor_idx[uid] = idx idx += 1 - fill_shape = self._dataset[split][self._data_params.features_key].shape[1] + fill_shape = self.device_dataset_test_split.shape[1] if is_infer else self.device_dataset_train_split.shape[1] member_id = int(self.id.split("-")[-1]) + 1 - features = copy(self._dataset[split][self._data_params.features_key]) + features = copy(self.device_dataset_test_split.cpu()) if is_infer else copy(self.device_dataset_train_split.cpu()) features = torch.cat([features, torch.zeros((len(uids_to_fill), fill_shape))]) has_features_column = torch.tensor([1.0 for _ in range(start_idx)] + [0.0 for _ in range(len(uids_to_fill))]) features = torch.cat([features, torch.unsqueeze(has_features_column, 1)], dim=1) @@ -158,5 +179,7 @@ def fillna(self, is_infer: bool = False) -> None: ds = ds.with_format("torch") self._dataset[split] = ds + self.prepare_device_data(is_infer=is_infer) + if not is_infer: self.initialize_model() diff --git a/stalactite/ml/arbitered/logistic_regression/party_arbiter.py b/stalactite/ml/arbitered/logistic_regression/party_arbiter.py index 5c926d2..86327f7 100644 --- a/stalactite/ml/arbitered/logistic_regression/party_arbiter.py +++ b/stalactite/ml/arbitered/logistic_regression/party_arbiter.py @@ -5,7 +5,8 @@ from stalactite.base import DataTensor, Batcher from stalactite.batching import ListBatcher -from stalactite.ml.arbitered.base import PartyArbiter, SecurityProtocolArbiter, Role +from stalactite.ml.arbitered.base import PartyArbiter, SecurityProtocolArbiter +from stalactite.utils import Role logger = logging.getLogger(__name__) @@ -25,6 +26,7 @@ def __init__( momentum: float = 0.0, do_train: bool = True, do_predict: bool = False, + **kwargs, ) -> None: self.id = uid self.epochs = epochs @@ -37,6 +39,9 @@ def __init__( self.do_train = do_train self.do_predict = do_predict + if kwargs: + logger.info(f'Passed extra kwargs to arbiter ({kwargs}), ignoring.') + self.is_initialized = False self.is_finalized = False @@ -60,6 +65,7 @@ def make_batcher( ) -> Batcher: if uids is None: uids = self._uids_to_use_test if is_infer else self._uids_to_use + logger.info(f"Arbiter {self.id} makes a batcher for {len(uids)} uids") epochs = 1 if is_infer else self.epochs batch_size = self._eval_batch_size if is_infer else self._batch_size @@ -102,6 +108,7 @@ def _get_delta_gradients(self) -> torch.Tensor: raise ValueError(f"No previous steps were performed.") def calculate_updates(self, gradients: dict) -> dict[str, DataTensor]: + logger.info(f'Arbiter {self.id} calculates updates for {len(gradients)} agents') members = [key for key in gradients.keys() if key != self.master] try: @@ -132,17 +139,21 @@ def calculate_updates(self, gradients: dict) -> dict[str, DataTensor]: splitted_grads = torch.tensor_split(delta_gradients, torch.cumsum(torch.tensor(size_list), 0)[:-1], dim=1) deltas = {agent: splitted_grads[i].clone().detach() for i, agent in enumerate([self.master] + members)} - + logger.debug(f'Arbiter {self.id} has calculated updates') return deltas def initialize(self, is_infer: bool = False): + logger.info(f"Arbiter {self.id}: initializing") if self.security_protocol is not None: self.security_protocol.generate_keys() self.is_initialized = True self.is_finalized = False + logger.info(f"Arbiter {self.id}: has been initialized") def finalize(self, is_infer: bool = False): + logger.info(f"Arbiter {self.id}: finalizing") self.is_finalized = True + logger.info(f"Arbiter {self.id} has finalized") def register_records_uids(self, uids: List[str], is_infer: bool = False): """ Register unique identifiers to be used. @@ -150,7 +161,7 @@ def register_records_uids(self, uids: List[str], is_infer: bool = False): :param uids: List of unique identifiers. :return: None """ - logger.info("Agent %s: registering %s uids to be used." % (self.id, len(uids))) + logger.info(f"Arbiter {self.id}: registering {len(uids)} uids to be used.") if is_infer: self._uids_to_use_test = uids else: diff --git a/stalactite/ml/arbitered/logistic_regression/party_master.py b/stalactite/ml/arbitered/logistic_regression/party_master.py index a01804b..7e3e6b4 100644 --- a/stalactite/ml/arbitered/logistic_regression/party_master.py +++ b/stalactite/ml/arbitered/logistic_regression/party_master.py @@ -1,5 +1,5 @@ import logging -from typing import Any, List, Optional +from typing import List, Optional import mlflow import numpy as np @@ -10,7 +10,8 @@ from stalactite.base import Batcher, RecordsBatch, DataTensor, PartyDataTensor from stalactite.batching import ListBatcher from stalactite.metrics import ComputeAccuracy_numpy -from stalactite.ml.arbitered.base import ArbiteredPartyMaster, SecurityProtocol, T, Role +from stalactite.ml.arbitered.base import ArbiteredPartyMaster, SecurityProtocol, T +from stalactite.utils import Role from stalactite.ml.arbitered.logistic_regression.party_agent import ArbiteredPartyAgentLogReg logger = logging.getLogger(__name__) @@ -39,7 +40,8 @@ def __init__( do_save_model: bool = False, processor=None, run_mlflow: bool = False, - seed: int = None + seed: int = None, + device: str = 'cpu', ) -> None: """ Initialize ArbiteredPartyMasterLinReg. @@ -80,15 +82,17 @@ def __init__( self.do_save_model = do_save_model self.model_path = model_path self.seed = seed + self.device = torch.device(device) self.uid2tensor_idx = None self.uid2tensor_idx_test = None def predict_partial(self, uids: RecordsBatch) -> DataTensor: - logger.info("Master %s: predicting. Batch size: %s" % (self.id, len(uids))) + logger.info(f"Master {self.id}: makes partial predictions for gradient computation") self.check_if_ready() Xw, y = self.predict(uids, is_infer=False) d = 0.25 * Xw - 0.5 * y.T.unsqueeze(2) + logger.debug(f"Master {self.id}: made partial predictions") return d def aggregate_partial_predictions( @@ -97,6 +101,7 @@ def aggregate_partial_predictions( members_predictions: List[T], uids: RecordsBatch, ) -> T: + logger.info(f"Master {self.id}: aggregates predictions from {len(members_predictions)} members") class_predictions = [] for class_idx in range(self.num_classes): prediction = master_prediction[class_idx] @@ -104,19 +109,22 @@ def aggregate_partial_predictions( if self.security_protocol is not None: prediction = self.security_protocol.add_matrices(prediction, member_preds[class_idx]) else: - prediction += member_preds[class_idx] + prediction += member_preds[class_idx].to(self.device) class_predictions.append(prediction) stacking_func = np.stack if self.security_protocol is not None else torch.stack master_prediction = stacking_func(class_predictions) - + logger.debug(f"Master {self.id}: aggregated predictions") return master_prediction def initialize(self, is_infer: bool = False): - logger.info("Master %s: initializing" % self.id) + logger.info(f"Master {self.id}: initializing") dataset = self.processor.fit_transform() self._dataset = dataset - self._data_params = self.processor.data_params + + self.prepare_device_data(is_infer=True) + self.prepare_device_data(is_infer=False) + self._common_params = self.processor.common_params self.target = dataset[self._data_params.train_split][self._data_params.label_key] @@ -138,18 +146,20 @@ def initialize(self, is_infer: bool = False): self.initialize_model(do_load_model=is_infer) self.is_initialized = True self.is_finalized = False - logger.info("Master %s: is initialized" % self.id) + logger.info(f"Master {self.id}: has been initialized") def predict(self, uids: Optional[List[str]], is_infer: bool = False): - logger.info(f'{self.id} makes predictions') - split = self._data_params.train_split if not is_infer else self._data_params.test_split + logger.info(f"Master {self.id}: predicting") target = self.target if not is_infer else self.test_target if uids is None: uids = self.inference_target_uids if is_infer else self.target_uids _uid2tensor_idx = self.uid2tensor_idx_test if is_infer else self.uid2tensor_idx tensor_idx = [_uid2tensor_idx[uid] for uid in uids] - X = self._dataset[split][self._data_params.features_key][tensor_idx] + X = self.device_dataset_train_split[tensor_idx] if not is_infer else self.device_dataset_test_split[tensor_idx] y = target[tensor_idx] + if not is_infer: + y = y.to(self.device) + logger.debug(f"Master {self.id}: made predictions") return torch.stack([model.predict(X) for model in self._model]), y def aggregate_predictions(self, master_predictions: DataTensor, members_predictions: PartyDataTensor) -> DataTensor: @@ -168,25 +178,27 @@ def aggregate_predictions(self, master_predictions: DataTensor, members_predicti torch.sigmoid( torch.sum( torch.hstack( - [master_predictions[class_idx]] + - [member_pred[class_idx] for member_pred in members_predictions] + [master_predictions[class_idx].to(self.device)] + + [member_pred[class_idx].to(self.device) for member_pred in members_predictions] ), dim=1 ) ) ) predictions = torch.stack(predictions).T - + logger.debug(f"Master {self.id}: aggregated predictions") return predictions def report_metrics(self, y: DataTensor, predictions: DataTensor, name: str, step: int): + logger.info(f"Master {self.id} reporting metrics") + logger.debug(f"Predictions size: {predictions.size()}, Target size: {y.size()}") postfix = "-infer" if step == -1 else "" step = step if step != -1 else None y = torch.where(y == -1., -0., 1.) # After a sigmoid function we calculate metrics on the {0, 1} labels - y = y.numpy() + y = y.cpu().numpy() - predictions = predictions.detach().numpy() + predictions = predictions.detach().cpu().numpy() mae = metrics.mean_absolute_error(y, predictions) acc = ComputeAccuracy_numpy(is_linreg=False).compute(y, predictions) @@ -213,7 +225,7 @@ def make_batcher( uids = self._uids_to_use_test else: raise RuntimeError('Master must initialize batcher with collected uids.') - logger.info("Master %s: making a make_batcher for uids %s" % (self.id, len(uids))) + logger.info(f"Master {self.id} makes a batcher for {len(uids)} uids") self.check_if_ready() batch_size = self._eval_batch_size if is_infer else self._batch_size epochs = 1 if is_infer else self.epochs diff --git a/stalactite/ml/arbitered/logistic_regression/party_member.py b/stalactite/ml/arbitered/logistic_regression/party_member.py index 21fc839..afb9f15 100644 --- a/stalactite/ml/arbitered/logistic_regression/party_member.py +++ b/stalactite/ml/arbitered/logistic_regression/party_member.py @@ -5,7 +5,8 @@ from stalactite.base import RecordsBatch, DataTensor, Batcher from stalactite.batching import ListBatcher -from stalactite.ml.arbitered.base import ArbiteredPartyMember, SecurityProtocol, Role +from stalactite.ml.arbitered.base import ArbiteredPartyMember, SecurityProtocol +from stalactite.utils import Role from stalactite.ml.arbitered.logistic_regression.party_agent import ArbiteredPartyAgentLogReg logger = logging.getLogger(__name__) @@ -33,7 +34,8 @@ def __init__( do_save_model: bool = False, processor=None, use_inner_join: bool = True, - seed: int = None + seed: int = None, + device: str = 'cpu', ) -> None: self.id = uid self.epochs = epochs @@ -54,23 +56,25 @@ def __init__( self.model_path = model_path self.use_inner_join = use_inner_join self.seed = seed + self.device = torch.device(device) def predict_partial(self, uids: RecordsBatch) -> DataTensor: - logger.info(f'{self.id} makes partial predictions') + logger.info(f'Member {self.id} makes partial predictions') predictions = self.predict(uids, is_infer=False) Xw = 0.25 * predictions if self.security_protocol is not None: Xw = self.security_protocol.encrypt(Xw) + logger.debug(f"Member {self.id}: made partial predictions") return Xw def predict(self, uids: Optional[List[str]], is_infer: bool = False) -> Union[DataTensor, List[DataTensor]]: - logger.info(f'{self.id} makes predictions') - split = self._data_params.train_split if not is_infer else self._data_params.test_split + logger.info(f"Member {self.id}: predicting") _uid2tensor_idx = self.uid2tensor_idx_test if is_infer else self.uid2tensor_idx if uids is None: uids = self._uids_to_use_test if is_infer else self._uids_to_use tensor_idx = [_uid2tensor_idx[uid] for uid in uids] - X = self._dataset[split][self._data_params.features_key][tensor_idx] + X = self.device_dataset_train_split[tensor_idx] if not is_infer else self.device_dataset_test_split[tensor_idx] + logger.debug(f"Member {self.id}: made predictions") return torch.stack([model.predict(X) for model in self._model]) def make_batcher( @@ -81,14 +85,20 @@ def make_batcher( ) -> Batcher: if uids is None: uids = self._uids_to_use_test if is_infer else self._uids_to_use + logger.info(f"Member {self.id} makes a batcher for {len(uids)} uids") epochs = 1 if is_infer else self.epochs batch_size = self._eval_batch_size if is_infer else self._batch_size return ListBatcher(epochs=epochs, members=None, uids=uids, batch_size=batch_size) def initialize(self, is_infer: bool = False): - logger.info("Member %s: initializing" % self.id) + logger.info(f"Member {self.id}: initializing") self._dataset = self.processor.fit_transform() + self._data_params = self.processor.data_params + + self.prepare_device_data(is_infer=False) + self.prepare_device_data(is_infer=True) + self._common_params = self.processor.common_params self.uid2tensor_idx = {uid: i for i, uid in enumerate(self._uids)} self.uid2tensor_idx_test = {uid: i for i, uid in enumerate(self._infer_uids)} @@ -97,4 +107,4 @@ def initialize(self, is_infer: bool = False): self.is_initialized = True self.is_finalized = False - logger.info("Member %s: has been initialized" % self.id) + logger.info(f"Member {self.id}: has been initialized") diff --git a/stalactite/ml/honest/base.py b/stalactite/ml/honest/base.py index f97e584..82e41aa 100644 --- a/stalactite/ml/honest/base.py +++ b/stalactite/ml/honest/base.py @@ -94,7 +94,7 @@ def run(self, party: PartyCommunicator) -> None: :param party: Communicator instance used for communication between VFL agents. :return: None """ - logger.info("Running master %s" % self.id) + logger.info(f"Running master {self.id}") if self.do_train: self.fit(party) @@ -147,7 +147,7 @@ def fit(self, party: PartyCommunicator) -> None: participating_members=party.members, ) self.finalize() - logger.info("Finished master %s" % self.id) + logger.info(f"Finished master {self.id}") def inference(self, party: PartyCommunicator) -> None: records_uids_tasks = party.broadcast( @@ -184,7 +184,7 @@ def inference(self, party: PartyCommunicator) -> None: method_kwargs=MethodKwargs(other_kwargs={'is_infer': True}) ) self.finalize(is_infer=True) - logger.info("Finished master %s" % self.id) + logger.info(f"Finished master {self.id}") def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: """ Run main training loop on the VFL master. @@ -194,12 +194,12 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: :return: None """ - logger.info("Master %s: entering training loop" % self.id) + logger.info(f"Master {self.id}: entering training loop") updates = self.make_init_updates(party.world_size) for titer in batcher: - logger.debug( - f"Master %s: train loop - starting batch %s (sub iter %s) on epoch %s" - % (self.id, titer.seq_num, titer.subiter_seq_num, titer.epoch) + logger.info( + f"Master {self.id}: train loop - starting batch {titer.seq_num} (sub iter {titer.subiter_seq_num}) " + f"on epoch {titer.epoch}" ) iter_start_time = time.time() if titer.seq_num == 0: @@ -230,8 +230,8 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: if self.report_train_metrics_iteration > 0 and titer.seq_num % self.report_train_metrics_iteration == 0: logger.debug( - f"Master %s: train loop - reporting train metrics on iteration %s of epoch %s" - % (self.id, titer.seq_num, titer.epoch) + f"Master {self.id}: train loop - reporting train metrics on iteration {titer.seq_num} " + f"of epoch {titer.epoch}" ) predict_tasks = party.broadcast( Method.predict, @@ -246,8 +246,8 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: if self.report_test_metrics_iteration > 0 and titer.seq_num % self.report_test_metrics_iteration == 0: logger.debug( - f"Master %s: train loop - reporting test metrics on iteration %s of epoch %s" - % (self.id, titer.seq_num, titer.epoch) + f"Master {self.id}: train loop - reporting test metrics on iteration {titer.seq_num} " + f"of epoch {titer.epoch}" ) predict_test_tasks = party.broadcast( Method.predict, @@ -276,14 +276,13 @@ def inference_loop(self, batcher: Batcher, party: PartyCommunicator) -> None: :return: None """ - logger.info("Master %s: entering inference loop" % self.id) + logger.info(f"Master {self.id}: entering inference loop") party_predictions_test = defaultdict(list) for titer in batcher: if titer.last_batch: break - logger.debug( - f"Master %s: inference loop - starting batch %s (sub iter %s) on epoch %s" - % (self.id, titer.seq_num, titer.subiter_seq_num, titer.epoch) + logger.info( + f"Master {self.id}: inference loop - starting batch {titer.seq_num} (sub iter {titer.subiter_seq_num})" ) predict_test_tasks = party.broadcast( Method.predict, @@ -332,7 +331,8 @@ def __init__( do_save_model: bool = False, use_inner_join: bool = False, model_params: dict = None, - seed: int = None + seed: int = None, + device: str = 'cpu', ) -> None: """ Initialize PartyMemberImpl. @@ -373,6 +373,7 @@ def __init__( self._optimizer = None self.use_inner_join = use_inner_join self.seed = seed + self.device = torch.device(device) if self.is_consequently: if self.members is None: @@ -423,7 +424,7 @@ def run(self, party: PartyCommunicator): :param party: Communicator instance used for communication between VFL agents. :return: None """ - logger.info("Running member %s" % self.id) + logger.info(f"Running member {self.id}") if self.do_train: self._run(party, is_infer=False) @@ -464,7 +465,7 @@ def _run(self, party: PartyCommunicator, is_infer: bool = False): finalize_task = party.recv(Task(method_name=Method.finalize, from_id=party.master, to_id=self.id)) self.execute_received_task(finalize_task) - logger.info("Finished member %s" % self.id) + logger.info(f"Finished member {self.id}") def _predict_metrics_loop(self, party: PartyCommunicator): predict_task = party.recv(Task(method_name=Method.predict, from_id=party.master, to_id=self.id)) @@ -479,16 +480,15 @@ def loop(self, batcher: Batcher, party: PartyCommunicator): :return: None """ - logger.info("Member %s: entering training loop" % self.id) - + logger.info(f"Member {self.id}: entering training loop") for titer in batcher: if titer.participating_members is not None: if self.id not in titer.participating_members: logger.debug(f'Member {self.id} skipping {titer.seq_num}.') continue - logger.debug( - f"Member %s: train loop - starting batch %s (sub iter %s) on epoch %s" - % (self.id, titer.seq_num, titer.subiter_seq_num, titer.epoch) + logger.info( + f"Member {self.id}: train loop - starting batch {titer.seq_num} (sub iter {titer.subiter_seq_num}) " + f"on epoch {titer.epoch}" ) update_predict_task = party.recv( Task(method_name=Method.update_predict, from_id=party.master, to_id=self.id) @@ -498,15 +498,15 @@ def loop(self, batcher: Batcher, party: PartyCommunicator): if self.report_train_metrics_iteration > 0 and titer.seq_num % self.report_train_metrics_iteration == 0: logger.debug( - f"Member %s: train loop - calculating train metrics on iteration %s of epoch %s" - % (self.id, titer.seq_num, titer.epoch) + f"Member {self.id}: train loop - reporting train metrics on iteration {titer.seq_num} " + f"of epoch {titer.epoch}" ) self._predict_metrics_loop(party) if self.report_test_metrics_iteration > 0 and titer.seq_num % self.report_test_metrics_iteration == 0: logger.debug( - f"Member %s: train loop - calculating train metrics on iteration %s of epoch %s" - % (self.id, titer.seq_num, titer.epoch) + f"Member {self.id}: test loop - reporting train metrics on iteration {titer.seq_num} " + f"of epoch {titer.epoch}" ) self._predict_metrics_loop(party) @@ -518,11 +518,14 @@ def inference_loop(self, batcher: Batcher, party: PartyCommunicator): :return: None """ - logger.info("Member %s: entering training loop" % self.id) + logger.info(f"Member {self.id}: entering inference loop") for titer in batcher: if titer.last_batch: break + logger.info( + f"Member {self.id}: inference loop - starting batch {titer.seq_num} (sub iter {titer.subiter_seq_num})" + ) if titer.participating_members is not None: if self.id not in titer.participating_members: logger.debug(f'Member {self.id} skipping {titer.seq_num}.') @@ -543,6 +546,7 @@ def make_batcher( uids_to_use = self._uids_to_use_test if is_infer else self._uids_to_use if uids_to_use is None: raise RuntimeError("Cannot create make_batcher, you must `register_records_uids` first.") + logger.info(f"Member {self.id} makes a batcher for {len(uids_to_use)} uids") return self._create_batcher(epochs=epochs, uids=uids_to_use, batch_size=batch_size) def _create_batcher(self, epochs: int, uids: List[str], batch_size: int) -> Batcher: @@ -552,7 +556,7 @@ def _create_batcher(self, epochs: int, uids: List[str], batch_size: int) -> Batc :param uids: List of unique identifiers for dataset rows. :param batch_size: Size of the training batch. """ - logger.info("Member %s: making a make_batcher for uids" % self.id) + logger.info(f"Member {self.id}: making a make_batcher for {len(uids)} uids") if not self.is_consequently: return ListBatcher(epochs=epochs, members=None, uids=uids, batch_size=batch_size) else: @@ -565,7 +569,7 @@ def records_uids(self, is_infer: bool = False) -> Tuple[List[str], bool]: :return: List of unique identifiers. """ - logger.info("Member %s: reporting existing record uids" % self.id) + logger.info(f"Member {self.id}: reporting existing record uids") if is_infer: return self._infer_uids, self.use_inner_join return self._uids, self.use_inner_join @@ -576,7 +580,7 @@ def register_records_uids(self, uids: List[str], is_infer: bool = False) -> None :param uids: List of unique identifiers. :return: None """ - logger.info("Member %s: registering %s uids to be used." % (self.id, len(uids))) + logger.info(f"Member {self.id}: registering {len(uids)} uids to be used.") if is_infer: self._uids_to_use_test = uids @@ -595,7 +599,7 @@ def fillna(self, is_infer: bool = False) -> None: if len(uids_to_fill) == 0: return - logger.info(f"Member {self.id} has {len(uids_to_fill)} missing values : using fillna...") + logger.info(f"Member {self.id} has {len(uids_to_fill)} missing values: using fillna...") start_idx = max(_uid2tensor_idx.values()) + 1 idx = start_idx for uid in uids_to_fill: @@ -618,14 +622,23 @@ def fillna(self, is_infer: bool = False) -> None: ds = ds.with_format("torch") self._dataset[split] = ds + self.prepare_device_data(is_infer=is_infer) if not is_infer: self.initialize_model() self.initialize_optimizer() + def prepare_device_data(self, is_infer: bool = False): + if not is_infer: + self.device_dataset_train_split = self._dataset[self._data_params.train_split][ + self._data_params.features_key].to(self.device) + else: + self.device_dataset_test_split = self._dataset[self._data_params.test_split][ + self._data_params.features_key].to(self.device) + def initialize(self, is_infer: bool = False): """ Initialize the party member. """ - logger.info("Member %s: initializing" % self.id) + logger.info(f"Member {self.id}: initializing") self._dataset = self.processor.fit_transform() self.uid2tensor_idx = {uid: i for i, uid in enumerate(self._uids)} self.uid2tensor_idx_test = {uid: i for i, uid in enumerate(self._infer_uids)} @@ -633,18 +646,20 @@ def initialize(self, is_infer: bool = False): self._common_params = self.processor.common_params self.initialize_model(do_load_model=is_infer) self.initialize_optimizer() + self.prepare_device_data(is_infer=True) + self.prepare_device_data(is_infer=False) self.is_initialized = True self.is_finalized = False - logger.info("Member %s: has been initialized" % self.id) + logger.info(f"Member {self.id}: has been initialized") def finalize(self, is_infer: bool = False) -> None: """ Finalize the party member. """ - logger.info("Member %s: finalizing" % self.id) + logger.info(f"Member {self.id}: finalizing") self.check_if_ready() if self.do_save_model and not is_infer: self.save_model(is_ovr_models=self.ovr) self.is_finalized = True - logger.info("Member %s: has been finalized" % self.id) + logger.info(f"Member {self.id}: has finalized") def _prepare_data(self, uids: RecordsBatch) -> Tuple: """ Prepare data for training. @@ -652,6 +667,6 @@ def _prepare_data(self, uids: RecordsBatch) -> Tuple: :param uids: Batch of record unique identifiers. :return: Tuple of three SVD matrices. """ - X_train = self._dataset[self._data_params.train_split][self._data_params.features_key][[int(x) for x in uids]] + X_train = self.device_dataset_train_split[[int(x) for x in uids]] U, S, Vh = sp.linalg.svd(X_train.numpy(), full_matrices=False, overwrite_a=False, check_finite=False) return U, S, Vh diff --git a/stalactite/ml/honest/linear_regression/party_master.py b/stalactite/ml/honest/linear_regression/party_master.py index 798c11a..2f8f801 100644 --- a/stalactite/ml/honest/linear_regression/party_master.py +++ b/stalactite/ml/honest/linear_regression/party_master.py @@ -12,14 +12,13 @@ from stalactite.ml.honest.base import HonestPartyMaster, Batcher from stalactite.metrics import ComputeAccuracy + logger = logging.getLogger(__name__) class HonestPartyMasterLinReg(HonestPartyMaster): """ Implementation class of the PartyMaster used for local and distributed VFL training. """ - do_save_model = False - do_load_model = False - model_path = None + def __init__( self, uid: str, @@ -37,7 +36,10 @@ def __init__( do_predict: bool = False, model_name: str = None, model_params: dict = None, - seed: int = None + seed: int = None, + device: str = 'cpu', + model_path: Optional[str] = None, + do_save_model: bool = False, ) -> None: """ Initialize PartyMaster. @@ -75,6 +77,9 @@ def __init__( self.aggregated_output = None self._model_params = model_params self.seed = seed + self.device = torch.device(device) + self.do_save_model = do_save_model + self.model_path = model_path self.uid2tensor_idx = None self.uid2tensor_idx_test = None @@ -87,7 +92,7 @@ def initialize_optimizer(self) -> None: def initialize(self, is_infer: bool = False) -> None: """ Initialize the party master. """ - logger.info("Master %s: initializing" % self.id) + logger.info(f"Master {self.id}: initializing") ds = self.processor.fit_transform() self.target = ds[self.processor.data_params.train_split][self.processor.data_params.label_key] self.test_target = ds[self.processor.data_params.test_split][self.processor.data_params.label_key] @@ -108,10 +113,14 @@ def initialize(self, is_infer: bool = False) -> None: self.binary = False if self._model_name is not None: - self.initialize_model() + self.initialize_model(do_load_model=is_infer) self.initialize_optimizer() + + self.target = self.target.to(self.device) + self.test_target = self.test_target.to(self.device) self.is_initialized = True self.is_finalized = False + logger.info(f"Master {self.id}: has been initialized") def make_batcher( self, @@ -126,12 +135,12 @@ def make_batcher( :return: Batcher instance. """ - logger.info("Master %s: making a make_batcher for uids %s" % (self.id, len(uids))) self.check_if_ready() batch_size = self._eval_batch_size if is_infer else self._batch_size epochs = 1 if is_infer else self.epochs if uids is None: raise RuntimeError('Master must initialize batcher with collected uids.') + logger.info(f"Master {self.id} makes a batcher for {len(uids)} uids") assert party_members is not None, "Master is trying to initialize make_batcher without members list" return ListBatcher(epochs=epochs, members=party_members, uids=uids, batch_size=batch_size) @@ -142,9 +151,9 @@ def make_init_updates(self, world_size: int) -> PartyDataTensor: :return: Initial updates as a list of tensors. """ - logger.info("Master %s: making init updates for %s members" % (self.id, world_size)) + logger.info(f"Master {self.id}: makes initial updates for {world_size} members") self.check_if_ready() - return [torch.zeros(self._batch_size) for _ in range(world_size)] + return [torch.zeros(self._batch_size, device=self.device) for _ in range(world_size)] def report_metrics(self, y: DataTensor, predictions: DataTensor, name: str, step: int) -> None: """ Report metrics based on target values and predictions. @@ -155,16 +164,16 @@ def report_metrics(self, y: DataTensor, predictions: DataTensor, name: str, step :return: None """ - logger.info( - f"Master %s: reporting metrics. Y dim: {y.size()}. " f"Predictions size: {predictions.size()}" % self.id - ) + logger.info(f"Master {self.id} reporting metrics") + logger.debug(f"Predictions size: {predictions.size()}, Target size: {y.size()}") postfix = '-infer' if step == -1 else "" step = step if step != -1 else None - - mae = metrics.mean_absolute_error(y, predictions.detach()) - acc = ComputeAccuracy().compute(y, predictions.detach()) - logger.info(f"Master %s: %s metrics (MAE): {mae}" % (self.id, name)) - logger.info(f"Master %s: %s metrics (Accuracy): {acc}" % (self.id, name)) + y = y.cpu() + predictions = predictions.cpu() + mae = metrics.mean_absolute_error(y, predictions) + acc = ComputeAccuracy().compute(y, predictions) + logger.info(f"{name} metrics (MAE): {mae}") + logger.info(f"{name} metrics (Accuracy): {acc}") if self.run_mlflow: mlflow.log_metric(f"{name.lower()}_mae{postfix}", mae, step=step) @@ -181,11 +190,11 @@ def aggregate( :return: Aggregated predictions. """ - logger.info("Master %s: aggregating party predictions (num predictions %s)" % (self.id, len(party_predictions))) + logger.info(f"Master {self.id}: aggregates party predictions (number of predictions {len(party_predictions)})") self.check_if_ready() if not is_infer: for member_id, member_prediction in zip(participating_members, party_predictions): - self.party_predictions[member_id] = member_prediction + self.party_predictions[member_id] = member_prediction.to(self.device) party_predictions = list(self.party_predictions.values()) return torch.sum(torch.stack(party_predictions, dim=1), dim=1) @@ -208,7 +217,7 @@ def compute_updates( :return: List of updates as tensors. """ - logger.info("Master %s: computing updates (world size %s)" % (self.id, world_size)) + logger.info(f"Master {self.id}: computes updates (world size {world_size})") self.check_if_ready() self.iteration_counter += 1 tensor_idx = [self.uid2tensor_idx[uid] for uid in uids] @@ -216,17 +225,19 @@ def compute_updates( for member_id in participating_members: party_predictions_for_upd = [v for k, v in self.party_predictions.items() if k != member_id] if len(party_predictions_for_upd) == 0: - party_predictions_for_upd = [torch.rand(predictions.size())] + party_predictions_for_upd = [torch.rand(predictions.size(), device=self.device)] pred_for_member_upd = torch.mean(torch.stack(party_predictions_for_upd), dim=0) member_update = y - torch.reshape(pred_for_member_upd, (-1,)) self.updates[member_id] = member_update + logger.debug(f"Master {self.id}: computed updates") return [self.updates[member_id] for member_id in participating_members] def finalize(self, is_infer: bool = False) -> None: """ Finalize the party master. """ - logger.info("Master %s: finalizing" % self.id) + logger.info(f"Master {self.id}: finalizing") self.check_if_ready() self.is_finalized = True + logger.info(f"Master {self.id}: has finalized") def check_if_ready(self): """ Check if the party master is ready for operations. @@ -256,12 +267,12 @@ def make_batcher( :return: ConsecutiveListBatcher instance. """ - logger.info("Master %s: making a make_batcher for uids %s" % (self.id, len(uids))) self.check_if_ready() epochs = 1 if is_infer else self.epochs batch_size = self._eval_batch_size if is_infer else self._batch_size if uids is None: raise RuntimeError('Master must initialize batcher with collected uids.') + logger.info(f"Master {self.id} makes a batcher for {len(uids)} uids") if not is_infer: return ConsecutiveListBatcher(epochs=epochs, members=party_members, uids=uids, batch_size=batch_size) else: diff --git a/stalactite/ml/honest/linear_regression/party_member.py b/stalactite/ml/honest/linear_regression/party_member.py index fca16e0..1b1d84d 100644 --- a/stalactite/ml/honest/linear_regression/party_member.py +++ b/stalactite/ml/honest/linear_regression/party_member.py @@ -1,12 +1,5 @@ import logging -import os -from abc import ABC from typing import Optional, Any -import math - -import torch -from torch import nn -import numpy as np from stalactite.base import RecordsBatch, DataTensor from stalactite.ml.honest.base import HonestPartyMember @@ -15,24 +8,26 @@ logger = logging.getLogger(__name__) + class HonestPartyMemberLinReg(HonestPartyMember): def initialize_model_from_params(self, **model_params) -> Any: - return LinearRegressionBatch(**model_params) + return LinearRegressionBatch(**model_params).to(self.device) def initialize_model(self, do_load_model: bool = False) -> None: """ Initialize the model based on the specified model name. """ + logger.info(f"Member {self.id} initializes model on device: {self.device}") + logger.debug(f"Model is loaded from path: {do_load_model}") if do_load_model: - self._model = self.load_model() - + self._model = self.load_model().to(self.device) else: self._model = LinearRegressionBatch( input_dim=self._dataset[self._data_params.train_split][self._data_params.features_key].shape[1], **self._model_params - ) + ).to(self.device) init_linear_np(self._model.linear, seed=self.seed) - + self._model.linear.to(self.device) def initialize_optimizer(self) -> None: pass @@ -43,12 +38,16 @@ def update_weights(self, uids: RecordsBatch, upd: DataTensor) -> None: :param uids: Batch of record unique identifiers. :param upd: Updated model weights. """ - logger.info("Member %s: updating weights. Incoming tensor: %s" % (self.id, tuple(upd.size()))) + logger.info(f"Member {self.id}: updating weights. Incoming tensor: {upd.size()}") self.check_if_ready() tensor_idx = [self.uid2tensor_idx[uid] for uid in uids] - X_train = self._dataset[self._data_params.train_split][self._data_params.features_key][tensor_idx, :] - self._model.update_weights(X_train, upd, optimizer=self._optimizer) - logger.info("Member %s: successfully updated weights" % self.id) + X_train = self.device_dataset_train_split[tensor_idx, :] + self._model.update_weights(X_train, upd.to(self.device), optimizer=self._optimizer) + self.move_model_to_device() + logger.debug(f"Member {self.id}: successfully updated weights") + + def move_model_to_device(self): + self._model.linear.weight.to(self.device) def predict(self, uids: Optional[RecordsBatch], is_infer: bool = False) -> DataTensor: """ Make predictions using the current model. @@ -58,20 +57,23 @@ def predict(self, uids: Optional[RecordsBatch], is_infer: bool = False) -> DataT :return: Model predictions. """ - logger.info("Member %s: predicting." % (self.id)) + logger.info(f"Member {self.id}: predicting") self.check_if_ready() _uid2tensor_idx = self.uid2tensor_idx_test if is_infer else self.uid2tensor_idx tensor_idx = [_uid2tensor_idx[uid] for uid in uids] if uids else None if is_infer: logger.info("Member %s: using test data" % self.id) if uids is None: - X = self._dataset[self._data_params.test_split][self._data_params.features_key] + X = self.device_dataset_test_split else: - X = self._dataset[self._data_params.test_split][self._data_params.features_key][tensor_idx, :] + X = self.device_dataset_test_split[tensor_idx, :] else: - X = self._dataset[self._data_params.train_split][self._data_params.features_key][tensor_idx, :] + X = self.device_dataset_train_split[tensor_idx, :] + if is_infer: + self._model.eval() predictions = self._model.predict(X) - logger.info("Member %s: made predictions." % self.id) + self._model.train() + logger.debug(f"Member {self.id}: made predictions") return predictions def update_predict(self, upd: DataTensor, previous_batch: RecordsBatch, batch: RecordsBatch) -> DataTensor: @@ -83,11 +85,12 @@ def update_predict(self, upd: DataTensor, previous_batch: RecordsBatch, batch: R :return: Model predictions. """ - logger.info("Member %s: updating and predicting." % self.id) + logger.info(f"Member {self.id}: updating and predicting") self.check_if_ready() if previous_batch is not None: self.update_weights(uids=previous_batch, upd=upd) + predictions = self.predict(batch) self.iterations_counter += 1 - logger.info("Member %s: updated and predicted." % self.id) + logger.debug(f"Member {self.id}: updated and predicted") return predictions diff --git a/stalactite/ml/honest/logistic_regression/party_master.py b/stalactite/ml/honest/logistic_regression/party_master.py index f3ec974..cd4dda4 100644 --- a/stalactite/ml/honest/logistic_regression/party_master.py +++ b/stalactite/ml/honest/logistic_regression/party_master.py @@ -20,9 +20,9 @@ def make_init_updates(self, world_size: int) -> PartyDataTensor: :param world_size: Number of party members. :return: Initial updates as a list of zero tensors. """ - logger.info("Master %s: making init updates for %s members" % (self.id, world_size)) + logger.info(f"Master {self.id}: makes initial updates for {world_size} members") self.check_if_ready() - return [torch.zeros(self._batch_size) for _ in range(world_size)] + return [torch.zeros(self._batch_size).to(self.device) for _ in range(world_size)] def aggregate( self, participating_members: List[str], party_predictions: PartyDataTensor, is_infer: bool = False @@ -35,15 +35,15 @@ def aggregate( :return: Aggregated predictions after applying sigmoid function. """ - logger.info("Master %s: aggregating party predictions (num predictions %s)" % (self.id, len(party_predictions))) + logger.info(f"Master {self.id}: aggregates party predictions (number of predictions {len(party_predictions)})") self.check_if_ready() if not is_infer: for member_id, member_prediction in zip(participating_members, party_predictions): self.party_predictions[member_id] = member_prediction party_predictions = list(self.party_predictions.values()) - predictions = torch.sum(torch.stack(party_predictions, dim=1), dim=1) + predictions = torch.sum(torch.stack(party_predictions, dim=1).to(self.device), dim=1) else: - predictions = self.activation(torch.sum(torch.stack(party_predictions, dim=1), dim=1)) + predictions = self.activation(torch.sum(torch.stack(party_predictions, dim=1).to(self.device), dim=1)) return predictions def compute_updates( @@ -64,20 +64,22 @@ def compute_updates( :return: List of gradients as tensors. """ - logger.info("Master %s: computing updates (world size %s)" % (self.id, world_size)) + logger.info(f"Master {self.id}: computes updates (world size {world_size})") self.check_if_ready() self.iteration_counter += 1 tensor_idx = [self.uid2tensor_idx[uid] for uid in uids] y = self.target[tensor_idx] - criterion = torch.nn.BCEWithLogitsLoss(pos_weight=self.class_weights) if self.binary else torch.nn.CrossEntropyLoss(weight=self.class_weights) + criterion = torch.nn.BCEWithLogitsLoss(pos_weight=self.class_weights) \ + if self.binary else torch.nn.CrossEntropyLoss(weight=self.class_weights) targets_type = torch.LongTensor if isinstance(criterion, torch.nn.CrossEntropyLoss) else torch.FloatTensor - loss = criterion(torch.squeeze(predictions), y.type(targets_type)) + predictions = predictions.to(self.device) + loss = criterion(torch.squeeze(predictions), y.type(targets_type).to(self.device)) grads = torch.autograd.grad(outputs=loss, inputs=predictions) for i, member_id in enumerate(participating_members): self.updates[member_id] = grads[0] - + logger.debug(f"Master {self.id}: computed updates") return [self.updates[member_id] for member_id in participating_members] def report_metrics(self, y: DataTensor, predictions: DataTensor, name: str, step: int) -> None: @@ -92,12 +94,11 @@ def report_metrics(self, y: DataTensor, predictions: DataTensor, name: str, step :return: None. """ - logger.info( - f"Master %s: reporting metrics. Y dim: {y.size()}. " f"Predictions size: {predictions.size()}" % self.id - ) + logger.info(f"Master {self.id} reporting metrics") + logger.debug(f"Predictions size: {predictions.size()}, Target size: {y.size()}") - y = y.numpy() - predictions = predictions.detach().numpy() + y = y.cpu().numpy() + predictions = predictions.cpu().detach().numpy() postfix = '-infer' if step == -1 else "" step = step if step != -1 else None diff --git a/stalactite/ml/honest/logistic_regression/party_member.py b/stalactite/ml/honest/logistic_regression/party_member.py index 39a2029..e03ae92 100644 --- a/stalactite/ml/honest/logistic_regression/party_member.py +++ b/stalactite/ml/honest/logistic_regression/party_member.py @@ -1,3 +1,4 @@ +import logging from typing import Any from torch.optim import SGD @@ -6,20 +7,25 @@ from stalactite.models import LogisticRegressionBatch from stalactite.utils import init_linear_np +logger = logging.getLogger(__name__) + class HonestPartyMemberLogReg(HonestPartyMemberLinReg): def initialize_model_from_params(self, **model_params) -> Any: - return LogisticRegressionBatch(**model_params) + return LogisticRegressionBatch(**model_params).to(self.device) def initialize_model(self, do_load_model: bool = False) -> None: """ Initialize the model based on the specified model name. """ + logger.info(f"Member {self.id} initializes model on device: {self.device}") + logger.debug(f"Model is loaded from path: {do_load_model}") if do_load_model: - self._model = self.load_model() + self._model = self.load_model().to(self.device) else: self._model = LogisticRegressionBatch( input_dim=self._dataset[self._data_params.train_split][self._data_params.features_key].shape[1], **self._model_params - ) + ).to(self.device) init_linear_np(self._model.linear, seed=self.seed) + self._model.linear.to(self.device) def initialize_optimizer(self) -> None: self._optimizer = SGD([ @@ -29,3 +35,8 @@ def initialize_optimizer(self) -> None: momentum=self._common_params.momentum, weight_decay=self._common_params.weight_decay, ) + + def move_model_to_device(self): + # As the class is inherited from the Linear regression model, we need to skip this step with returning the model + # to device after weights updates + pass diff --git a/stalactite/ml/honest/split_learning/base.py b/stalactite/ml/honest/split_learning/base.py index 484336a..6dd1f3e 100644 --- a/stalactite/ml/honest/split_learning/base.py +++ b/stalactite/ml/honest/split_learning/base.py @@ -16,35 +16,46 @@ class HonestPartyMasterSplitNN(HonestPartyMasterLinReg): + def finalize(self, is_infer: bool = False) -> None: + """ Finalize the party master. """ + logger.info(f"Master {self.id}: finalizing") + self.check_if_ready() + if self.do_save_model and not is_infer: + self.save_model() + self.is_finalized = True + logger.info(f"Master {self.id}: has finalized") def predict(self, x: DataTensor, is_infer: bool = False, use_activation: bool = False) -> DataTensor: """ Make predictions using the master model. :return: Model predictions. """ - logger.info("Master: predicting.") + logger.info(f"Master {self.id}: predicting") self.check_if_ready() - predictions = self._model.predict(x) + if is_infer: + self._model.eval() + predictions = self._model.predict(x.to(self.device)) + self._model.train() if use_activation: predictions = self.activation(predictions) - logger.info("Master: made predictions.") + logger.debug(f"Master {self.id}: made predictions") return predictions def update_weights(self, agg_members_output: DataTensor, upd: DataTensor) -> None: - logger.info(f"Master: updating weights. Incoming tensor: {upd.size()}") + logger.info(f"Master {self.id}: updating weights. Incoming tensor: {upd.size()}") self.check_if_ready() self._model.update_weights(x=agg_members_output, gradients=upd, is_single=False, optimizer=self._optimizer) - logger.info("Master: successfully updated weights") + logger.debug(f"Master {self.id}: successfully updated weights") def update_predict(self, upd: DataTensor, agg_members_output: DataTensor) -> DataTensor: - logger.info("Master: updating and predicting.") + logger.info(f"Master {self.id}: updating and predicting") self.check_if_ready() # get aggregated output from previous batch if exist (we do not make update_weights if it's the first iter) if self.aggregated_output is not None: self.update_weights( agg_members_output=self.aggregated_output, upd=upd) predictions = self.predict(agg_members_output, use_activation=False) - logger.info("Master: updated and predicted.") + logger.debug(f"Master {self.id}: updated and predicted") # save current agg_members_output for making update_predict for next batch self.aggregated_output = copy(agg_members_output) return predictions @@ -67,13 +78,13 @@ def compute_updates( :return: List of gradients as tensors. """ - logger.info("Master %s: computing updates (world size %s)" % (self.id, world_size)) + logger.info(f"Master {self.id}: computes updates (world size {world_size})") self.check_if_ready() self.iteration_counter += 1 tensor_idx = [self.uid2tensor_idx[uid] for uid in uids] y = self.target[tensor_idx] targets_type = torch.LongTensor if isinstance(self._criterion, torch.nn.CrossEntropyLoss) else torch.FloatTensor - loss = self._criterion(torch.squeeze(master_predictions), y.type(targets_type)) + loss = self._criterion(torch.squeeze(master_predictions), y.type(targets_type).to(self.device)) if self.run_mlflow: mlflow.log_metric("loss", loss.item(), step=self.iteration_counter) @@ -84,8 +95,8 @@ def compute_updates( self.updates["master"] = torch.autograd.grad( outputs=loss, inputs=master_predictions, retain_graph=True )[0] - - return [self.updates[member_id] for member_id in participating_members] + logger.debug(f"Master {self.id}: computed updates") + return [self.updates[member_id].contiguous() for member_id in participating_members] def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: """ Run main training loop on the VFL master. @@ -94,13 +105,13 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: :param party: Communicator instance used for communication between VFL agents. :return: None """ - logger.info("Master %s: entering training loop" % self.id) + logger.info(f"Master {self.id}: entering training loop") updates = self.make_init_updates(party.world_size) for titer in batcher: - logger.debug( - f"Master %s: train loop - starting batch %s (sub iter %s) on epoch %s" - % (self.id, titer.seq_num, titer.subiter_seq_num, titer.epoch) + logger.info( + f"Master {self.id}: train loop - starting batch {titer.seq_num} (sub iter {titer.subiter_seq_num}) " + f"on epoch {titer.epoch}" ) iter_start_time = time.time() # tasks for members @@ -135,8 +146,8 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: if self.report_train_metrics_iteration > 0 and titer.seq_num % self.report_train_metrics_iteration == 0: logger.debug( - f"Master %s: train loop - reporting train metrics on iteration %s of epoch %s" - % (self.id, titer.seq_num, titer.epoch) + f"Master {self.id}: train loop - reporting train metrics on iteration {titer.seq_num} " + f"of epoch {titer.epoch}" ) predict_tasks = party.broadcast( Method.predict, @@ -148,19 +159,23 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: task.result for task in party.gather(predict_tasks, recv_results=True) ] - agg_members_predictions = self.aggregate(party.members, party_members_predictions, is_infer=True) + agg_members_predictions = self.aggregate( + titer.participating_members, + party_members_predictions, + is_infer=True + ) master_predictions = self.predict(x=agg_members_predictions, use_activation=True) target = self.target[[self.uid2tensor_idx[uid] for uid in batcher.uids]] self.report_metrics( - target.numpy(), master_predictions.detach().numpy(), name="Train", step=titer.seq_num + target, master_predictions, name="Train", step=titer.seq_num ) if self.report_test_metrics_iteration > 0 and titer.seq_num % self.report_test_metrics_iteration == 0: logger.debug( - f"Master %s: train loop - reporting test metrics on iteration %s of epoch %s" - % (self.id, titer.seq_num, titer.epoch) + f"Master {self.id}: train loop - reporting test metrics on iteration {titer.seq_num} " + f"of epoch {titer.epoch}" ) predict_test_tasks = party.broadcast( Method.predict, @@ -171,17 +186,52 @@ def loop(self, batcher: Batcher, party: PartyCommunicator) -> None: party_members_predictions = [ task.result for task in party.gather(predict_test_tasks, recv_results=True) ] - agg_members_predictions = self.aggregate(party.members, party_members_predictions, is_infer=True) + agg_members_predictions = self.aggregate( + titer.participating_members, + party_members_predictions, + is_infer=True + ) master_predictions = self.predict(x=agg_members_predictions, use_activation=True) self.report_metrics( - self.test_target.numpy(), master_predictions.detach().numpy(), name="Test", step=titer.seq_num + self.test_target, master_predictions, name="Test", step=titer.seq_num ) self.iteration_times.append( IterationTime(client_id=self.id, iteration=titer.seq_num, iteration_time=time.time() - iter_start_time) ) + def inference_loop(self, batcher: Batcher, party: PartyCommunicator) -> None: + logger.info(f"Master {self.id}: entering inference loop") + predictions = torch.tensor([], device=self.device) + test_targets = torch.tensor([], device=self.device) + for titer in batcher: + if titer.last_batch: + break + logger.info( + f"Master {self.id}: inference loop - starting batch {titer.seq_num} (sub iter {titer.subiter_seq_num})" + ) + predict_test_tasks = party.broadcast( + Method.predict, + method_kwargs=MethodKwargs(other_kwargs={"uids": titer.batch, "is_infer": True}), + participating_members=titer.participating_members, + ) + party_members_predictions = [ + task.result for task in party.gather(predict_test_tasks, recv_results=True) + ] + agg_members_predictions = self.aggregate( + titer.participating_members, party_members_predictions, is_infer=True + ) + master_predictions = self.predict(x=agg_members_predictions, use_activation=True) + target = self.test_target[[self.uid2tensor_idx_test[uid] for uid in titer.batch]] + test_targets = torch.cat([test_targets, target]) + predictions = torch.cat([predictions, master_predictions]) + self.report_metrics(test_targets, predictions, name="Test", step=-1) + def report_metrics(self, y: DataTensor, predictions: DataTensor, name: str, step: int) -> None: + logger.info(f"Master {self.id} reporting metrics") + logger.debug(f"Predictions size: {predictions.size()}, Target size: {y.size()}") + y = y.cpu().numpy() + predictions = predictions.cpu().detach().numpy() if self.binary: for avg in ["macro", "micro"]: try: @@ -197,7 +247,6 @@ def report_metrics(self, y: DataTensor, predictions: DataTensor, name: str, step mlflow.log_metric(f"{name.lower()}_roc_auc_{avg}", roc_auc, step=step) mlflow.log_metric(f"{name.lower()}_rmse", rmse, step=step) else: - avg = "macro" roc_auc = roc_auc_score(y, predictions, average=avg, multi_class="ovr") logger.info(f'{name} ROC AUC {avg} on step {step}: {roc_auc}') diff --git a/stalactite/ml/honest/split_learning/efficientnet/party_master.py b/stalactite/ml/honest/split_learning/efficientnet/party_master.py index d1ed512..4f9b4f1 100644 --- a/stalactite/ml/honest/split_learning/efficientnet/party_master.py +++ b/stalactite/ml/honest/split_learning/efficientnet/party_master.py @@ -1,13 +1,12 @@ import logging -from typing import List +from typing import List, Any import mlflow import torch from torch import nn from sklearn.metrics import roc_auc_score - -from stalactite.models.split_learning import EfficientNetTop, MLPTop, ResNetTop +from stalactite.models.split_learning import EfficientNetTop from stalactite.ml.honest.split_learning.base import HonestPartyMasterSplitNN from stalactite.base import DataTensor, PartyDataTensor @@ -16,11 +15,20 @@ class HonestPartyMasterEfficientNetSplitNN(HonestPartyMasterSplitNN): + def initialize_model_from_params(self, **model_params) -> Any: + return EfficientNetTop(**model_params).to(self.device) + def initialize_model(self, do_load_model: bool = False) -> None: """ Initialize the model based on the specified model name. """ - self._model = EfficientNetTop(**self._model_params, seed=self.seed) - class_weights = None if self.class_weights is None else self.class_weights.type(torch.FloatTensor) - self._criterion = nn.CrossEntropyLoss(weight=class_weights) + logger.info(f"Master {self.id} initializes model on device: {self.device}") + logger.debug(f"Model is loaded from path: {do_load_model}") + if do_load_model: + self._model = self.load_model().to(self.device) + else: + self._model = EfficientNetTop(**self._model_params, seed=self.seed).to(self.device) + class_weights = None if self.class_weights is None else self.class_weights.type(torch.FloatTensor) \ + .to(self.device) + self._criterion = nn.CrossEntropyLoss(weight=class_weights) self._activation = nn.Softmax(dim=1) def initialize_optimizer(self) -> None: @@ -36,16 +44,21 @@ def initialize_optimizer(self) -> None: def aggregate( self, participating_members: List[str], party_predictions: PartyDataTensor, is_infer: bool = False ) -> DataTensor: - logger.info("Master %s: aggregating party predictions (num predictions %s)" % (self.id, len(party_predictions))) - self.check_if_ready() + logger.info(f"Master {self.id}: aggregates party predictions (number of predictions {len(party_predictions)})") + self.check_if_ready() for member_id, member_prediction in zip(participating_members, party_predictions): - self.party_predictions[member_id] = member_prediction + self.party_predictions[member_id] = member_prediction.to(self.device) party_predictions = list(self.party_predictions.values()) predictions = torch.mean(torch.stack(party_predictions, dim=1), dim=1) return predictions def report_metrics(self, y: DataTensor, predictions: DataTensor, name: str, step: int) -> None: + logger.info(f"Master {self.id} reporting metrics") + logger.debug(f"Predictions size: {predictions.size()}, Target size: {y.size()}") + y = y.cpu().numpy() + predictions = predictions.cpu().detach().numpy() + postfix = "-infer" if step == -1 else "" step = step if step != -1 else None @@ -53,4 +66,4 @@ def report_metrics(self, y: DataTensor, predictions: DataTensor, name: str, step roc_auc = roc_auc_score(y, predictions, average=avg, multi_class="ovr") logger.info(f'{name} ROC AUC {avg} on step {step}: {roc_auc}') if self.run_mlflow: - mlflow.log_metric(f"{name.lower()}_roc_auc_{avg}{postfix}", roc_auc, step=step) \ No newline at end of file + mlflow.log_metric(f"{name.lower()}_roc_auc_{avg}{postfix}", roc_auc, step=step) diff --git a/stalactite/ml/honest/split_learning/efficientnet/party_member.py b/stalactite/ml/honest/split_learning/efficientnet/party_member.py index 45757c6..5468002 100644 --- a/stalactite/ml/honest/split_learning/efficientnet/party_member.py +++ b/stalactite/ml/honest/split_learning/efficientnet/party_member.py @@ -1,3 +1,4 @@ +import logging from typing import Any from torch.optim import SGD @@ -5,15 +6,22 @@ from stalactite.ml.honest.linear_regression.party_member import HonestPartyMemberLinReg from stalactite.models.split_learning import EfficientNetBottom +logger = logging.getLogger(__name__) + class HonestPartyMemberEfficientNet(HonestPartyMemberLinReg): + def initialize_model_from_params(self, **model_params) -> Any: + return EfficientNetBottom(**model_params).to(self.device) + def initialize_model(self, do_load_model: bool = False) -> None: """ Initialize the model based on the specified model name. """ + logger.info(f"Member {self.id} initializes model on device: {self.device}") + logger.debug(f"Model is loaded from path: {do_load_model}") if do_load_model: - self._model = self.load_model() + self._model = self.load_model().to(self.device) else: - self._model = EfficientNetBottom(**self._model_params, seed=self.seed) + self._model = EfficientNetBottom(**self._model_params, seed=self.seed).to(self.device) def initialize_optimizer(self) -> None: self._optimizer = SGD([ @@ -22,5 +30,9 @@ def initialize_optimizer(self) -> None: lr=self._common_params.learning_rate, momentum=self._common_params.momentum, weight_decay=self._common_params.weight_decay, - ) + + def move_model_to_device(self): + # As the class is inherited from the Linear regression model, we need to skip this step with returning the model + # to device after weights updates + pass diff --git a/stalactite/ml/honest/split_learning/mlp/party_master.py b/stalactite/ml/honest/split_learning/mlp/party_master.py index d1d2cca..ca12279 100644 --- a/stalactite/ml/honest/split_learning/mlp/party_master.py +++ b/stalactite/ml/honest/split_learning/mlp/party_master.py @@ -1,5 +1,5 @@ import logging -from typing import List +from typing import List, Any import torch @@ -12,10 +12,21 @@ class HonestPartyMasterMLPSplitNN(HonestPartyMasterSplitNN): + def initialize_model_from_params(self, **model_params) -> Any: + return MLPTop(**model_params).to(self.device) + def initialize_model(self, do_load_model: bool = False) -> None: """ Initialize the model based on the specified model name. """ - self._model = MLPTop(**self._model_params, seed=self.seed) - self._criterion = torch.nn.BCEWithLogitsLoss(pos_weight=self.class_weights) if self.binary else torch.nn.CrossEntropyLoss(weight=self.class_weights) + logger.info(f"Master {self.id} initializes model on device: {self.device}") + logger.debug(f"Model is loaded from path: {do_load_model}") + if do_load_model: + self._model = self.load_model().to(self.device) + else: + self._model = MLPTop(**self._model_params, seed=self.seed).to(self.device) + class_weights = None if self.class_weights is None else self.class_weights.type(torch.FloatTensor) \ + .to(self.device) + self._criterion = torch.nn.BCEWithLogitsLoss(pos_weight=class_weights) \ + if self.binary else torch.nn.CrossEntropyLoss(weight=class_weights) def initialize_optimizer(self) -> None: self._optimizer = torch.optim.SGD([ @@ -24,17 +35,17 @@ def initialize_optimizer(self) -> None: lr=self._common_params.learning_rate, momentum=self._common_params.momentum, weight_decay=self._common_params.weight_decay, - ) def aggregate( self, participating_members: List[str], party_predictions: PartyDataTensor, is_infer: bool = False ) -> DataTensor: - logger.info("Master %s: aggregating party predictions (num predictions %s)" % (self.id, len(party_predictions))) + logger.info(f"Master {self.id}: aggregates party predictions (number of predictions {len(party_predictions)})") + self.check_if_ready() for member_id, member_prediction in zip(participating_members, party_predictions): - self.party_predictions[member_id] = member_prediction + self.party_predictions[member_id] = member_prediction.to(self.device) party_predictions = list(self.party_predictions.values()) predictions = torch.sum(torch.stack(party_predictions, dim=1), dim=1) return predictions diff --git a/stalactite/ml/honest/split_learning/mlp/party_member.py b/stalactite/ml/honest/split_learning/mlp/party_member.py index 72a7e85..8df3771 100644 --- a/stalactite/ml/honest/split_learning/mlp/party_member.py +++ b/stalactite/ml/honest/split_learning/mlp/party_member.py @@ -1,18 +1,28 @@ +import logging +from typing import Any + from torch.optim import SGD from stalactite.ml.honest.linear_regression.party_member import HonestPartyMemberLinReg from stalactite.models.split_learning import MLPBottom +logger = logging.getLogger(__name__) + class HonestPartyMemberMLP(HonestPartyMemberLinReg): + def initialize_model_from_params(self, **model_params) -> Any: + return MLPBottom(**model_params).to(self.device) + def initialize_model(self, do_load_model: bool = False) -> None: """ Initialize the model based on the specified model name. """ + logger.info(f"Member {self.id} initializes model on device: {self.device}") + logger.debug(f"Model is loaded from path: {do_load_model}") input_dim = self._dataset[self._data_params.train_split][self._data_params.features_key].shape[1] if do_load_model: - self._model = self.load_model() + self._model = self.load_model().to(self.device) else: - self._model = MLPBottom(input_dim=input_dim, **self._model_params, seed=self.seed) + self._model = MLPBottom(input_dim=input_dim, **self._model_params, seed=self.seed).to(self.device) def initialize_optimizer(self) -> None: self._optimizer = SGD([ @@ -23,3 +33,8 @@ def initialize_optimizer(self) -> None: weight_decay=self._common_params.weight_decay, ) + + def move_model_to_device(self): + # As the class is inherited from the Linear regression model, we need to skip this step with returning the model + # to device after weights updates + pass diff --git a/stalactite/ml/honest/split_learning/resnet/party_master.py b/stalactite/ml/honest/split_learning/resnet/party_master.py index 418a2f2..c000e28 100644 --- a/stalactite/ml/honest/split_learning/resnet/party_master.py +++ b/stalactite/ml/honest/split_learning/resnet/party_master.py @@ -1,8 +1,7 @@ import logging -from typing import List +from typing import List, Any import torch -from torch import nn from stalactite.models.split_learning import ResNetTop from stalactite.ml.honest.split_learning.base import HonestPartyMasterSplitNN from stalactite.base import DataTensor, PartyDataTensor @@ -12,11 +11,21 @@ class HonestPartyMasterResNetSplitNN(HonestPartyMasterSplitNN): + def initialize_model_from_params(self, **model_params) -> Any: + return ResNetTop(**model_params).to(self.device) + def initialize_model(self, do_load_model: bool = False) -> None: """ Initialize the model based on the specified model name. """ - self._model = ResNetTop(**self._model_params, seed=self.seed) - self._criterion = torch.nn.BCEWithLogitsLoss( - pos_weight=self.class_weights) if self.binary else torch.nn.CrossEntropyLoss(weight=self.class_weights) + logger.info(f"Master {self.id} initializes model on device: {self.device}") + logger.debug(f"Model is loaded from path: {do_load_model}") + if do_load_model: + self._model = self.load_model().to(self.device) + else: + self._model = ResNetTop(**self._model_params, seed=self.seed).to(self.device) + class_weights = None if self.class_weights is None else self.class_weights.type(torch.FloatTensor) \ + .to(self.device) + self._criterion = torch.nn.BCEWithLogitsLoss( + pos_weight=class_weights) if self.binary else torch.nn.CrossEntropyLoss(weight=class_weights) def initialize_optimizer(self) -> None: self._optimizer = torch.optim.SGD([ @@ -31,11 +40,10 @@ def initialize_optimizer(self) -> None: def aggregate( self, participating_members: List[str], party_predictions: PartyDataTensor, is_infer: bool = False ) -> DataTensor: - logger.info("Master %s: aggregating party predictions (num predictions %s)" % (self.id, len(party_predictions))) + logger.info(f"Master {self.id}: aggregates party predictions (number of predictions {len(party_predictions)})") self.check_if_ready() - for member_id, member_prediction in zip(participating_members, party_predictions): - self.party_predictions[member_id] = member_prediction + self.party_predictions[member_id] = member_prediction.to(self.device) party_predictions = list(self.party_predictions.values()) predictions = torch.cat(party_predictions, dim=1) return predictions diff --git a/stalactite/ml/honest/split_learning/resnet/party_member.py b/stalactite/ml/honest/split_learning/resnet/party_member.py index 729bac9..5eb037e 100644 --- a/stalactite/ml/honest/split_learning/resnet/party_member.py +++ b/stalactite/ml/honest/split_learning/resnet/party_member.py @@ -1,3 +1,4 @@ +import logging from typing import Any from torch.optim import SGD @@ -5,16 +6,23 @@ from stalactite.ml.honest.linear_regression.party_member import HonestPartyMemberLinReg from stalactite.models.split_learning import ResNetBottom +logger = logging.getLogger(__name__) + class HonestPartyMemberResNet(HonestPartyMemberLinReg): + def initialize_model_from_params(self, **model_params) -> Any: + return ResNetBottom(**model_params).to(self.device) + def initialize_model(self, do_load_model: bool = False) -> None: """ Initialize the model based on the specified model name. """ + logger.info(f"Member {self.id} initializes model on device: {self.device}") + logger.debug(f"Model is loaded from path: {do_load_model}") input_dim = self._dataset[self._data_params.train_split][self._data_params.features_key].shape[1] if do_load_model: - self._model = self.load_model() + self._model = self.load_model().to(self.device) else: - self._model = ResNetBottom(input_dim=input_dim, **self._model_params, seed=self.seed) + self._model = ResNetBottom(input_dim=input_dim, **self._model_params, seed=self.seed).to(self.device) def initialize_optimizer(self) -> None: self._optimizer = SGD([ @@ -25,3 +33,8 @@ def initialize_optimizer(self) -> None: weight_decay=self._common_params.weight_decay, ) + + def move_model_to_device(self): + # As the class is inherited from the Linear regression model, we need to skip this step with returning the model + # to device after weights updates + pass diff --git a/stalactite/models/efficient_net.py b/stalactite/models/efficient_net.py index 0634552..30d2ef3 100644 --- a/stalactite/models/efficient_net.py +++ b/stalactite/models/efficient_net.py @@ -57,6 +57,12 @@ def __init__( super().__init__() _log_api_usage_once(self) + self.width_mult = width_mult + self.depth_mult = depth_mult + self.dropout = dropout + self.stochastic_depth_prob = stochastic_depth_prob + self.num_classes = num_classes + self.criterion = torch.nn.CrossEntropyLoss() inverted_residual_setting, last_channel = _efficientnet_conf(width_mult=width_mult, depth_mult=depth_mult) @@ -154,7 +160,7 @@ def update_weights(self, x: torch.Tensor, y: torch.Tensor, is_single: bool = Fal optimizer.zero_grad() logit = self.forward(x) loss = self.criterion(torch.squeeze(logit), y.type(torch.LongTensor)) - logger.info(f"loss: {loss.item()}") + logger.info(f"Loss: {loss.item()}") loss.backward() optimizer.step() @@ -164,3 +170,14 @@ def predict(self, x: torch.Tensor) -> torch.Tensor: def get_weights(self) -> torch.Tensor: return self.linear.weight.clone() + @property + def init_params(self): + return { + 'width_mult': self.width_mult, + 'depth_mult': self.depth_mult, + 'dropout': self.dropout, + 'stochastic_depth_prob': self.stochastic_depth_prob, + 'num_classes': self.num_classes, + } + + diff --git a/stalactite/models/linreg_batch.py b/stalactite/models/linreg_batch.py index 225bd25..fcf1307 100644 --- a/stalactite/models/linreg_batch.py +++ b/stalactite/models/linreg_batch.py @@ -37,11 +37,11 @@ def forward(self, x): def update_weights(self, X_train, rhs, optimizer=None) -> None: # todo: add docs - U, S, Vh = sp.linalg.svd(X_train.numpy(), full_matrices=False, overwrite_a=False, check_finite=False) - logger.debug("updating weights inside model") - coeffs, num_rank = solve_ols_svd(U, S, Vh, rhs, self.reg_lambda) + U, S, Vh = sp.linalg.svd(X_train.cpu().numpy(), full_matrices=False, overwrite_a=False, check_finite=False) + logger.debug("Updating weights inside model") + coeffs, num_rank = solve_ols_svd(U, S, Vh, rhs.cpu(), self.reg_lambda) self.linear.weight.copy_(torch.as_tensor(coeffs).t()) # TODO: copying is not efficient - logger.debug("SUCCESS update weights") + logger.debug("Success: update weights") def predict(self, X_pred): Y_pred = self.forward(X_pred) diff --git a/stalactite/models/logreg_batch.py b/stalactite/models/logreg_batch.py index 9fb26c7..32e2557 100644 --- a/stalactite/models/logreg_batch.py +++ b/stalactite/models/logreg_batch.py @@ -49,6 +49,7 @@ def update_weights( targets_type = torch.LongTensor if isinstance(criterion, torch.nn.CrossEntropyLoss) else torch.FloatTensor loss = criterion(torch.squeeze(logit), gradients.type(targets_type)) + logger.info(f"Loss: {loss.item()}") loss.backward() else: logit.backward(gradient=gradients) diff --git a/stalactite/models/mlp.py b/stalactite/models/mlp.py index c1ec223..2824a02 100644 --- a/stalactite/models/mlp.py +++ b/stalactite/models/mlp.py @@ -12,20 +12,24 @@ class MLP(nn.Module): def __init__( - self, - input_dim: int, - output_dim: int, - hidden_channels: List[int], - norm_layer: Optional[Callable[..., torch.nn.Module]] = None, - activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, - bias: bool = True, - dropout: float = 0.0, - init_weights: float = None, + self, + input_dim: int, + output_dim: int, + hidden_channels: List[int], + norm_layer: Optional[Callable[..., torch.nn.Module]] = None, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + bias: bool = True, + dropout: float = 0.0, + init_weights: float = None, ) -> None: super().__init__() _log_api_usage_once(self) + self.input_dim = input_dim + self.output_dim = output_dim + self.dropout = dropout + layers = [] in_dim = input_dim for hidden_dim in hidden_channels: @@ -64,9 +68,17 @@ def update_weights(self, x: torch.Tensor, y: torch.Tensor, is_single: bool = Fal targets_type = torch.LongTensor if isinstance(criterion, torch.nn.CrossEntropyLoss) else torch.FloatTensor loss = criterion(torch.squeeze(logit), y.type(targets_type)) - logger.info(f"loss: {loss.item()}") + logger.info(f"Loss: {loss.item()}") loss.backward() optimizer.step() def predict(self, x: torch.Tensor) -> torch.Tensor: return self.forward(x) + + @property + def init_params(self): + return { + 'input_dim': self.input_dim, + 'output_dim': self.output_dim, + 'dropout': self.dropout, + } diff --git a/stalactite/models/resnet.py b/stalactite/models/resnet.py index 76e90c9..25a2eba 100644 --- a/stalactite/models/resnet.py +++ b/stalactite/models/resnet.py @@ -49,17 +49,17 @@ class ResNetBlock(nn.Module): """ def __init__( - self, - n_in: int, - hid_factor: float, - n_out: int, - drop_rate: List[float] = [0.1, 0.1], - noise_std: float = 0.05, - act_fun: nn.Module = nn.ReLU, - use_bn: bool = True, - use_noise: bool = False, - device: torch.device = torch.device("cpu"), - **kwargs, + self, + n_in: int, + hid_factor: float, + n_out: int, + drop_rate: List[float] = [0.1, 0.1], + noise_std: float = 0.05, + act_fun: nn.Module = nn.ReLU, + use_bn: bool = True, + use_noise: bool = False, + device: torch.device = torch.device("cpu"), + **kwargs, ): super(ResNetBlock, self).__init__() self.features = nn.Sequential(OrderedDict([])) @@ -104,20 +104,26 @@ class ResNet(nn.Module): """ def __init__( - self, - input_dim: int, - output_dim: int = 1, - hid_factor: List[float] = [2, 2], - drop_rate: Union[float, List[float], List[List[float]]] = 0.1, - noise_std: float = 0.05, - act_fun: nn.Module = nn.ReLU, - num_init_features: Optional[int] = None, - use_bn: bool = True, - use_noise: bool = False, - device: torch.device = torch.device("cpu"), - init_weights: float = None, - **kwargs, + self, + input_dim: int, + output_dim: int = 1, + hid_factor: List[float] = [2, 2], + drop_rate: Union[float, List[float], List[List[float]]] = 0.1, + noise_std: float = 0.05, + act_fun: nn.Module = nn.ReLU, + num_init_features: Optional[int] = None, + use_bn: bool = True, + use_noise: bool = False, + device: torch.device = torch.device("cpu"), + init_weights: float = None, + **kwargs, ): + + self.input_dim = input_dim + self.output_dim = output_dim + self.drop_rate = drop_rate + self.noise_std = noise_std + super(ResNet, self).__init__() if isinstance(drop_rate, float): drop_rate = [[drop_rate, drop_rate]] * len(hid_factor) @@ -125,7 +131,7 @@ def __init__( drop_rate = [drop_rate] * len(hid_factor) else: assert ( - len(drop_rate) == len(hid_factor) and len(drop_rate[0]) == 2 + len(drop_rate) == len(hid_factor) and len(drop_rate[0]) == 2 ), "Wrong number hidden_sizes/drop_rates. Must be equal." num_features = input_dim if num_init_features is None else num_init_features @@ -180,13 +186,22 @@ def update_weights(self, x: torch.Tensor, y: torch.Tensor, is_single: bool = Fal targets_type = torch.LongTensor if isinstance(criterion, torch.nn.CrossEntropyLoss) else torch.FloatTensor loss = criterion(torch.squeeze(logit), y.type(targets_type)) - logger.info(f"loss: {loss.item()}") + logger.info(f"Loss: {loss.item()}") loss.backward() optimizer.step() def predict(self, x: torch.Tensor) -> torch.Tensor: return self.forward(x) + @property + def init_params(self): + return { + 'input_dim': self.input_dim, + 'output_dim': self.output_dim, + 'drop_rate': self.drop_rate, + 'noise_std': self.noise_std, + } + if __name__ == "__main__": model = ResNet(input_dim=200, output_dim=10, use_bn=True, hid_factor=[0.1, 0.1]) diff --git a/stalactite/models/split_learning/efficientnet_bottom.py b/stalactite/models/split_learning/efficientnet_bottom.py index d879a8c..10c9ab1 100644 --- a/stalactite/models/split_learning/efficientnet_bottom.py +++ b/stalactite/models/split_learning/efficientnet_bottom.py @@ -53,6 +53,11 @@ def __init__( super().__init__() _log_api_usage_once(self) self.seed = seed + self.width_mult = width_mult + self.depth_mult = depth_mult + self.stochastic_depth_prob = stochastic_depth_prob + self.init_weights = init_weights + inverted_residual_setting, last_channel = _efficientnet_conf(width_mult=width_mult, depth_mult=depth_mult) if norm_layer is None: @@ -138,3 +143,13 @@ def predict(self, x: torch.Tensor) -> torch.Tensor: def get_weights(self) -> torch.Tensor: return self.linear.weight.clone() + + @property + def init_params(self): + return { + 'seed': self.seed, + 'width_mult': self.width_mult, + 'depth_mult': self.depth_mult, + 'stochastic_depth_prob': self.stochastic_depth_prob, + 'init_weights': self.init_weights, + } diff --git a/stalactite/models/split_learning/efficientnet_top.py b/stalactite/models/split_learning/efficientnet_top.py index 3d55235..7266ded 100644 --- a/stalactite/models/split_learning/efficientnet_top.py +++ b/stalactite/models/split_learning/efficientnet_top.py @@ -1,9 +1,8 @@ -import math +import logging from functools import partial from typing import Optional, Sequence, Union, Tuple import torch -import numpy as np from torch import nn, Tensor from torchvision.models.efficientnet import MBConvConfig, FusedMBConvConfig @@ -11,10 +10,12 @@ from stalactite.utils import init_linear_np +logger = logging.getLogger(__name__) + def _efficientnet_conf( - width_mult: float, - depth_mult: float + width_mult: float, + depth_mult: float ) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]: inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]] bneck_conf = partial(MBConvConfig, width_mult=width_mult, depth_mult=depth_mult) @@ -34,12 +35,12 @@ def _efficientnet_conf( class EfficientNetTop(nn.Module): def __init__( - self, - dropout: float = 0.1, - input_dim=None, - num_classes: int = 1000, - init_weights: float = None, - seed: int = None, + self, + dropout: float = 0.1, + input_dim=None, + num_classes: int = 1000, + init_weights: float = None, + seed: int = None, ) -> None: """ @@ -54,6 +55,11 @@ def __init__( super().__init__() _log_api_usage_once(self) + self.dropout = dropout + self.input_dim = input_dim + self.num_classes = num_classes + self.init_weights = init_weights + self.criterion = torch.nn.CrossEntropyLoss() self.seed = seed self.avgpool = nn.AdaptiveAvgPool2d(1) @@ -87,6 +93,7 @@ def update_weights(self, x: torch.Tensor, gradients: torch.Tensor, is_single: bo if is_single: logit = self.forward(x) loss = self.criterion(torch.squeeze(logit), gradients.type(torch.LongTensor)) + logger.info(f"Loss: {loss.item()}") grads = torch.autograd.grad(outputs=loss, inputs=x, retain_graph=True) loss.backward() optimizer.step() @@ -95,13 +102,19 @@ def update_weights(self, x: torch.Tensor, gradients: torch.Tensor, is_single: bo model_output = self.forward(x) model_output.backward(gradient=gradients) optimizer.step() - + def predict(self, x: torch.Tensor) -> torch.Tensor: return self.forward(x) def get_weights(self) -> torch.Tensor: return self.linear.weight.clone() - - - + @property + def init_params(self): + return { + 'dropout': self.dropout, + 'input_dim': self.input_dim, + 'num_classes': self.num_classes, + 'init_weights': self.init_weights, + 'seed': self.seed, + } diff --git a/stalactite/models/split_learning/mlp_bottom.py b/stalactite/models/split_learning/mlp_bottom.py index 5b73b61..345c3ea 100644 --- a/stalactite/models/split_learning/mlp_bottom.py +++ b/stalactite/models/split_learning/mlp_bottom.py @@ -28,7 +28,14 @@ def __init__( super().__init__() _log_api_usage_once(self) + + self.input_dim = input_dim + self.bias = bias + self.dropout = dropout + self.multilabel = multilabel + self.init_weights = init_weights self.seed = seed + if multilabel: self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=class_weights) @@ -69,3 +76,14 @@ def update_weights(self, x: torch.Tensor, gradients: torch.Tensor, is_single: bo def predict(self, x: torch.Tensor) -> torch.Tensor: return self.forward(x) + + @property + def init_params(self): + return { + 'input_dim': self.input_dim, + 'bias': self.bias, + 'dropout': self.dropout, + 'multilabel': self.multilabel, + 'init_weights': self.init_weights, + 'seed': self.seed, + } diff --git a/stalactite/models/split_learning/mlp_top.py b/stalactite/models/split_learning/mlp_top.py index 1181e00..3505c70 100644 --- a/stalactite/models/split_learning/mlp_top.py +++ b/stalactite/models/split_learning/mlp_top.py @@ -24,7 +24,13 @@ def __init__( super().__init__() _log_api_usage_once(self) + self.input_dim = input_dim + self.output_dim = output_dim + self.bias = bias + self.multilabel = multilabel + self.init_weights = init_weights self.seed = seed + if multilabel: self.criterion = torch.nn.BCEWithLogitsLoss(pos_weight=class_weights) else: @@ -55,6 +61,7 @@ def update_weights(self, x: torch.Tensor, gradients: torch.Tensor, is_single: bo logit = self.forward(x) loss = self.criterion(torch.squeeze(logit), gradients.type(torch.FloatTensor)) grads = torch.autograd.grad(outputs=loss, inputs=x, retain_graph=True) + logger.info(f"Loss: {loss.item()}") loss.backward() optimizer.step() return grads[0] @@ -66,5 +73,14 @@ def update_weights(self, x: torch.Tensor, gradients: torch.Tensor, is_single: bo def predict(self, x: torch.Tensor) -> torch.Tensor: return self.forward(x) - + @property + def init_params(self): + return { + 'input_dim': self.input_dim, + 'output_dim': self.output_dim, + 'bias': self.bias, + 'multilabel': self.multilabel, + 'init_weights': self.init_weights, + 'seed': self.seed, + } diff --git a/stalactite/models/split_learning/resnet_bottom.py b/stalactite/models/split_learning/resnet_bottom.py index 3b27e26..e35b450 100644 --- a/stalactite/models/split_learning/resnet_bottom.py +++ b/stalactite/models/split_learning/resnet_bottom.py @@ -46,6 +46,15 @@ def __init__( **kwargs, ): super(ResNetBottom, self).__init__() + + self.input_dim = input_dim + self.noise_std = noise_std + self.use_bn = use_bn + self.num_init_features = num_init_features + self.use_noise = use_noise + self.init_weights = init_weights + self.seed = seed + if isinstance(drop_rate, float): drop_rate = [[drop_rate, drop_rate]] * len(hid_factor) elif isinstance(drop_rate, list) and len(drop_rate) == 2: @@ -54,7 +63,6 @@ def __init__( assert ( len(drop_rate) == len(hid_factor) and len(drop_rate[0]) == 2 ), "Wrong number hidden_sizes/drop_rates. Must be equal." - self.seed = seed num_features = input_dim if num_init_features is None else num_init_features self.dense0 = nn.Linear(input_dim, num_features) if num_init_features is not None else nn.Identity() self.features1 = nn.Sequential(OrderedDict([])) @@ -103,3 +111,16 @@ def update_weights(self, x: torch.Tensor, gradients: torch.Tensor, criterion=Non def predict(self, x: torch.Tensor) -> torch.Tensor: return self.forward(x) + + @property + def init_params(self): + return { + 'input_dim': self.input_dim, + 'noise_std': self.noise_std, + 'use_bn': self.use_bn, + 'num_init_features': self.num_init_features, + 'use_noise': self.use_noise, + 'seed': self.seed, + 'init_weights': self.init_weights, + } + diff --git a/stalactite/models/split_learning/resnet_top.py b/stalactite/models/split_learning/resnet_top.py index ec8b67b..b5755e4 100644 --- a/stalactite/models/split_learning/resnet_top.py +++ b/stalactite/models/split_learning/resnet_top.py @@ -41,7 +41,14 @@ def __init__( super(ResNetTop, self).__init__() num_features = input_dim if num_init_features is None else num_init_features + + self.input_dim = input_dim + self.output_dim = output_dim + self.num_init_features = num_init_features + self.use_bn = use_bn + self.init_weights = init_weights self.seed = seed + self.features = nn.Sequential(OrderedDict([])) if use_bn: self.features.add_module("norm", nn.BatchNorm1d(num_features)) @@ -71,6 +78,7 @@ def update_weights(self, x: torch.Tensor, gradients: torch.Tensor, criterion=Non logit = self.forward(x) loss = criterion(torch.squeeze(logit), gradients.type(torch.FloatTensor)) grads = torch.autograd.grad(outputs=loss, inputs=x, retain_graph=True) + logger.info(f"Loss: {loss.item()}") loss.backward() optimizer.step() return grads[0] @@ -81,3 +89,14 @@ def update_weights(self, x: torch.Tensor, gradients: torch.Tensor, criterion=Non def predict(self, x: torch.Tensor) -> torch.Tensor: return self.forward(x) + + @property + def init_params(self): + return { + 'input_dim': self.input_dim, + 'output_dim': self.output_dim, + 'num_init_features': self.num_init_features, + 'use_bn': self.use_bn, + 'seed': self.seed, + 'init_weights': self.init_weights, + } diff --git a/stalactite/run_grpc_agent.py b/stalactite/run_grpc_agent.py index 85e8e7f..7b793fd 100644 --- a/stalactite/run_grpc_agent.py +++ b/stalactite/run_grpc_agent.py @@ -9,9 +9,8 @@ ) from stalactite.helpers import reporting, global_logging from stalactite.configs import VFLConfig -from stalactite.ml.arbitered.base import Role from stalactite.data_utils import get_party_master, get_party_arbiter, get_party_member -from stalactite.utils import seed_all +from stalactite.utils import seed_all, Role import logging diff --git a/stalactite/utils.py b/stalactite/utils.py index a94385c..0f7c359 100644 --- a/stalactite/utils.py +++ b/stalactite/utils.py @@ -1,3 +1,4 @@ +import enum import math import os import random @@ -22,3 +23,9 @@ def seed_all(seed: int): np.random.seed(seed) random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) + + +class Role(str, enum.Enum): + arbiter = "arbiter" + master = "master" + member = "member" diff --git a/stalactite/utils_main.py b/stalactite/utils_main.py index b4ec425..34ec6a6 100644 --- a/stalactite/utils_main.py +++ b/stalactite/utils_main.py @@ -15,7 +15,7 @@ from stalactite.configs import VFLConfig, raise_path_not_exist from stalactite.data_utils import get_party_arbiter, get_party_master, get_party_member from stalactite.helpers import run_local_agents, reporting, global_logging -from stalactite.ml.arbitered.base import Role +from stalactite.utils import Role BASE_CONTAINER_LABEL = "grpc-experiment" DOCKER_OBJECTS_LABEL = {"framework": "stalactite"}