Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hypergraph Heat Kernel Lift (Hypergraph to Simplicial) #58

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions configs/datasets/contact_primary_school.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: hypergraph
data_type: contact
data_name: ContactPrimarySchool
data_dir: datasets/${data_domain}/${data_type}
#data_split_dir: ${oc.env:PROJECT_ROOT}/datasets/data_splits/${data_name}

# Dataset parameters
num_nodes: 242
num_hyperedges: 12704
num_classes: 11
max_dim: 1
task: classification
8 changes: 8 additions & 0 deletions configs/datasets/manual_hypergraph.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
data_domain: hypergraph
data_type: toy_hypergraph
data_name: manual_hg
data_dir: datasets/${data_domain}/${data_type}

num_nodes: 12
num_hyperedges: 24
max_dim: 2
12 changes: 12 additions & 0 deletions configs/datasets/senate_committee.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: hypergraph
data_type: interaction
data_name: senate_committee
data_dir: datasets/${data_domain}/${data_type}
#data_split_dir: ${oc.env:PROJECT_ROOT}/datasets/data_splits/${data_name}

# Dataset parameters
num_nodes: 282
num_hyperedges: 315
num_classes: 2
max_dim: 2
task: classification
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
transform_type: 'lifting'
transform_name: "HypergraphHeatLifting"
complex_dim: 2
signed: True
feature_lifting: ProjectionSum
21 changes: 20 additions & 1 deletion modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from modules.data.utils.custom_dataset import CustomDataset
from modules.data.utils.utils import (
load_cell_complex_dataset,
load_contact_primary_school,
load_hypergraph_pickle_dataset,
load_manual_graph,
load_manual_hypergraph,
load_senate_committee,
load_simplicial_dataset,
)

Expand Down Expand Up @@ -203,4 +206,20 @@ def load(
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_hypergraph_pickle_dataset(self.parameters)
## Define the path to the data directory
root_folder = rootutils.find_root()
root_data_dir = os.path.join(root_folder, self.parameters["data_dir"])

self.data_dir = os.path.join(root_data_dir, self.parameters["data_name"])
if self.parameters.data_name in ["ContactPrimarySchool"]:
data = load_contact_primary_school(self.parameters, self.data_dir)
dataset = CustomDataset([data], self.data_dir)
elif self.parameters.data_name in ["senate_committee"]:
data = load_senate_committee(self.parameters, self.data_dir)
dataset = CustomDataset([data], self.data_dir)
elif self.parameters.data_name in ["manual_hg"]:
data = load_manual_hypergraph(self.parameters)
dataset = CustomDataset([data], self.data_dir)
else:
dataset = load_hypergraph_pickle_dataset(self.parameters)
return dataset
4 changes: 2 additions & 2 deletions modules/data/preprocess/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ class PreProcessor(torch_geometric.data.InMemoryDataset):

Parameters
----------
data_dir : str
Path to the directory containing the data.
data_list : list
List of data objects.
transforms_config : DictConfig | dict
Configuration parameters for the transforms.
data_dir : str
Path to the directory containing the data.
**kwargs: optional
Additional arguments.
"""
Expand Down
174 changes: 165 additions & 9 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import hashlib
import itertools as it
import os.path as osp
import pickle
import tempfile
import zipfile

import networkx as nx
import numpy as np
import omegaconf
import toponetx.datasets.graph as graph
import torch
import torch_geometric
import torch_sparse
from topomodelx.utils.sparse import from_sparse
from torch_geometric.data import Data
from torch_sparse import coalesce
from torch_sparse import SparseTensor, coalesce


def get_complex_connectivity(complex, max_rank, signed=False):
Expand Down Expand Up @@ -50,16 +54,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand Down Expand Up @@ -334,6 +338,158 @@ def load_manual_graph():
)


def load_manual_hypergraph(cfg: dict):
"""Create a manual hypergraph for testing purposes."""
rng = np.random.default_rng(1234)
n, m = 12, 24
hyperedges = set(
[tuple(np.flatnonzero(rng.choice([0, 1], size=n))) for _ in range(m)]
)
hyperedges = [np.array(he) for he in hyperedges]
R = torch.tensor(np.concatenate(hyperedges), dtype=torch.long)
C = torch.tensor(
np.repeat(np.arange(len(hyperedges)), [len(he) for he in hyperedges]),
dtype=torch.long,
)
V = torch.tensor(np.ones(len(R)))
incidence_hyperedges = torch_sparse.SparseTensor(row=R, col=C, value=V)
incidence_hyperedges = incidence_hyperedges.coalesce().to_torch_sparse_coo_tensor()

## Bipartite graph repr.
edges = np.array(
list(it.chain(*[[(i, v) for v in he] for i, he in enumerate(hyperedges)]))
)
return Data(
x=torch.empty((n, 0)),
edge_index=torch.tensor(edges, dtype=torch.long),
num_nodes=n,
num_node_features=0,
num_edges=len(hyperedges),
incidence_hyperedges=incidence_hyperedges,
max_dim=cfg.get("max_dim", 3),
)


def load_contact_primary_school(cfg: dict, data_dir: str):
import gdown

url = "https://drive.google.com/uc?id=1H7PGDPvjCyxbogUqw17YgzMc_GHLjbZA"
fn = tempfile.NamedTemporaryFile()
gdown.download(url, fn.name, quiet=False)
archive = zipfile.ZipFile(fn.name, "r")
labels = archive.open(
"contact-primary-school/node-labels-contact-primary-school.txt", "r"
).readlines()
hyperedges = archive.open(
"contact-primary-school/hyperedges-contact-primary-school.txt", "r"
).readlines()
label_names = archive.open(
"contact-primary-school/label-names-contact-primary-school.txt", "r"
).readlines()

hyperedges = [
list(map(int, he.decode().replace("\n", "").strip().split(",")))
for he in hyperedges
]
labels = np.array([int(b.decode().replace("\n", "").strip()) for b in labels])
label_names = np.array([b.decode().replace("\n", "").strip() for b in label_names])

# Based on: https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.HypergraphConv.html
HE_coo = torch.tensor(
np.array(
[
np.concatenate(hyperedges),
np.repeat(np.arange(len(hyperedges)), [len(he) for he in hyperedges]),
]
)
)

incidence_hyperedges = (
SparseTensor(
row=HE_coo[0, :],
col=HE_coo[1, :],
value=torch.tensor(np.ones(HE_coo.shape[1])),
)
.coalesce()
.to_torch_sparse_coo_tensor()
)

return Data(
x=torch.empty((len(labels), 0)),
edge_index=HE_coo,
y=torch.LongTensor(labels),
y_names=label_names,
num_nodes=len(labels),
num_node_features=0,
num_edges=len(hyperedges),
incidence_hyperedges=incidence_hyperedges,
max_dim=cfg.get("max_dim", 1)
# x_hyperedges=torch.tensor(np.empty(shape=(len(hyperedges), 0)))
)


def load_senate_committee(cfg: dict, data_dir: str) -> torch_geometric.data.Data:
import tempfile
import zipfile

import gdown

url = "https://drive.google.com/uc?id=17ZRVwki_x_C_DlOAea5dPBO7Q4SRTRRw"
fn = tempfile.NamedTemporaryFile()
gdown.download(url, fn.name, quiet=False)
archive = zipfile.ZipFile(fn.name, "r")
labels = archive.open(
"senate-committees/node-labels-senate-committees.txt", "r"
).readlines()
hyperedges = archive.open(
"senate-committees/hyperedges-senate-committees.txt", "r"
).readlines()
label_names = archive.open(
"senate-committees/node-names-senate-committees.txt", "r"
).readlines()

hyperedges = [
list(map(int, he.decode().replace("\n", "").strip().split(",")))
for he in hyperedges
]
labels = np.array([int(b.decode().replace("\n", "").strip()) for b in labels])
label_names = np.array([b.decode().replace("\n", "").strip() for b in label_names])

# Based on: https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.HypergraphConv.html
HE_coo = torch.tensor(
np.array(
[
np.concatenate(hyperedges) - 1,
np.repeat(np.arange(len(hyperedges)), [len(he) for he in hyperedges]),
]
)
)
from torch_sparse import SparseTensor

incidence_hyperedges = (
SparseTensor(
row=HE_coo[0, :],
col=HE_coo[1, :],
value=torch.tensor(np.ones(HE_coo.shape[1])),
)
.coalesce()
.to_torch_sparse_coo_tensor()
)

return Data(
x=torch.empty((len(labels), 0)),
edge_index=HE_coo,
y=torch.LongTensor(labels),
y_names=label_names,
num_nodes=len(labels),
num_node_features=0,
num_edges=len(hyperedges),
incidence_hyperedges=incidence_hyperedges,
max_dim=cfg.get("max_dim", 2)
# x_hyperedges=torch.tensor(np.empty(shape=(len(hyperedges), 0)))
)


def get_Planetoid_pyg(cfg):
r"""Loads Planetoid graph datasets from torch_geometric.

Expand Down
5 changes: 5 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.hypergraph2simplicial.heat_lifting import (
HypergraphHeatLifting,
)

TRANSFORMS = {
# Graph -> Hypergraph
Expand All @@ -23,6 +26,8 @@
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Hypergraph -> Simplicial Complex
"HypergraphHeatLifting": HypergraphHeatLifting,
# Feature Liftings
"ProjectionSum": ProjectionSum,
# Data Manipulations
Expand Down
4 changes: 3 additions & 1 deletion modules/transforms/feature_liftings/feature_liftings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def lift_features(
-------
torch_geometric.data.Data | dict
The lifted data."""
keys = sorted([key.split("_")[1] for key in data.keys() if "incidence" in key]) # noqa : SIM118
keys = sorted(
[key.split("_")[1] for key in data.keys() if "incidence" in key] # noqa
)
for elem in keys:
if f"x_{elem}" not in data:
idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1
Expand Down
19 changes: 19 additions & 0 deletions modules/transforms/liftings/hypergraph2simplicial/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from modules.transforms.liftings.lifting import HypergraphLifting


class Hypergraph2SimplicialLifting(HypergraphLifting):
r"""Abstract class for lifting hyper graphs to simplicial complexes.

Parameters
----------
complex_dim : int, optional
The dimension of the simplicial complex to be generated. Default is 2.
**kwargs : optional
Additional arguments for the class.
"""

def __init__(self, complex_dim=2, **kwargs):
super().__init__(**kwargs)
self.complex_dim = complex_dim
self.type = "hypergraph2simplicial"
self.signed = kwargs.get("signed", False)
Loading
Loading