Skip to content

Commit cdac2cb

Browse files
Update graph.py
1 parent 476aff0 commit cdac2cb

File tree

1 file changed

+29
-21
lines changed
  • graphs/minimum-spanning-tree/kruskals-algorithm

1 file changed

+29
-21
lines changed

graphs/minimum-spanning-tree/kruskals-algorithm/graph.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,53 @@
1-
from typing import Set, Dict, List, Tuple
1+
2+
from typing import List, Set, Dict, Tuple
23

34
class Graph:
45

6+
57
def __init__(self) -> None:
68
self.vertices:Set[str] = set()
7-
self.root:Dict[str, str] = dict()
8-
self.rank:Dict[str, int] = dict()
9+
self.roots:Dict[str, str] = dict()
10+
self.sizes:Dict[str, int] = dict()
911
self.edges:List[Tuple[str, str, int]] = list()
1012
self.mst:List[Tuple[str, str, int]] = list()
1113

12-
14+
1315
def add_vertex(self, label:str) -> None:
1416
self.vertices.add(label)
15-
self.root[label] = label
16-
self.rank[label] = 0
17+
self.roots[label] = label
18+
self.sizes[label] = 1
1719

1820

1921
def add_edge(self, label1:str, label2:str, weight:int) -> None:
22+
if label1 not in self.vertices or label2 not in self.vertices:
23+
raise Exception("Vertices must be added before connecting them")
2024
self.edges.append((label1, label2, weight))
2125

2226

23-
def kruskal(self) -> None:
27+
def kruskal(self) -> List[Tuple[str, str, int]]:
28+
self.mst.clear()
2429
self.edges.sort(key = lambda edge: edge[2])
25-
for l1, l2, weight in self.edges:
2630

27-
root1:str = self._find_root(l1)
28-
root2:str = self._find_root(l2)
31+
for v1, v2, weight in self.edges:
32+
33+
root1:str = self._find_root(v1)
34+
root2:str = self._find_root(v2)
2935

3036
if root1 != root2:
31-
if self.rank[root1] > self.rank[root2]:
32-
self.root[root2] = root1
33-
self.rank[root1] = self.rank[root1] + 1
37+
if self.sizes[root1] >= self.sizes[root2]:
38+
self.roots[root2] = root1
39+
self.sizes[root1] += self.sizes[root2]
3440
else:
35-
self.root[root1] = root2
36-
self.rank[root2] = self.rank[root2] + 1
37-
self.mst.append((l1, l2, weight))
41+
self.roots[root1] = root2
42+
self.sizes[root2] += self.sizes[root1]
43+
self.mst.append((v1, v2, weight))
44+
45+
return list(self.mst)
46+
47+
def _find_root(self, label:str) -> str:
48+
if self.roots[label] != label:
49+
self.roots[label] = self._find_root(self.roots[label])
50+
return self.roots[label]
3851

39-
print(self.mst)
4052

4153

42-
def _find_root(self, label:str) -> str:
43-
if self.root[label] != label:
44-
self.root[label] = self._find_root(self.root[label])
45-
return self.root[label]

0 commit comments

Comments
 (0)