-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmemory_utils.py
127 lines (107 loc) · 4.17 KB
/
memory_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# memory_utils.py
import torch
from typing import Dict, List, Tuple, Optional
from titans_pytorch import MemoryAsContextTransformer, NeuralMemory
def load_model_from_checkpoint(checkpoint_path: str) -> MemoryAsContextTransformer:
"""
Load model from checkpoint with all necessary configurations
"""
checkpoint = torch.load(checkpoint_path)
# Get hyperparameters from checkpoint
hyperparams = checkpoint['hyperparams']
# Initialize model with saved hyperparameters
model = MemoryAsContextTransformer(
num_tokens = 256,
dim = 384,
depth = 8,
segment_len = hyperparams['WINDOW_SIZE'],
num_persist_mem_tokens = hyperparams['NUM_PERSIST_MEM'],
num_longterm_mem_tokens = hyperparams['NUM_LONGTERM_MEM'],
neural_memory_layers = hyperparams['NEURAL_MEM_LAYERS'],
neural_memory_segment_len = hyperparams['NEURAL_MEM_SEGMENT_LEN'],
neural_mem_gate_attn_output = True,
aux_kv_recon_loss_weight = hyperparams['KV_RECON_LOSS_WEIGHT'],
use_flex_attn = True,
sliding_window_attn = hyperparams['SLIDING_WINDOWS'],
neural_memory_kwargs = dict(
dim_head = 64,
heads = 4,
use_accelerated_scan = True,
learned_mem_model_weights = hyperparams['LEARNED_MEM_MODEL_WEIGHTS'],
default_model_kwargs = dict(
depth = hyperparams['NEURAL_MEMORY_DEPTH'],
)
)
)
# Move to GPU if available
if torch.cuda.is_available():
model = model.cuda()
# Load model state
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded checkpoint from batch {checkpoint['batch_idx']}")
return model
def save_memory_state(model, save_path: str):
"""
Save the neural memory states from all memory layers in the model
"""
memory_states = {}
# Extract memory states from each layer
for idx, (attn, _) in enumerate(model.layers):
if hasattr(attn, 'neural_mem'):
mem = attn.neural_mem
if mem is not None and hasattr(mem, 'previous_state'):
memory_states[f'layer_{idx}_memory'] = {
'previous_state': mem.previous_state
}
# Save to file
torch.save(memory_states, save_path)
print(f"Memory states saved to {save_path}")
return memory_states
def load_memory_state(model, load_path: str):
"""
Load and restore neural memory states to the model
"""
if not torch.cuda.is_available():
memory_states = torch.load(load_path, map_location='cpu')
else:
memory_states = torch.load(load_path)
# Restore memory states to each layer
for idx, (attn, _) in enumerate(model.layers):
if hasattr(attn, 'neural_mem'):
mem = attn.neural_mem
if mem is not None and f'layer_{idx}_memory' in memory_states:
mem.previous_state = memory_states[f'layer_{idx}_memory']['previous_state']
print(f"Memory states loaded from {load_path}")
return memory_states
def process_text_and_update_memory(
model,
text: str,
chunk_size: int = 512,
save_memory: bool = True,
memory_path: Optional[str] = None
) -> Tuple[torch.Tensor, Dict]:
"""
Process text through the model while updating its neural memory
"""
model.eval()
# Convert text to tokens
tokens = torch.tensor([[ord(c) for c in text]], device='cuda' if torch.cuda.is_available() else 'cpu')
# Process in chunks to update memory
chunks = tokens.split(chunk_size, dim=1)
last_output = None
with torch.no_grad():
for chunk in chunks:
# Process chunk
output = model(chunk)
last_output = output
# Save updated memory state if requested
memory_states = None
if save_memory and memory_path:
memory_states = save_memory_state(model, memory_path)
return last_output, memory_states
def decode_token(token):
"""Decode a single token to character"""
return str(chr(max(32, token)))
def decode_tokens(tokens):
"""Decode a sequence of tokens to text"""
return ''.join(list(map(decode_token, tokens)))