|
1 |
| -from typing import Set, Dict, List, Tuple |
| 1 | + |
| 2 | +from typing import List, Set, Dict, Tuple |
2 | 3 |
|
3 | 4 | class Graph:
|
4 | 5 |
|
| 6 | + |
5 | 7 | def __init__(self) -> None:
|
6 | 8 | self.vertices:Set[str] = set()
|
7 |
| - self.root:Dict[str, str] = dict() |
8 |
| - self.rank:Dict[str, int] = dict() |
| 9 | + self.roots:Dict[str, str] = dict() |
| 10 | + self.sizes:Dict[str, int] = dict() |
9 | 11 | self.edges:List[Tuple[str, str, int]] = list()
|
10 | 12 | self.mst:List[Tuple[str, str, int]] = list()
|
11 | 13 |
|
12 |
| - |
| 14 | + |
13 | 15 | def add_vertex(self, label:str) -> None:
|
14 | 16 | self.vertices.add(label)
|
15 |
| - self.root[label] = label |
16 |
| - self.rank[label] = 0 |
| 17 | + self.roots[label] = label |
| 18 | + self.sizes[label] = 1 |
17 | 19 |
|
18 | 20 |
|
19 | 21 | def add_edge(self, label1:str, label2:str, weight:int) -> None:
|
| 22 | + if label1 not in self.vertices or label2 not in self.vertices: |
| 23 | + raise Exception("Vertices must be added before connecting them") |
20 | 24 | self.edges.append((label1, label2, weight))
|
21 | 25 |
|
22 | 26 |
|
23 |
| - def kruskal(self) -> None: |
| 27 | + def kruskal(self) -> List[Tuple[str, str, int]]: |
| 28 | + self.mst.clear() |
24 | 29 | self.edges.sort(key = lambda edge: edge[2])
|
25 |
| - for l1, l2, weight in self.edges: |
26 | 30 |
|
27 |
| - root1:str = self._find_root(l1) |
28 |
| - root2:str = self._find_root(l2) |
| 31 | + for v1, v2, weight in self.edges: |
| 32 | + |
| 33 | + root1:str = self._find_root(v1) |
| 34 | + root2:str = self._find_root(v2) |
29 | 35 |
|
30 | 36 | if root1 != root2:
|
31 |
| - if self.rank[root1] > self.rank[root2]: |
32 |
| - self.root[root2] = root1 |
33 |
| - self.rank[root1] = self.rank[root1] + 1 |
| 37 | + if self.sizes[root1] >= self.sizes[root2]: |
| 38 | + self.roots[root2] = root1 |
| 39 | + self.sizes[root1] += self.sizes[root2] |
34 | 40 | else:
|
35 |
| - self.root[root1] = root2 |
36 |
| - self.rank[root2] = self.rank[root2] + 1 |
37 |
| - self.mst.append((l1, l2, weight)) |
| 41 | + self.roots[root1] = root2 |
| 42 | + self.sizes[root2] += self.sizes[root1] |
| 43 | + self.mst.append((v1, v2, weight)) |
| 44 | + |
| 45 | + return list(self.mst) |
| 46 | + |
| 47 | + def _find_root(self, label:str) -> str: |
| 48 | + if self.roots[label] != label: |
| 49 | + self.roots[label] = self._find_root(self.roots[label]) |
| 50 | + return self.roots[label] |
38 | 51 |
|
39 |
| - print(self.mst) |
40 | 52 |
|
41 | 53 |
|
42 |
| - def _find_root(self, label:str) -> str: |
43 |
| - if self.root[label] != label: |
44 |
| - self.root[label] = self._find_root(self.root[label]) |
45 |
| - return self.root[label] |
|
0 commit comments