Skip to content

Commit

Permalink
add option to define env variables from yaml file
Browse files Browse the repository at this point in the history
  • Loading branch information
seruva19 committed Jan 26, 2025
1 parent 8562b00 commit 032196c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 9 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ main.*
/models
/output
/train
/checkpoints
/checkpoints
kubin.env.yaml
5 changes: 5 additions & 0 deletions src/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from extension.ext_registry import ExtensionRegistry
from params import KubinParams
from utils.env_data import load_custom_env
from utils.logging import k_error, k_log


Expand Down Expand Up @@ -32,6 +33,10 @@ def with_args(self, args):

self.params.register_change_callback(self.ext_registry.propagate_params_changes)

def with_envvars(self):
yaml_path = "kubin.env.yaml"
load_custom_env(yaml_path)

def with_pipeline(self):
use_mock = self.params("general", "mock")
model_name = self.params("general", "model_name")
Expand Down
1 change: 1 addition & 0 deletions src/kubin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

def init_kubin(kubin: Kubin):
kubin.with_args(args)
kubin.with_envvars()
kubin.with_utils()
kubin.with_extensions()
kubin.with_hooks()
Expand Down
78 changes: 70 additions & 8 deletions src/utils/env_data.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,80 @@
import os
import torch
import gc
from typing import Union, List, Dict, Optional
from collections import defaultdict
from safetensors.torch import load_file, save_file
from utils.logging import k_log
import os
import yaml

models: Dict[str, torch.nn.Module] = {}


def reg(model_id, weights):
if model_id in models:
k_log(f"model with name '{model_id}' already exists")
models[model_id] = weights


def clear(model_names: Optional[Union[str, List[str]]] = None):
names_to_clear = []

if model_names is None:
names_to_clear = list(models.keys())
elif isinstance(model_names, str):
names_to_clear = [model_names]
else:
names_to_clear = model_names

for name in names_to_clear:
if name not in models:
k_log(f"model '{name}' not registered, cannot release")

try:
models[name].to("cpu")
except:
k_log(f"failed to release model '{name}'")

if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
del models[name]
gc.collect()


def load_env_value(key, default_value):
return os.environ.get(key, default_value)


def load_custom_env(file_path):
try:
if os.path.exists(file_path):
k_log(f"loading custom env values from {file_path}")
with open(file_path, "r") as yaml_file:
config = yaml.safe_load(yaml_file)

if config is None:
return

for key, value in config.items():
os.environ[key] = str(value)
k_log(f"custom environment variable set: {key} = {value}")

except Exception as e:
k_log(f"error loading custom env values from {file_path}: {e}")


def map_target_to_task(target):
return (
"text2img"
if target == "t2i"
else "img2img"
if target == "i2i"
else "inpainting"
if target == "inpaint"
else "outpainting"
if target == "outpaint"
else target
else (
"img2img"
if target == "i2i"
else (
"inpainting"
if target == "inpaint"
else "outpainting" if target == "outpaint" else target
)
)
)

0 comments on commit 032196c

Please sign in to comment.