aboutsummaryrefslogtreecommitdiff
path: root/ass2/ukkonen.py
diff options
context:
space:
mode:
authorakiyamn2021-04-27 14:12:47 +1000
committerakiyamn2021-04-27 14:12:47 +1000
commit1695a8e7c9e0d345918946ef6fbc8f56c7751e32 (patch)
tree41829151bb336da54aab6b13ff4c47b3a6454698 /ass2/ukkonen.py
parente347fd3246a4282d7fea85c1ae727e48c810480b (diff)
downloadfit3155-1695a8e7c9e0d345918946ef6fbc8f56c7751e32.tar.gz
fit3155-1695a8e7c9e0d345918946ef6fbc8f56c7751e32.zip
Ass 2: Guts working for suffix array
Diffstat (limited to 'ass2/ukkonen.py')
-rw-r--r--ass2/ukkonen.py279
1 files changed, 279 insertions, 0 deletions
diff --git a/ass2/ukkonen.py b/ass2/ukkonen.py
new file mode 100644
index 0000000..8d536c5
--- /dev/null
+++ b/ass2/ukkonen.py
@@ -0,0 +1,279 @@
+import sys
+
+
+class OrderedDict(dict):
+ def __init__(self):
+ super().__init__()
+ self.first_letters = [None for _ in range(27)]
+
+ def __setitem__(self, key, value):
+ super().__setitem__(key, value)
+ try:
+ self.first_letters[self.rank(key)] = value
+ except IndexError:
+ raise IndexError(f"Could not add item of key '{key}' since it is out of range of the rank function.")
+
+ def __delitem__(self, key):
+ super().__delitem__(key)
+ self.first_letters[self.rank(key)] = None
+
+ def ordered_items(self):
+ return filter(lambda x: x is not None, self.first_letters)
+
+ @staticmethod
+ def rank(char):
+ # result = 26 if char == "$" else ord(char) - 97
+ result = 0 if char == "$" else ord(char) - 96
+ assert result in range(0, 27)
+ return result
+
+
+class Node:
+ global_end = 0
+ num_splits = 0
+ all_nodes = []
+ string = ""
+
+ def __init__(self, start, end):
+ self.root = False
+ self.start = start
+ self.end = end
+ self.children = OrderedDict()
+ self.id = len(self.all_nodes)
+ self.suffix_index = self.id - self.num_splits - 1
+ self.all_nodes.append(self)
+ self.parent = None
+ self.link = None
+
+ def __str__(self):
+ link_str = "" if self.link is None else f" -> {self.link.id}"
+ if not self.root:
+ j, i = self.tuple()
+ return f"[{self.id}, {self.tuple()}, {self.string[j:i + 1]}{link_str}]"
+ return f"[{self.id} root{link_str}]"
+
+ def __repr__(self):
+ return f"[{self.id}]"
+
+ def print_tree(self, spaces=1):
+ print(f"{self}")
+ for edge in self.children:
+ print(f" " * spaces, end="")
+ self.get_child(edge).print_tree(spaces=spaces + 1)
+
+ def first_char(self):
+ return self.string[self.start]
+
+ def get_child(self, char):
+ if char in self.children:
+ return self.children[char]
+ return None
+
+ def add_child(self, child):
+ child.parent = self
+ self.children[child.first_char()] = child
+ return child
+
+ def remove_child(self, child):
+ self.children.pop(child.first_char())
+
+ @property
+ def end_index(self):
+ return self.tuple()[1]
+
+ def tuple(self):
+ if self.root:
+ raise Exception("Can't get substring of root.")
+ if self.end == "#":
+ return self.start, self.global_end
+ return self.start, self.end
+
+ @property
+ def edge_length(self):
+ if self.root:
+ return 0
+ else:
+ start, end = self.tuple()
+ return end - start + 1
+
+ def detach(self):
+ self.parent.remove_child(self)
+ self.parent = None
+
+
+class Point:
+ def __init__(self, node, edge="", length=0):
+ assert isinstance(node, Node)
+ self.node = node
+ self.edge = edge
+ self.length = length
+
+ def __repr__(self):
+ return f"(Node {self.node.id}'s edge:'{self.edge}', {self.length} along.)"
+
+ def is_explicit(self): # a.k.a. is not on an edge
+ return self.edge == ""
+
+ def set_node(self, node):
+ self.node = node
+ self.edge = ""
+ self.length = 0
+ if not self.is_explicit():
+ print("WARNING: Node.set_node", file=sys.stderr)
+
+ @property
+ def edge_node(self) -> Node:
+ return self.node.get_child(self.edge)
+
+ def index_here(self):
+ if self.is_explicit():
+ return 0 if self.node.root else self.node.start
+ return self.edge_node.start + self.length - 1
+
+ def char_here(self):
+ return Node.string[self.index_here()]
+
+
+def create_root():
+ assert len(Node.all_nodes) == 0
+ root = Node(None, None)
+ root.root = True
+ root.link = root
+ return root
+
+
+def split_edge(split_point: Point):
+ assert not split_point.is_explicit()
+ edge = split_point.edge_node
+ original = edge.tuple()
+ edge.detach()
+ Node.num_splits += 1
+ mediator = Node(original[0], original[0] + split_point.length - 1)
+ mediator.suffix_index = None
+ edge.start = original[0] + split_point.length
+ assert edge.start <= edge.end_index
+ mediator.add_child(edge)
+ split_point.node.add_child(mediator)
+ return mediator
+
+
+def pos(n: int):
+ return max(n, 0)
+
+
+def do_phase(root: Node, active: Point, i, last_j, remainder):
+ root_point = Point(root)
+ Node.global_end += 1
+ did_rule_three = False
+ j = last_j + 1
+ node_just_created = None
+ while not did_rule_three and j <= i + 1:
+
+ curr_char = Node.string[i]
+ match = char_is_after(active, curr_char)
+ if match:
+ # print(3)
+ remainder += 1
+ if node_just_created is not None:
+ node_just_created.link = active.node
+ active = skip_count(1, active, i)
+ did_rule_three = True
+ else:
+ # print(2)
+ if not active.is_explicit():
+ mediator = split_edge(active)
+ mediator.add_child(Node(i, "#"))
+ if node_just_created is not None:
+ node_just_created.link = mediator
+ node_just_created = mediator
+ active.length -= 1
+ if active.length == 0:
+ active.set_node(active.node)
+ else:
+ active.node.add_child(Node(i, "#"))
+ if node_just_created is not None and node_just_created.link is None:
+ node_just_created.link = active.node
+ remainder = pos(remainder - 1)
+ active.set_node(active.node.link)
+ if remainder > 0:
+ active = skip_count(remainder, Point(root), i - remainder)
+ last_j = j
+ j += 1
+ print(active)
+ root.print_tree()
+ return active, remainder, last_j
+
+
+def char_is_after(point: Point, char):
+ if point.is_explicit():
+ return char in point.node.children
+ else:
+ if point.length == point.edge_node.edge_length:
+ return Node.string[point.edge_node.start] == char
+ else: # If not at the end of an edge
+ # return Node.string[point.index_here() + point.length] == char
+ return Node.string[point.index_here() + 1] == char
+
+
+def skip_count(num_chars, start_point: Point, index):
+ incoming_length = -1
+ existing_length = 0
+ head = start_point
+ chars_left = num_chars
+ char = ""
+
+ if not head.is_explicit():
+ incoming_length = head.edge_node.edge_length - head.length
+ if num_chars < incoming_length:
+ head.length += num_chars
+ return head
+ head.set_node(head.edge_node)
+ chars_left -= incoming_length
+ index += incoming_length
+
+ # Node.string[i] if head.node.root else Node.string[head.node.end_index + 1]
+ # assert head.node.end_index + 1 + chars_left < len(Node.string)
+ while chars_left > 0:
+ # assert head.node.end_index + 1 + chars_left < len(Node.string)
+ direction = Node.string[index]
+ next_node = head.node.get_child(direction)
+ if next_node is None:
+ raise Exception(f"Attempted to traverse char\n '{direction}' at point {head}. ({index=})")
+ incoming_length = next_node.edge_length
+ if chars_left < incoming_length:
+ break
+ chars_left -= incoming_length
+ index += incoming_length
+ head.set_node(next_node)
+
+ # direction = Node.string[index]
+
+ if chars_left > 0: # Landed on an edge
+ head.edge = Node.string[index]
+ head.length = chars_left
+
+ return head
+
+
+def ukkonen(string):
+ string += "$"
+ Node.string = string
+ Node.global_end = 0
+ Node.num_splits = 0
+ Node.all_nodes.clear()
+ n = len(string)
+ remainder = 0
+ last_j = 1
+ root = create_root()
+ root.add_child(Node(0, "#"))
+ active = Point(root)
+ for i in range(1, n):
+ active, remainder, last_j = do_phase(root, active, i, last_j, remainder)
+ return root
+
+
+if __name__ == "__main__":
+ # ukkonen("DEFDBEFFDDEFFFADEFFB")
+ ukkonen("abacabad")
+ print("done")
+# ukkonen("abcbcbc$")