aboutsummaryrefslogtreecommitdiff
path: root/ass2/ukkonen.py
blob: 5aa94e788d01dc5ecc055f77baec74290060f8cb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
"""
This file is imported into questions 2 and 3.
"""

import sys

ALPHABET_SIZE = 28


class OrderedDict(dict):
    """
    A hybrid Python dictionary/list
        All set/get item operations on this data structure are the same complexity of a normal dictionary O(1)-ish
        For Ukkonen's operation, only the normal dictionary features are used.
    As values are stored in the dictionary, the are also referenced in a list of size O(alphabet).
        This acts as a kind of 'counting sort' when accessed is O(n), but provides a pre-sorted list of all children nodes
        This is used for generating the suffix array.
    """
    def __init__(self):
        super().__init__()
        self.first_letters = [None for _ in range(ALPHABET_SIZE)]

    def __setitem__(self, key, value):
        super().__setitem__(key, value)
        try:
            self.first_letters[self.rank(key)] = value
        except IndexError:
            raise IndexError(f"Could not add item of key '{key}' since it is out of range of the rank function.")

    def __delitem__(self, key):
        super().__delitem__(key)
        self.first_letters[self.rank(key)] = None

    def ordered_items(self):  # Return iterable of pre-sorted items (for suffix array)
        return filter(lambda x: x is not None, self.first_letters)

    @staticmethod
    def rank(char):
        """
        Define a number value to an alphabet letter, including special characters so they can fit in a list
        """
        if char == "$":
            return 26
        elif char == "&":
            return 27
        else:
            return ord(char) - 96


class Node:
    """
    Represents an arbitrary node in a suffix tree
    Also statically sotres some state information about the algorithm (not pretty, I know)
    """
    global_end = 0
    num_splits = 0
    all_nodes = []
    string = ""

    def __init__(self, start, end):
        self.root = False
        self.start = start
        self.end = end
        self.children = OrderedDict()
        self.id = len(self.all_nodes)
        self.suffix_index = self.id - self.num_splits - 1
        self.all_nodes.append(self)
        self.parent = None
        self.link = None

    def __str__(self):
        """
        String representation of node, shows important internal values of a node (for debug)
        """
        link_str = "" if self.link is None else f" -> {self.link.id}"
        if not self.root:
            j, i = self.tuple()
            return f"[{self.id}, {self.tuple()}, {self.string[j:i + 1]}{link_str}]"
        return f"[{self.id} root{link_str}]"

    def __repr__(self):  # Shorter representation of node
        return f"[{self.id}]"

    def print_tree(self, spaces=1):
        """
        Recursively prints tree of nodes (for debug)
        """
        print(f"{self}")
        for edge in self.children:
            print(f"   " * spaces, end="")
            self.get_child(edge).print_tree(spaces=spaces + 1)

    def first_char(self):
        return self.string[self.start]

    def get_child(self, char):
        if char in self.children:
            return self.children[char]
        return None

    def add_child(self, child):
        child.parent = self
        self.children[child.first_char()] = child
        return child

    def remove_child(self, child):
        self.children.pop(child.first_char())

    @property
    def end_index(self):  # Translates end index into a number (could be '#' pointer)
        return self.tuple()[1]

    def tuple(self):
        """
        Returns the resolved start and end coordinates of the substring this node represents
        """
        if self.root:
            raise Exception("Can't get substring of root.")
        if self.end == "#":  # Translate '#' into global_end
            return self.start, self.global_end
        return self.start, self.end

    @property
    def edge_length(self):
        if self.root:
            return 0
        else:
            start, end = self.tuple()
            return end - start + 1

    def detach(self):
        self.parent.remove_child(self)
        self.parent = None


class Point:
    """
    A representation of a single point on the tree. Used to store active node, edge and length data in one place
    Could represent a place in the middle of an edge (implicit) or a place on a node (explicit).
    Abstracts away a lot of tedium regarding working with these closely connected values.
    Can be used to create 'pure' functions which return a transformation on a given point
    """
    def __init__(self, node, edge="", length=0):
        assert isinstance(node, Node)
        self.node = node
        self.edge = edge
        self.length = length

    def __repr__(self):
        return f"(Node {self.node.id}'s edge:'{self.edge}', {self.length} along.)"

    def is_explicit(self):  # a.k.a. is not on an edge
        return self.edge == ""

    def set_node(self, node):  # Set point to a specific node, reset other values
        self.node = node
        self.edge = ""
        self.length = 0

    @property
    def edge_node(self) -> Node:  # Return the Node object of the edge this object points to
        return self.node.get_child(self.edge)

    def index_here(self):
        """
        Return the index in the original string that this point refers to
        """
        if self.is_explicit():
            return 0 if self.node.root else self.node.start
        return self.edge_node.start + self.length - 1

    def char_here(self):
        """
        Return the char in the original string that this point refers to
        """
        return Node.string[self.index_here()]


def create_root():
    """
    Create a root node with special root properties. Used to initalise the algorithm
    """
    assert len(Node.all_nodes) == 0
    root = Node(None, None)
    root.root = True
    root.link = root
    return root


def split_edge(split_point: Point):
    """
    Split a given edge into two separate edges, creating a new node in the middle (called a mediator in my code)
    Used for Rule 2s on implicit suffixes.
    Returns the newly created mediator node
    """
    assert not split_point.is_explicit()
    edge = split_point.edge_node
    original = edge.tuple()
    edge.detach()
    Node.num_splits += 1
    mediator = Node(original[0], original[0] + split_point.length - 1)
    mediator.suffix_index = None
    edge.start = original[0] + split_point.length
    assert edge.start <= edge.end_index
    mediator.add_child(edge)
    split_point.node.add_child(mediator)
    return mediator


def pos(n: int):
    return max(n, 0)


def do_phase(root: Node, active: Point, i, last_j, remainder):
    """
    Performs a single phase of Ukkonen's algorithm, returning values used for the next phase.
    """

    # Initialisation
    root_point = Point(root)
    Node.global_end += 1  # Perform rapid leaf extension trick (Rule 1)
    did_rule_three = False
    j = last_j + 1
    node_just_created = None

    while not did_rule_three and j <= i + 1:  # Run only the required extensions for this phase

        curr_char = Node.string[i]
        match = char_is_after(active, curr_char)
        if match:  # Decide if Rule 2 or 3.
            # RULE 3 LOGIC
            remainder += 1
            if node_just_created is not None:
                node_just_created.link = active.node  # Create suffix link (Rule 3)
            active = skip_count(1, active, i)  # Move active node
            did_rule_three = True  # Break loop
        else:
            # RULE 2 LOGIC
            if not active.is_explicit():  # Active point on an edge, need to split
                mediator = split_edge(active)
                mediator.add_child(Node(i, "#"))  # Dangle new character off of mediator node from split
                if node_just_created is not None:
                    node_just_created.link = mediator  # Create suffix link (First sub-case)
                node_just_created = mediator
                active.length -= 1
                if active.length == 0:
                    active.set_node(active.node)
            else:  # Active point on node, just dangle off a new node
                active.node.add_child(Node(i, "#"))
                if node_just_created is not None and node_just_created.link is None:
                    node_just_created.link = active.node  # Create suffix link (Second sub-case)
            remainder = pos(remainder - 1)
            active.set_node(active.node.link)  # Go to suffix link
            if remainder > 0:
                active = skip_count(remainder, Point(root), i - remainder)  # Traverse from root
            last_j = j
            j += 1
        # print(active)
        # root.print_tree()
    return active, remainder, last_j


def char_is_after(point: Point, char):
    """
    Return if a given character is traversable directly after a given point
    Used for Rule 2/3 selection
    """
    if point.is_explicit():  # If point on a node
        return char in point.node.children
    else:  # If point on an edge
        if point.length == point.edge_node.edge_length:
            return Node.string[point.edge_node.start] == char
        else:  # If not at the end of an edge
            return Node.string[point.index_here() + 1] == char


def skip_count(num_chars, start_point: Point, index):
    """
    Use the skip-counting trick to traverse num_chars down from point start_point.
    Use index value as where to start looking in the string for char comparison
    Returns the point that the traversal lands on
    """

    # Initialise
    incoming_length = -1
    existing_length = 0
    head = start_point
    chars_left = num_chars
    char = ""

    # Move point to nearest node if it is on an edge
    if not head.is_explicit():
        incoming_length = head.edge_node.edge_length - head.length
        if num_chars < incoming_length:
            head.length += num_chars
            return head
        head.set_node(head.edge_node)
        chars_left -= incoming_length
        index += incoming_length

    # Main traversal loop
    while chars_left > 0:
        direction = Node.string[index]  # Choose a direction to go from this point
        next_node = head.node.get_child(direction)
        if next_node is None:  # Went off the tree -> error
            raise IndexError(f"Attempted to traverse char\n '{direction}' at point {head}. ({index=})")
        incoming_length = next_node.edge_length
        if chars_left < incoming_length:  # Break if we able can't go down that edge
            break
        # Move down edge to next node
        chars_left -= incoming_length
        index += incoming_length
        head.set_node(next_node)

    # Return position on edge if couldn't traverse a final edge (search landed on edge)
    if chars_left > 0:
        head.edge = Node.string[index]
        head.length = chars_left

    return head


def ukkonen(string):
    """
    Reset the algorithm values and create return the root of a suffix tree for a given string
    using everyone's favourite algorithm: Ukkonen's algorithm. O(n) time.
    """
    # Initialise values
    string += "$"
    Node.string = string
    Node.global_end = 0
    Node.num_splits = 0
    Node.all_nodes.clear()
    n = len(string)
    remainder = 0
    last_j = 1
    # Perform base case i = 0 phase
    root = create_root()
    root.add_child(Node(0, "#"))
    active = Point(root)
    # Perform rest of phases
    for i in range(1, n):
        active, remainder, last_j = do_phase(root, active, i, last_j, remainder)
    return root


if __name__ == "__main__":
    ukkonen("abacabad")
    print("done")