diff options
Diffstat (limited to 'ass2')
| -rw-r--r-- | ass2/q2/suffix_array.py | 162 | 
1 files changed, 162 insertions, 0 deletions
| diff --git a/ass2/q2/suffix_array.py b/ass2/q2/suffix_array.py new file mode 100644 index 0000000..1bb88b9 --- /dev/null +++ b/ass2/q2/suffix_array.py @@ -0,0 +1,162 @@ +class Node: +    global_end = 0 +    string = "" +    all = [] + +    def __init__(self, data=None): +        self.data = data +        self.children = [] +        self.id = len(self.all) +        self.all.append(self) +        self.parent = None + +    def __str__(self): +        tup = self.get_tuple() +        if tup is not None: +            j, i = tup +            return f"[{self.id}, {self.data}, {self.string[j:i + 1]}]" +        return f"[{self.id} root]" + +    def __repr__(self): +        return f"[{self.id}]" + +    def print_children(self): +        print(self, end=" contains: ") +        for child in self.children: +            print(child, end=", ") + +    def add_child(self, child): +        if child: +            child.parent = self +            self.children.append(child) + +    def num_chars_deep(self): +        if self.data is None: +            return 0 +        else: +            return self.parent.num_chars_deep() + self.edge_length() + +    def remove_child(self, child): +        self.children.remove(child) + +    def bastardise(self): +        self.parent.remove_child(self) +        parent = None + +    def get_tuple(self): +        if self.data is not None and self.data[1] == "#": +            return self.data[0], self.global_end +        return self.data + +    def first_letter(self): +        return self.string[self.data[0]] + +    def get_child(self, char): +        for child in self.children: +            if child.first_letter() == char: +                return child +        return None + +    def edge_length(self): +        if self.data is None: +            return 0 +        else: +            substring = self.get_tuple() +            return substring[1] - substring[0] + 1 + +    def print_tree(self, spaces=0): +        print(f"{self}") +        for child in self.children: +            print(f"   " * spaces, end="") +            child.print_tree(spaces=spaces + 1) + + +def branch_from_point(active_node, active_edge, active_length, j): +    edge = active_node.get_child(active_edge) +    edge.bastardise() +    original_substr = edge.get_tuple() +    edge.data = (edge.data[0] + active_length, edge.data[1]) +    mediator = Node((original_substr[0], original_substr[0] + active_length - 1)) +    active_node.add_child(mediator) +    mediator.add_child(edge) +    mediator.add_child(Node((j + 1, '#'))) + + +def traverse_down(string, substring, active_node, active_edge, active_length, remainder): +    j, i = substring +    traversed = 0 +    g = 0 +    char = "" +    remainder -= 0 if active_node.data is None else active_node.data[0] + 1 +    while remainder > 0: +        char = string[i + traversed] +        path = active_node.get_child(char) +        assert path is not None +        traversed += path.edge_length() +        g = min(remainder, path.edge_length()) +        # if path is None: +        #     return active_node, active_edge, active_length +        if remainder >= path.edge_length(): +            active_node = path +            active_length = 0 +            active_edge = '' +            remainder -= path.edge_length() +        else: +            active_length = remainder +            active_edge = char if active_length != 0 else None +            break +    active_length = remainder +    return active_node, active_edge, active_length + + +def char_match_after(char, active_node, active_edge, active_length): +    if active_length < 1: +        return active_node.get_child(char) is not None +    edge = active_node.get_child(active_edge) +    position = edge.data[0] + active_length + 1 +    return Node.string[position] == char + + +def suffix_tree(string): +    # string += "$" +    Node.string = string +    n = len(string) +    last_j = 0 +    root = Node() +    root.add_child(Node((0, '#'))) +    active_node = root +    active_edge = "" +    active_length = 0 +    remainder = 0 +    last_rule = 0 +    for i in range(1, n): +        Node.global_end += 1  # Rule 1s implicitly +        for j in range(last_j + 1, i + 1): +            print(string[j:i + 1], end=": ") +            # if 2 or more letter string, just match the last character at the current position (we already went there) i.e. i i think +            char = string[i] if i-j <= 1 else string[j] +            match = char_match_after(char, active_node, active_edge, active_length) +            if match:  # Rule 3 +                print(3, -j) +                remainder += 1 +                active_node, active_edge, active_length = traverse_down(string, (j, i), active_node, active_edge, +                                                                        active_length, remainder) + +                break +            else:  # Rule 2 +                print(2, -j) +                if active_length > 0: +                    branch_from_point(active_node, active_edge, active_length, j) +                else: +                    active_node.add_child(Node((j, '#'))) +                active_edge = "" +                remainder = max(remainder - 1, 0) +                active_length = 0 +                last_j = j + +        print(f"{i=} '{string[i]}' => {(active_node, active_edge, active_length)}, {remainder}") +        root.print_tree() +        print("\n" * 10) + + +suffix_tree("abacabad$") | 
