@@ -148,12 +148,14 @@ def bfs(g, s):
148
148
149
149
def dijk (g , s ):
150
150
"""
151
- dijk({1: [(2, 7), (3, 9), (6, 14)],
152
- 2: [(1, 7), (3, 10), (4, 15)],
153
- 3: [(1, 9), (2, 10), (4, 11), (6, 2)],
154
- 4: [(2, 15), (3, 11), (5, 6)],
155
- 5: [(4, 6), (6, 9)],
156
- 6: [(1, 14), (3, 2), (5, 9)]}, 1)
151
+ >>> dijk({
152
+ ... 1: [(2, 7), (3, 9), (6, 14)],
153
+ ... 2: [(1, 7), (3, 10), (4, 15)],
154
+ ... 3: [(1, 9), (2, 10), (4, 11), (6, 2)],
155
+ ... 4: [(2, 15), (3, 11), (5, 6)],
156
+ ... 5: [(4, 6), (6, 9)],
157
+ ... 6: [(1, 14), (3, 2), (5, 9)]
158
+ ... }, 1)
157
159
7
158
160
9
159
161
11
@@ -165,7 +167,7 @@ def dijk(g, s):
165
167
if len (known ) == len (g ) - 1 :
166
168
break
167
169
mini = 100000
168
- for key , value in dist :
170
+ for key , value in dist . items () :
169
171
if key not in known and value < mini :
170
172
mini = value
171
173
u = key
@@ -187,6 +189,15 @@ def dijk(g, s):
187
189
188
190
189
191
def topo (g , ind = None , q = None ):
192
+ """
193
+ Perform a topological sort on a directed acyclic graph.
194
+
195
+ >>> topo({1: [2, 3], 2: [4], 3: [4], 4: []})
196
+ 1
197
+ 2
198
+ 3
199
+ 4
200
+ """
190
201
if q is None :
191
202
q = [1 ]
192
203
if ind is None :
@@ -256,16 +267,35 @@ def adjm():
256
267
"""
257
268
258
269
259
- def floy (a_and_n ):
270
+ def floyd_warshall (a_and_n ):
271
+ """
272
+ Floyd-Warshall algorithm to compute all-pairs shortest paths.
273
+
274
+ Parameters:
275
+ a_and_n (tuple): A tuple (a, n) where
276
+ a is an N x N adjacency matrix (list of lists),
277
+ n is the number of nodes.
278
+
279
+ Example:
280
+ >>> floyd_warshall(([
281
+ ... [0, 5, float('inf')],
282
+ ... [50, 0, 10],
283
+ ... [float('inf'), float('inf'), 0]
284
+ ... ], 3))
285
+ [[0, 5, 15], [50, 0, 10], [inf, inf, 0]]
286
+ """
287
+
260
288
(a , n ) = a_and_n
261
- dist = list ( a )
289
+ dist = [ row [:] for row in a ] # create a deep copy of matrix a
262
290
path = [[0 ] * n for i in range (n )]
291
+
263
292
for k in range (n ):
264
293
for i in range (n ):
265
294
for j in range (n ):
266
295
if dist [i ][j ] > dist [i ][k ] + dist [k ][j ]:
267
296
dist [i ][j ] = dist [i ][k ] + dist [k ][j ]
268
- path [i ][k ] = k
297
+ path [i ][k ] = k # possible error
298
+
269
299
print (dist )
270
300
271
301
0 commit comments