-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCPT.py
220 lines (183 loc) · 7.9 KB
/
CPT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import pandas as pd
class Tree:
# Tree data structure
def __init__(self, item_value=None):
self.Item = item_value
self.Count = 0
self.Children = []
self.Parent = None
def add_child(self, child):
new_child = Tree(child)
new_child.Parent = self
self.Children.append(new_child)
def get_child(self, tg):
for child in self.Children:
if child.Item == tg:
return child
return None
def has_child(self, target):
found = self.get_child(target)
if found is not None:
return True
else:
return False
def remove_child(self, child):
for chld in self.Children:
if chld.Item == child:
self.Children.remove(chld)
class CPT:
def __init__(self):
self.alphabet = set()
self.II = {}
self.LT = {}
self.root_node = Tree()
def train(self, data, max_seq_length=10):
for idx, seq in enumerate(data):
seq = seq[-max_seq_length:] # take only the last max_seq_length items in the sequence
# Start from root node
current_node = self.root_node
for item in seq:
# AJO:
self.root_node.Count += 1
# Update complete list of item used
self.alphabet.add(item)
# Add a new branch if not existing
if not current_node.has_child(item):
current_node.add_child(item)
# Move one level down in tree
current_node = current_node.get_child(item)
# AJO:
current_node.Count += 1
# Create set in Inverted index if item not existing
if self.II.get(item) is None:
self.II[item] = set()
# Add idx to II
self.II[item].add(idx)
# Add last item to the lookup table
self.LT[idx] = current_node
return True
def prune(self, min_leaf_count=1):
branches_to_remove = []
for idx in self.LT:
current_node = self.LT[idx]
item_count = {}
while current_node.Parent is not None:
# Create table to check if needed to delete reference in II
if item_count.get(current_node.Item) is None:
item_count[current_node.Item] = 0
item_count[current_node.Item] = max(item_count[current_node.Item], current_node.Count)
# Remove node if needed
if current_node.Count < min_leaf_count:
item = current_node.Item
current_node = current_node.Parent
# update LT
if current_node.Parent is not None:
self.LT[idx] = current_node
else:
# keep track of branches to remove (keep dict size while iterating)
branches_to_remove.append(idx)
current_node.remove_child(item)
else:
current_node = current_node.Parent
# Remove references in II
for item, count in item_count.items():
if count < min_leaf_count:
if len(self.II[item]) == 1:
del self.II[item]
else:
self.II[item].remove(idx)
# Delete Branch
for branch in branches_to_remove:
del self.LT[branch]
def predict(self, target, k=10, n=1, p=1, coef=2):
# k --> Limitation for the target size
# n --> Nb of predictions --> give the best n items
# p --> prune if node.Count <= p (if not pruned before...)
predictions = []
all_seqs = set(range(0, len(self.LT)))
for t_seq in target:
t_seq = t_seq[-k:] # Take only the last k element of sequence in target
# Find sequences id where items of target are present
intersection = set()
for item in t_seq:
if self.II.get(item) is None: # manage if code not seen during training
continue
intersection = all_seqs & self.II.get(item)
# Rebuild sequences from sequence id in intersection
# This allow to predict with the Tree and not the original data
similar_sequences = []
for element in intersection:
current_node = self.LT.get(element)
tmp = []
while current_node.Item is not None:
if current_node.Count > p: # AJO
tmp.append(current_node.Item)
current_node = current_node.Parent
if len(tmp) > 0: # AJO
similar_sequences.append(tmp)
for sequence in similar_sequences:
sequence.reverse()
count_table = {}
for seq in similar_sequences:
# find index in similar_sequence of last item in target
try:
index = seq.index(t_seq[-1])
except ValueError:
index = None
if index is not None:
# add predecessor weight (if exact same predecessors)
weight_predecessor = 1
if index > 0:
seq_pred = seq[:index + 1]
ridx = 2
while len(t_seq) >= ridx and len(seq_pred) >= ridx:
if t_seq[-ridx] == seq_pred[-ridx]:
weight_predecessor = weight_predecessor * coef
ridx += 1
count = 1
for element in seq[index + 1:]:
# if element in t_seq: # Skip if element already in target
# continue
weight_level = 1 / len(similar_sequences) # len(similar_sequences) = support of
weight_distance = 1 / count
score = weight_predecessor + weight_level + weight_distance * 0.001
if count_table.get(element) is None:
count_table[element] = score
else:
count_table[element] = score * count_table.get(element)
count_table = count_table
count += 1
largest = sorted(count_table.items(), key=lambda t: t[1], reverse=True)[:n]
if len(largest) == 0:
largest = [('--NO-RESULT--', 0)]
else:
largest = [(k, round(v / sum([v for k, v in largest]), 2)) for k, v in largest]
predictions.append(largest)
return predictions
def read_file(filename, id_col, line_num_col, code_col, require_sorting=False):
# Read csv file and create list of data
df = pd.read_csv(filename, sep=";", engine='python', keep_default_na=False)
if require_sorting:
df = df.sort_values(by=[id_col, line_num_col], ignore_index=True) # Sorted by Pentaho
dat = []
hist_ref = df[id_col][0]
current_seq = []
for _, row in df.iterrows():
if row[id_col] == hist_ref:
current_seq.append(row[code_col])
else:
dat.append(current_seq)
current_seq = []
hist_ref = row[id_col]
current_seq.append(row[code_col])
return dat
def pprint_tree(node, file=None, _prefix="", _last=True):
"""
Useful function to Print a graph of the Tree
"""
print(_prefix, "└─ " if _last else "├─ ", node.Item, " = ", node.Count, sep="", file=file)
_prefix += " " if _last else "│ "
child_count = len(node.Children)
for i, child in enumerate(node.Children):
_last = i == (child_count - 1)
pprint_tree(child, file, _prefix, _last)