Skip to content

Commit

Permalink
Feature/distr inference (#19)
Browse files Browse the repository at this point in the history
* Fix docs

* Benchmark he-frameworks

* Add distributed inference

* Fix pr comments
  • Loading branch information
zakharova-anastasiia authored Feb 28, 2024
1 parent a427a64 commit 7cf7eab
Show file tree
Hide file tree
Showing 64 changed files with 6,456 additions and 3,395 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,7 @@ main_grps.py
data
rsync_repo
reports/
configs/config-local.yml
configs/config-local.yml
.DS_Store
saved_models/
*.orig
14 changes: 14 additions & 0 deletions docs/stalactite.ml.arbitered.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
stalactite.ml.arbitered package
===============================

Submodules
----------

stalactite.ml.arbitered.base module
-----------------------------------

.. automodule:: stalactite.ml.arbitered.base
:members:
:undoc-members:
:show-inheritance:

22 changes: 22 additions & 0 deletions docs/stalactite.ml.honest.linear_regression.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
stalactite.ml.honest.linear\_regression package
===============================================

Submodules
----------

stalactite.ml.honest.linear\_regression.party\_master module
------------------------------------------------------------

.. automodule:: stalactite.ml.honest.linear_regression.party_master
:members:
:undoc-members:
:show-inheritance:

stalactite.ml.honest.linear\_regression.party\_member module
------------------------------------------------------------

.. automodule:: stalactite.ml.honest.linear_regression.party_member
:members:
:undoc-members:
:show-inheritance:

21 changes: 21 additions & 0 deletions docs/stalactite.ml.honest.logistic_regression.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
stalactite.ml.honest.logistic\_regression package
=================================================

Submodules
----------

stalactite.ml.honest.logistic\_regression.party\_master module
--------------------------------------------------------------

.. automodule:: stalactite.ml.honest.logistic_regression.party_master
:members:
:undoc-members:
:show-inheritance:

stalactite.ml.honest.logistic\_regression.party\_member module
--------------------------------------------------------------

.. automodule:: stalactite.ml.honest.logistic_regression.party_member
:members:
:undoc-members:
:show-inheritance:
22 changes: 22 additions & 0 deletions docs/stalactite.ml.honest.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
stalactite.ml.honest package
============================

Subpackages
-----------

.. toctree::
:maxdepth: 4

stalactite.ml.honest.linear_regression
stalactite.ml.honest.logistic_regression

Submodules
----------

stalactite.ml.honest.base module
--------------------------------

.. automodule:: stalactite.ml.honest.base
:members: HonestPartyMaster, HonestPartyMember
:undoc-members:
:show-inheritance:
19 changes: 19 additions & 0 deletions docs/stalactite.ml.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
stalactite.ml package
=====================

Subpackages
-----------

.. toctree::
:maxdepth: 4

stalactite.ml.arbitered
stalactite.ml.honest

Module contents
---------------

.. automodule:: stalactite.ml
:members:
:undoc-members:
:show-inheritance:
20 changes: 3 additions & 17 deletions docs/stalactite.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ Subpackages
-----------

.. toctree::
:maxdepth: 4
:maxdepth: 8

stalactite.communications
stalactite.data_preprocessors
stalactite.models
stalactite.ml


Submodules
----------
Expand All @@ -22,22 +24,6 @@ stalactite.base module
:undoc-members:


stalactite.party\_master\_impl module
-------------------------------------

.. automodule:: stalactite.party_master_impl
:members:
:undoc-members:
:show-inheritance:

stalactite.party\_member\_impl module
-------------------------------------

.. automodule:: stalactite.party_member_impl
:members:
:undoc-members:
:show-inheritance:

stalactite.party\_single\_impl module
-------------------------------------

Expand Down
9 changes: 7 additions & 2 deletions docs/tutorials/configuration_file_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ VFL model are training and model specific parameters also used in any experiment
vfl_model:
vfl_model_name: # Model name to train
vfl_model_path: # Directory to save the model for further evaluation
do_train: # Whether to do training of the model
do_predict: # Whether to do evaluation of the model
do_save_model: # Whether to save the model to the `vfl_model_path` after training
epochs: # Number of training epochs
batch_size: # Training batch size
# For local experiment with `linreg` model you can choose the consequent batcher, to update on member at a time
is_consequently: False # Set True for consequent batcher
eval_batch_size: # Evaluation batch size
# For local experiment with `linreg` model you can choose the consequent make_batcher, to update on member at a time
is_consequently: False # Set True for consequent make_batcher
learning_rate: # Experiment learning rate
use_class_weights: # Used in `logreg`
Expand Down
72 changes: 53 additions & 19 deletions docs/tutorials/distr_communicator_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,48 +28,68 @@ defined in the :ref:`local_comm_tutorial`, we import it from the examples folder
from examples.utils.local_experiment import load_processors
def get_party_master(config_path: str):
def get_party_master(config_path: str, is_infer: bool = False):
config = VFLConfig.load_and_validate(config_path) # Load configuration file
processors = load_processors(config) # Load processors
# Define target uids (simulating only partially available data)
# Define target uids and evaluation target uids (simulating only partially available data)
target_uids = [str(i) for i in range(config.data.dataset_size)]
inference_target_uids = [str(i) for i in range(1000)]
# The rest of party master definition is similar to the local example
if 'logreg' in config.common.vfl_model_name:
master_class = PartyMasterImplLogreg
if 'logreg' in config.vfl_model.vfl_model_name:
master_class = HonestPartyMasterLogReg
else:
if config.common.is_consequently:
master_class = PartyMasterImplConsequently
master_class = HonestPartyMasterLinRegConsequently
else:
master_class = PartyMasterImpl
master_class = HonestPartyMasterLinReg
return master_class(
uid="master",
epochs=config.common.epochs,
epochs=config.vfl_model.epochs,
report_train_metrics_iteration=config.common.report_train_metrics_iteration,
report_test_metrics_iteration=config.common.report_test_metrics_iteration,
processor=processors[0],
target_uids=target_uids,
batch_size=config.common.batch_size,
batch_size=config.vfl_model.batch_size,
eval_batch_size=config.vfl_model.eval_batch_size,
model_update_dim_size=0,
run_mlflow=config.master.run_mlflow,
do_train=not is_infer, # To launch from stalactite CLI we use separate commands for training and inference
do_predict=is_infer, # To launch from stalactite CLI we use separate commands for training and inference
inference_target_uids=inference_target_uids,
)
# Because we create separate containers we pass the rank to load correct processors
def get_party_member(config_path: str, member_rank: int):
def get_party_member(config_path: str, member_rank: int, is_infer: bool = False):
config = VFLConfig.load_and_validate(config_path)
processors = load_processors(config_path)
processors = load_processors(config)
target_uids = [str(i) for i in range(config.data.dataset_size)]
# We do not pass members ids due to sequential distributed algorithm cannot be used
return PartyMemberImpl(
inference_target_uids = [str(i) for i in range(1000)]
if 'logreg' in config.vfl_model.vfl_model_name:
member_class = HonestPartyMemberLogReg
else:
member_class = HonestPartyMemberLinReg
return member_class(
uid=f"member-{member_rank}",
member_record_uids=target_uids,
model_name=config.common.vfl_model_name,
member_inference_record_uids=inference_target_uids,
model_name=config.vfl_model.vfl_model_name,
processor=processors[member_rank],
batch_size=config.common.batch_size,
epochs=config.common.epochs,
batch_size=config.vfl_model.batch_size,
eval_batch_size=config.vfl_model.eval_batch_size,
epochs=config.vfl_model.epochs,
report_train_metrics_iteration=config.common.report_train_metrics_iteration,
report_test_metrics_iteration=config.common.report_test_metrics_iteration,
do_train=not is_infer,
do_predict=is_infer,
do_save_model=True, # To work with the stalactite CLI train and predict we always save the model
model_path=config.vfl_model.vfl_model_path,
)
# Alternatively, we can set `do_train=config.vfl_model.do_train` and `do_predict=config.vfl_model.do_predict`
# for both master and member configuration
# For the member configuration, `do_save_model` can be altered with: `config.vfl_model.do_save_model`
The CLI ``stalactite local --multi-process start`` and ``stalactite <master/member> start`` commands launches containers
using the ``grpc-base:latest`` image built from one of the dockerfiles which can be found in the ``docker/`` folder in
`github <https://github.com/sb-ai-lab/vfl-benchmark/tree/main>`_.
Expand All @@ -92,15 +112,22 @@ For the master communicator the following script is used (``run_grpc_master.py``
@click.command()
@click.option("--config-path", type=str, default="../configs/config.yml")
def main(config_path):
@click.option(
"--infer",
is_flag=True,
show_default=True,
default=False,
help="Run in an inference mode.",
)
def main(config_path, infer):
# Same to the local experiment load the configuration into the VFLConfig Pydantic model
config = VFLConfig.load_and_validate(config_path)
# Use context manager to log metrics to mlflow (if enabled)
with reporting(config):
# In the GRpcMasterPartyCommunicator several keyword arguments appear, mostly required for the gRPC server start
comm = GRpcMasterPartyCommunicator(
participant=get_party_master(config_path),
participant=get_party_master(config_path, is_infer=infer),
world_size=config.common.world_size,
port=config.grpc_server.port,
host=config.grpc_server.host,
Expand Down Expand Up @@ -136,7 +163,14 @@ For the member communicator we implemented the following (``run_grpc_member.py``
@click.command()
@click.option("--config-path", type=str, default="../configs/config.yml")
def main(config_path):
@click.option(
"--infer",
is_flag=True,
show_default=True,
default=False,
help="Run in an inference mode.",
)
def main(config_path, infer):
# Due to the metrics and parameters are logged from the master, we do not need to start the mlflow
# experiment here
Expand All @@ -153,7 +187,7 @@ For the member communicator we implemented the following (``run_grpc_member.py``
# Again, GRpcMemberPartyCommunicator requires additional keyword args to act as the gRPC client to the
# server on master
comm = GRpcMemberPartyCommunicator(
participant=get_party_member(config_path, member_rank),
participant=get_party_member(config_path, member_rank, is_infer=infer),
master_host=grpc_host,
master_port=config.grpc_server.port,
max_message_size=config.grpc_server.max_message_size,
Expand Down
50 changes: 35 additions & 15 deletions docs/tutorials/local_communicator_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ define the ``load_processors`` function
# We initialize the target uids here because we want to simulate only partially available data
target_uids = [str(i) for i in range(config.data.dataset_size)]
inference_target_uids = [str(i) for i in range(500)]
# Local communicator requires party information, we initialize it as an empty dictionary as no data is passed for
# the experiment
shared_party_info = dict()
Expand All @@ -93,27 +94,40 @@ After we can get all required data, let's initialize the master class

.. code-block:: python
from stalactite.party_master_impl import PartyMasterImpl, PartyMasterImplConsequently, PartyMasterImplLogreg
from stalactite.ml import (
HonestPartyMasterLinRegConsequently,
HonestPartyMasterLinReg,
HonestPartyMemberLogReg,
HonestPartyMemberLinReg,
HonestPartyMasterLogReg
)
def run(config_path: str):
...
if 'logreg' in config.common.vfl_model_name:
master_class = PartyMasterImplLogreg
if 'logreg' in config.vfl_model.vfl_model_name:
master_class = HonestPartyMasterLogReg
member_class = HonestPartyMemberLogReg
else:
if config.common.is_consequently:
master_class = PartyMasterImplConsequently
member_class = HonestPartyMemberLinReg
if config.vfl_model.is_consequently:
master_class = HonestPartyMasterLinRegConsequently
else:
master_class = PartyMasterImpl
master_class = HonestPartyMasterLinReg
master = master_class(
uid="master",
epochs=config.common.epochs,
epochs=config.vfl_model.epochs,
report_train_metrics_iteration=config.common.report_train_metrics_iteration,
report_test_metrics_iteration=config.common.report_test_metrics_iteration,
processor=processors[0], # For the master we take the first processor
target_uids=target_uids,
batch_size=config.common.batch_size,
inference_target_uids=inference_target_uids,
batch_size=config.vfl_model.batch_size,
eval_batch_size=config.vfl_model.eval_batch_size,
model_update_dim_size=0, # Let us leave this parameter as is, it will be updated later
run_mlflow=config.master.run_mlflow,
do_train=config.vfl_model.do_train,
do_predict=config.vfl_model.do_predict,
)
....
Expand All @@ -125,23 +139,29 @@ After the master is ready, we need to prepare the members:
def run(config_path: str):
...
# Members ids are required before the initialization only in local sequential linear regression case
# for the batcher initialization (it needs to have a list of the participants),
# for the make_batcher initialization (it needs to have a list of the participants),
# and are not applicable or used in other cases
member_ids = [f"member-{member_rank}" for member_rank in range(config.common.world_size)]
members = [
PartyMemberImpl(
member_class(
uid=member_uid,
member_record_uids=target_uids,
model_name=config.common.vfl_model_name,
member_inference_record_uids=inference_target_uids,
model_name=config.vfl_model.vfl_model_name,
processor=processors[member_rank],
batch_size=config.common.batch_size,
epochs=config.common.epochs,
batch_size=config.vfl_model.batch_size,
eval_batch_size=config.vfl_model.eval_batch_size,
epochs=config.vfl_model.epochs,
report_train_metrics_iteration=config.common.report_train_metrics_iteration,
report_test_metrics_iteration=config.common.report_test_metrics_iteration,
is_consequently=config.common.is_consequently,
members=member_ids if config.common.is_consequently else None,
is_consequently=config.vfl_model.is_consequently,
members=member_ids if config.vfl_model.is_consequently else None,
do_train=config.vfl_model.do_train,
do_predict=config.vfl_model.do_predict,
do_save_model=config.vfl_model.do_save_model,
model_path=config.vfl_model.vfl_model_path
)
for member_rank, member_uid in enumerate(member_ids)
]
Expand Down
Loading

0 comments on commit 7cf7eab

Please sign in to comment.