diff options
| author | akiyamn | 2021-04-27 14:12:47 +1000 |
|---|---|---|
| committer | akiyamn | 2021-04-27 14:12:47 +1000 |
| commit | 1695a8e7c9e0d345918946ef6fbc8f56c7751e32 (patch) | |
| tree | 41829151bb336da54aab6b13ff4c47b3a6454698 /ass2/ukkonen.py | |
| parent | e347fd3246a4282d7fea85c1ae727e48c810480b (diff) | |
| download | fit3155-1695a8e7c9e0d345918946ef6fbc8f56c7751e32.tar.gz fit3155-1695a8e7c9e0d345918946ef6fbc8f56c7751e32.zip | |
Ass 2: Guts working for suffix array
Diffstat (limited to 'ass2/ukkonen.py')
| -rw-r--r-- | ass2/ukkonen.py | 279 |
1 files changed, 279 insertions, 0 deletions
diff --git a/ass2/ukkonen.py b/ass2/ukkonen.py new file mode 100644 index 0000000..8d536c5 --- /dev/null +++ b/ass2/ukkonen.py @@ -0,0 +1,279 @@ +import sys + + +class OrderedDict(dict): + def __init__(self): + super().__init__() + self.first_letters = [None for _ in range(27)] + + def __setitem__(self, key, value): + super().__setitem__(key, value) + try: + self.first_letters[self.rank(key)] = value + except IndexError: + raise IndexError(f"Could not add item of key '{key}' since it is out of range of the rank function.") + + def __delitem__(self, key): + super().__delitem__(key) + self.first_letters[self.rank(key)] = None + + def ordered_items(self): + return filter(lambda x: x is not None, self.first_letters) + + @staticmethod + def rank(char): + # result = 26 if char == "$" else ord(char) - 97 + result = 0 if char == "$" else ord(char) - 96 + assert result in range(0, 27) + return result + + +class Node: + global_end = 0 + num_splits = 0 + all_nodes = [] + string = "" + + def __init__(self, start, end): + self.root = False + self.start = start + self.end = end + self.children = OrderedDict() + self.id = len(self.all_nodes) + self.suffix_index = self.id - self.num_splits - 1 + self.all_nodes.append(self) + self.parent = None + self.link = None + + def __str__(self): + link_str = "" if self.link is None else f" -> {self.link.id}" + if not self.root: + j, i = self.tuple() + return f"[{self.id}, {self.tuple()}, {self.string[j:i + 1]}{link_str}]" + return f"[{self.id} root{link_str}]" + + def __repr__(self): + return f"[{self.id}]" + + def print_tree(self, spaces=1): + print(f"{self}") + for edge in self.children: + print(f" " * spaces, end="") + self.get_child(edge).print_tree(spaces=spaces + 1) + + def first_char(self): + return self.string[self.start] + + def get_child(self, char): + if char in self.children: + return self.children[char] + return None + + def add_child(self, child): + child.parent = self + self.children[child.first_char()] = child + return child + + def remove_child(self, child): + self.children.pop(child.first_char()) + + @property + def end_index(self): + return self.tuple()[1] + + def tuple(self): + if self.root: + raise Exception("Can't get substring of root.") + if self.end == "#": + return self.start, self.global_end + return self.start, self.end + + @property + def edge_length(self): + if self.root: + return 0 + else: + start, end = self.tuple() + return end - start + 1 + + def detach(self): + self.parent.remove_child(self) + self.parent = None + + +class Point: + def __init__(self, node, edge="", length=0): + assert isinstance(node, Node) + self.node = node + self.edge = edge + self.length = length + + def __repr__(self): + return f"(Node {self.node.id}'s edge:'{self.edge}', {self.length} along.)" + + def is_explicit(self): # a.k.a. is not on an edge + return self.edge == "" + + def set_node(self, node): + self.node = node + self.edge = "" + self.length = 0 + if not self.is_explicit(): + print("WARNING: Node.set_node", file=sys.stderr) + + @property + def edge_node(self) -> Node: + return self.node.get_child(self.edge) + + def index_here(self): + if self.is_explicit(): + return 0 if self.node.root else self.node.start + return self.edge_node.start + self.length - 1 + + def char_here(self): + return Node.string[self.index_here()] + + +def create_root(): + assert len(Node.all_nodes) == 0 + root = Node(None, None) + root.root = True + root.link = root + return root + + +def split_edge(split_point: Point): + assert not split_point.is_explicit() + edge = split_point.edge_node + original = edge.tuple() + edge.detach() + Node.num_splits += 1 + mediator = Node(original[0], original[0] + split_point.length - 1) + mediator.suffix_index = None + edge.start = original[0] + split_point.length + assert edge.start <= edge.end_index + mediator.add_child(edge) + split_point.node.add_child(mediator) + return mediator + + +def pos(n: int): + return max(n, 0) + + +def do_phase(root: Node, active: Point, i, last_j, remainder): + root_point = Point(root) + Node.global_end += 1 + did_rule_three = False + j = last_j + 1 + node_just_created = None + while not did_rule_three and j <= i + 1: + + curr_char = Node.string[i] + match = char_is_after(active, curr_char) + if match: + # print(3) + remainder += 1 + if node_just_created is not None: + node_just_created.link = active.node + active = skip_count(1, active, i) + did_rule_three = True + else: + # print(2) + if not active.is_explicit(): + mediator = split_edge(active) + mediator.add_child(Node(i, "#")) + if node_just_created is not None: + node_just_created.link = mediator + node_just_created = mediator + active.length -= 1 + if active.length == 0: + active.set_node(active.node) + else: + active.node.add_child(Node(i, "#")) + if node_just_created is not None and node_just_created.link is None: + node_just_created.link = active.node + remainder = pos(remainder - 1) + active.set_node(active.node.link) + if remainder > 0: + active = skip_count(remainder, Point(root), i - remainder) + last_j = j + j += 1 + print(active) + root.print_tree() + return active, remainder, last_j + + +def char_is_after(point: Point, char): + if point.is_explicit(): + return char in point.node.children + else: + if point.length == point.edge_node.edge_length: + return Node.string[point.edge_node.start] == char + else: # If not at the end of an edge + # return Node.string[point.index_here() + point.length] == char + return Node.string[point.index_here() + 1] == char + + +def skip_count(num_chars, start_point: Point, index): + incoming_length = -1 + existing_length = 0 + head = start_point + chars_left = num_chars + char = "" + + if not head.is_explicit(): + incoming_length = head.edge_node.edge_length - head.length + if num_chars < incoming_length: + head.length += num_chars + return head + head.set_node(head.edge_node) + chars_left -= incoming_length + index += incoming_length + + # Node.string[i] if head.node.root else Node.string[head.node.end_index + 1] + # assert head.node.end_index + 1 + chars_left < len(Node.string) + while chars_left > 0: + # assert head.node.end_index + 1 + chars_left < len(Node.string) + direction = Node.string[index] + next_node = head.node.get_child(direction) + if next_node is None: + raise Exception(f"Attempted to traverse char\n '{direction}' at point {head}. ({index=})") + incoming_length = next_node.edge_length + if chars_left < incoming_length: + break + chars_left -= incoming_length + index += incoming_length + head.set_node(next_node) + + # direction = Node.string[index] + + if chars_left > 0: # Landed on an edge + head.edge = Node.string[index] + head.length = chars_left + + return head + + +def ukkonen(string): + string += "$" + Node.string = string + Node.global_end = 0 + Node.num_splits = 0 + Node.all_nodes.clear() + n = len(string) + remainder = 0 + last_j = 1 + root = create_root() + root.add_child(Node(0, "#")) + active = Point(root) + for i in range(1, n): + active, remainder, last_j = do_phase(root, active, i, last_j, remainder) + return root + + +if __name__ == "__main__": + # ukkonen("DEFDBEFFDDEFFFADEFFB") + ukkonen("abacabad") + print("done") +# ukkonen("abcbcbc$") |
