1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
|
from functools import reduce
class Edge:
def __init__(self, start, end, weight):
self.start = start
self.end = end
self.weight = weight
def __str__(self):
return f"{self.start} -({self.weight})-> {self.end}"
def __repr__(self):
return str(self)
def __int__(self):
return self.weight
def create_edge_from_line(line):
data = map(int, line.split(" "))
return Edge(*data)
def read_edges(filename):
edges = []
with open(filename, "r") as file:
lines = file.read().split("\n")
for line in lines:
edges.append(create_edge_from_line(line))
return edges
def count_vertices(edges):
v = -1
for edge in edges:
v = max(v, edge.start, edge.end)
return v + 1
def sort_edges(edges):
return sorted(edges, key=lambda e: e.weight) # Timsort, O(e log e)
def find(groups, vertex):
# return vertex if groups[vertex] < 0 else find(groups, groups[vertex])
if groups[vertex] < 0:
return vertex
else:
return find(groups, groups[vertex])
def union(groups, u, v):
root_u = find(groups, u)
root_v = find(groups, v)
if root_u == root_v:
return
height_u = -groups[root_u]
height_v = -groups[root_v]
if height_u > height_v:
groups[root_v] = root_u
elif height_v > height_u:
groups[root_u] = root_v
else:
groups[root_u] = root_v
groups[root_v] = -(height_v+1)
def is_spanning(min_span_tree, v):
exist = [False for _ in range(v)]
for edge in min_span_tree:
exist[edge.start] = True
exist[edge.end] = True
return all(exist)
def main():
# print(sum(map(int, [2, Edge(0, 1, 5)])))
edges = read_edges("edges.txt")
print(edges)
print(sort_edges(edges))
v = count_vertices(edges)
groups = [-1 for _ in range(v)]
min_span_tree = []
for edge in sort_edges(edges):
if find(groups, edge.start) != find(groups, edge.end):
union(groups, edge.start, edge.end)
min_span_tree.append(edge)
weight_total = sum(map(int, min_span_tree))
print(weight_total)
print("\n".join(map(str, min_span_tree)))
assert is_spanning(min_span_tree, v)
main()
|