aboutsummaryrefslogtreecommitdiff
path: root/ass2/q1/kruskals.py
blob: 6b73019c08ee9322577ae90c93e4344e33ead5a7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Alexander Occhipinti - 29994705

import sys


class Edge:
    """
    The class which represents an edge in a graph
    """
    def __init__(self, start, end, weight):
        self.start = start
        self.end = end
        self.weight = weight

    def __repr__(self):
        return f"{self.start} {self.end} {self.weight}"

    def __str__(self):
        return repr(self) + "\n"  # The form output to file

    def __len__(self):
        return self.weight


def create_edge_from_line(line):
    """
    Creates an Edge object by parsing a line of the format specified in the assignment brief
    """
    data = map(int, line.split(" "))
    return Edge(*data)


def read_edges(filename):
    """
    Reads in a file and creates a list of Edge objects by parsing the format specified in the assignment brief
    """
    edges = []
    with open(filename, "r") as file:
        lines = file.read().split("\n")
        for line in lines:
            edges.append(create_edge_from_line(line))
    return edges


def sort_edges(edges):
    """
    Returns a sorted (from smallest to biggest edge) version of a list of edges
    Complexity is O(e log e) where e = num of edges due to timSort.
    """
    return sorted(edges, key=len)


def find(groups, vertex):
    """
    Finds the root of the vertex's group according to a union-find data structure
    Implements the find operation of union-by-rank
    """
    if groups[vertex] < 0:
        return vertex
    else:
        groups[vertex] = find(groups, groups[vertex])
        return find(groups, groups[vertex])


def union(groups, u, v):
    """
    Unions two vertices together (by height/rank) according to a union-find data structure
    """
    root_u, root_v = find(groups, u), find(groups, v)
    height_u, height_v = -groups[root_u], -groups[root_v]
    if root_u == root_v:
        return
    if height_u > height_v:
        groups[root_v] = root_u
    elif height_v > height_u:
        groups[root_u] = root_v
    else:
        groups[root_u] = root_v
        groups[root_v] = -(height_v + 1)


def is_spanning(min_span_tree, v):
    """
    Returns if a given graph touches all vertices
    """
    exist = [False for _ in range(v)]
    for edge in min_span_tree:
        exist[edge.start] = True
        exist[edge.end] = True
    return all(exist)


def kruskals(edges, v):
    """
    Generates a minimum spanning tree for an undirected, weighted, connected graph using Kruskal's algorithm
    """
    groups = [-1 for _ in range(v)]  # Init tree structure for union-by-height
    min_span_tree = []
    total_weight = 0
    for edge in sort_edges(edges):  # Start from the smallest edge and go up
        if find(groups, edge.start) != find(groups, edge.end):  # Add to span tree if edges are from different groups
            union(groups, edge.start, edge.end)
            min_span_tree.append(edge)
            total_weight += edge.weight
    return total_weight, min_span_tree


def write_output(total_weight, min_span_tree, filename):
    """
    Writes the required data to a given file in the format required by the assignment
    """
    with open(filename, "w") as file:
        file.write(f"{total_weight}\n")
        file.writelines(map(str, min_span_tree))


def parse_args(argv):
    if len(sys.argv) != 3:
        raise IndexError(f"Incorrect number of arguments provided. Expected 3, got {len(argv)}.")
    return int(argv[1]), argv[2]


def main():
    """
    The main function of the program.
    """
    v, filename = parse_args(sys.argv)
    edges = read_edges(filename)
    total_weight, min_span_tree = kruskals(edges, v)
    write_output(total_weight, min_span_tree, "output_kruskals.txt")
    print("Done.")

if __name__ == "__main__":
    main()