From 032196cbae1e8cb8da91d6eda42c70584a20ddff Mon Sep 17 00:00:00 2001 From: seruva19 Date: Sun, 26 Jan 2025 22:11:36 +0300 Subject: [PATCH] add option to define env variables from yaml file --- .gitignore | 3 +- src/env.py | 5 +++ src/kubin.py | 1 + src/utils/env_data.py | 78 ++++++++++++++++++++++++++++++++++++++----- 4 files changed, 78 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index e77fd5c..cb3ff81 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,5 @@ main.* /models /output /train -/checkpoints \ No newline at end of file +/checkpoints +kubin.env.yaml \ No newline at end of file diff --git a/src/env.py b/src/env.py index 0d402a6..dba7694 100644 --- a/src/env.py +++ b/src/env.py @@ -3,6 +3,7 @@ import torch from extension.ext_registry import ExtensionRegistry from params import KubinParams +from utils.env_data import load_custom_env from utils.logging import k_error, k_log @@ -32,6 +33,10 @@ def with_args(self, args): self.params.register_change_callback(self.ext_registry.propagate_params_changes) + def with_envvars(self): + yaml_path = "kubin.env.yaml" + load_custom_env(yaml_path) + def with_pipeline(self): use_mock = self.params("general", "mock") model_name = self.params("general", "model_name") diff --git a/src/kubin.py b/src/kubin.py index 75ad40d..d69f26a 100644 --- a/src/kubin.py +++ b/src/kubin.py @@ -15,6 +15,7 @@ def init_kubin(kubin: Kubin): kubin.with_args(args) + kubin.with_envvars() kubin.with_utils() kubin.with_extensions() kubin.with_hooks() diff --git a/src/utils/env_data.py b/src/utils/env_data.py index dc6b00c..a56533c 100644 --- a/src/utils/env_data.py +++ b/src/utils/env_data.py @@ -1,18 +1,80 @@ -import os import torch +import gc +from typing import Union, List, Dict, Optional from collections import defaultdict from safetensors.torch import load_file, save_file +from utils.logging import k_log +import os +import yaml + +models: Dict[str, torch.nn.Module] = {} + + +def reg(model_id, weights): + if model_id in models: + k_log(f"model with name '{model_id}' already exists") + models[model_id] = weights + + +def clear(model_names: Optional[Union[str, List[str]]] = None): + names_to_clear = [] + + if model_names is None: + names_to_clear = list(models.keys()) + elif isinstance(model_names, str): + names_to_clear = [model_names] + else: + names_to_clear = model_names + + for name in names_to_clear: + if name not in models: + k_log(f"model '{name}' not registered, cannot release") + + try: + models[name].to("cpu") + except: + k_log(f"failed to release model '{name}'") + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + del models[name] + gc.collect() + + +def load_env_value(key, default_value): + return os.environ.get(key, default_value) + + +def load_custom_env(file_path): + try: + if os.path.exists(file_path): + k_log(f"loading custom env values from {file_path}") + with open(file_path, "r") as yaml_file: + config = yaml.safe_load(yaml_file) + + if config is None: + return + + for key, value in config.items(): + os.environ[key] = str(value) + k_log(f"custom environment variable set: {key} = {value}") + + except Exception as e: + k_log(f"error loading custom env values from {file_path}: {e}") def map_target_to_task(target): return ( "text2img" if target == "t2i" - else "img2img" - if target == "i2i" - else "inpainting" - if target == "inpaint" - else "outpainting" - if target == "outpaint" - else target + else ( + "img2img" + if target == "i2i" + else ( + "inpainting" + if target == "inpaint" + else "outpainting" if target == "outpaint" else target + ) + ) )