# COMP 554 - Spring 2018
# Assignment 2, question 3
# Matthew Costa, Osbelia Duenas, Victoria Lam and Chase Sariaslani
import heapq
# a graph is represented by its cost matrix
test_graph = [ [ 0.0,10.0,12.0,11.0, 3.0, 5.0]
, [10.0, 0.0, 9.0,12.0, 2.0, 4.0]
, [12.0, 9.0, 0.0, 8.0, 0.0, 0.0]
, [11.0,12.0, 8.0, 0.0, 7.0, 0.0]
, [ 3.0, 2.0, 0.0, 7.0, 0.0, 6.0]
, [ 5.0, 4.0, 0.0, 0.0, 6.0, 0.0] ]
# main loop, returns the final spanning tree
def loop(graph, vertex_set, spanning_tree, included_vertices, heap_queue):
if vertex_set == included_vertices:
print(spanning_tree)
return spanning_tree
else:
(st, iv, hq) = step(graph, spanning_tree, included_vertices, heap_queue)
loop(graph, vertex_set, st, iv, hq)
# updates variables in main loop
# `spanning_tree` is a set of edges (without their cost)
# `included_vertices` is the set of vertices `spanning_tree` includes
# `heap_queue` is a priority queue of the remaining elligible edges
def step(graph, spanning_tree, included_vertices, heap_queue):
spanning_tree.add((heap_queue[0][1],heap_queue[0][2]))
included_vertices.add(heap_queue[0][1])
included_vertices.add(heap_queue[0][2])
edges = make_edge_set(graph)
new_heap_queue = make_heap_queue(included_vertices, edges)
return (spanning_tree, included_vertices, new_heap_queue)
# takes a graph and returns the set of edges
def make_edge_set(graph):
edge_set = set()
for i in range(len(graph)):
for j in range(i, len(graph)):
if graph[i][j] != 0:
edge_set.add( (graph[i][j], i, j) )
return edge_set
# makes a priority queue of all edges in `edge_set` that connect a vertex
# from `included_vertices` to vertex outside of `included_vertices`
def make_heap_queue(included_vertices, edge_set):
heep = []
for edge in edge_set:
if (edge[1] in included_vertices) ^ (edge[2] in included_vertices):
heapq.heappush(heep, edge)
return heep
# validates input, ensuring that it is square, that each entry is
# a scalar float, that the diagonal is zero, and that it is symmetric
def validate_input(graph):
for i in range(len(graph)):
assert len(graph) == len(graph[i])
assert graph[i][i] == 0.0
for j in range(len(graph)):
assert graph[i][j] == graph[j][i]
assert isinstance(graph[i][j], float)
# entry point
def prims(graph):
validate_input(graph)
vertex_set = set(range(len(graph)))
included = set([0])
queue = make_heap_queue(included, make_edge_set(graph))
return loop(graph, vertex_set, set(), included, queue)
prims(test_graph)