Skip to content

Commit 1ce00e0

Browse files
committed
Add doctests to topo, fix dijk bug and doctest, rename and test floyd_warshall
1 parent 7a0fee4 commit 1ce00e0

File tree

1 file changed

+40
-10
lines changed

1 file changed

+40
-10
lines changed

graphs/basic_graphs.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,14 @@ def bfs(g, s):
148148

149149
def dijk(g, s):
150150
"""
151-
dijk({1: [(2, 7), (3, 9), (6, 14)],
152-
2: [(1, 7), (3, 10), (4, 15)],
153-
3: [(1, 9), (2, 10), (4, 11), (6, 2)],
154-
4: [(2, 15), (3, 11), (5, 6)],
155-
5: [(4, 6), (6, 9)],
156-
6: [(1, 14), (3, 2), (5, 9)]}, 1)
151+
>>> dijk({
152+
... 1: [(2, 7), (3, 9), (6, 14)],
153+
... 2: [(1, 7), (3, 10), (4, 15)],
154+
... 3: [(1, 9), (2, 10), (4, 11), (6, 2)],
155+
... 4: [(2, 15), (3, 11), (5, 6)],
156+
... 5: [(4, 6), (6, 9)],
157+
... 6: [(1, 14), (3, 2), (5, 9)]
158+
... }, 1)
157159
7
158160
9
159161
11
@@ -165,7 +167,7 @@ def dijk(g, s):
165167
if len(known) == len(g) - 1:
166168
break
167169
mini = 100000
168-
for key, value in dist:
170+
for key, value in dist.items():
169171
if key not in known and value < mini:
170172
mini = value
171173
u = key
@@ -187,6 +189,15 @@ def dijk(g, s):
187189

188190

189191
def topo(g, ind=None, q=None):
192+
"""
193+
Perform a topological sort on a directed acyclic graph.
194+
195+
>>> topo({1: [2, 3], 2: [4], 3: [4], 4: []})
196+
1
197+
2
198+
3
199+
4
200+
"""
190201
if q is None:
191202
q = [1]
192203
if ind is None:
@@ -256,16 +267,35 @@ def adjm():
256267
"""
257268

258269

259-
def floy(a_and_n):
270+
def floyd_warshall(a_and_n):
271+
"""
272+
Floyd-Warshall algorithm to compute all-pairs shortest paths.
273+
274+
Parameters:
275+
a_and_n (tuple): A tuple (a, n) where
276+
a is an N x N adjacency matrix (list of lists),
277+
n is the number of nodes.
278+
279+
Example:
280+
>>> floyd_warshall(([
281+
... [0, 5, float('inf')],
282+
... [50, 0, 10],
283+
... [float('inf'), float('inf'), 0]
284+
... ], 3))
285+
[[0, 5, 15], [50, 0, 10], [inf, inf, 0]]
286+
"""
287+
260288
(a, n) = a_and_n
261-
dist = list(a)
289+
dist = [row[:] for row in a] # create a deep copy of matrix a
262290
path = [[0] * n for i in range(n)]
291+
263292
for k in range(n):
264293
for i in range(n):
265294
for j in range(n):
266295
if dist[i][j] > dist[i][k] + dist[k][j]:
267296
dist[i][j] = dist[i][k] + dist[k][j]
268-
path[i][k] = k
297+
path[i][k] = k # possible error
298+
269299
print(dist)
270300

271301

0 commit comments

Comments
 (0)