diff --git a/min_spanning_tree/prim/CMakeLists.txt b/min_spanning_tree/prim/CMakeLists.txt new file mode 100644 index 0000000..a5a575b --- /dev/null +++ b/min_spanning_tree/prim/CMakeLists.txt @@ -0,0 +1,6 @@ +cmake_minimum_required(VERSION 3.10) +project(prim) +add_executable(prim + ../../weighted_graph.cpp + ../../weighted_graph.h + prim.cpp) diff --git a/min_spanning_tree/prim/prim.cpp b/min_spanning_tree/prim/prim.cpp new file mode 100644 index 0000000..2ecedbc --- /dev/null +++ b/min_spanning_tree/prim/prim.cpp @@ -0,0 +1,97 @@ +// Algorithm generating a minimum spanning tree using Prim's algorithm +// Authors: GeorĒµi Kocharyan + +#include +#include +#include +#include +#include +#include +#include +#include +#include "../../weighted_graph.h" + + + struct key_compare + { + bool operator()(const std::pair& l, const std::pair& r) + { + return l.second > r.second; + } + }; + +void next_edge(WeightedGraph const & G, std::vector> & min_neighbours, std::priority_queue, std::vector>, key_compare> & notMST, std::vector & inMST, int & last_added, double & total_weight) { + // update the priority queue + for (const auto& j: G.adjList(last_added)) + { + int to = j.first; + double weight = j.second; + if (weight < min_neighbours[to].second) + { + min_neighbours[to] = std::make_pair(last_added, weight); + notMST.push(std::make_pair(to, weight)); + } + } + // add the next edge + std::pair node = notMST.top(); + notMST.pop(); + if (!inMST[node.first]) + { + std::cout << node.first << "-" << min_neighbours[node.first].first << "\t" << node.second << std::endl; + last_added = node.first; + inMST[node.first] = true; + total_weight = total_weight + node.second; + } + +} + +void prim(WeightedGraph const & G) { + // preprocessing: remove all double edges except the minimal ones + + WeightedGraph H = G.remove_parallel(); + + // preprocessing: create a vector with the min MST neighbour of all vertices + + std::vector> min_neighbours(H.num_nodes(), std::make_pair(0,std::numeric_limits::infinity())); + + // preprocessing: track with a vector which elements are in the MST + + std::vector inMST(H.num_nodes(), false); + inMST[0] = true; + + // preprocessing: create a priority queue of vertices not yet in the MST + + std::priority_queue, std::vector>, key_compare> notMST; + for (int i = 1; i < H.num_nodes(); i++) + { + notMST.push(std::make_pair(i, std::numeric_limits::infinity())); + } + int last_added = 0; + double total_weight = 0; + while (!notMST.empty()) + { + next_edge(H, min_neighbours, notMST, inMST, last_added, total_weight); + } + std::cout << "The total weight of the MST is " << total_weight << std::endl; +} + +int main() { + + int size = 8; + WeightedGraph G(size); + G.add_edge(3,4,2); + G.add_edge(4,3,3); + G.add_edge(5,6,6); + G.add_edge(6,7,1); + G.add_edge(1,2,3); + G.add_edge(2,3,8); + G.add_edge(7,5,0.2); + G.add_edge(7,3,9); + G.add_edge(0,3,1); + G.add_edge(3,0,5); + + prim(G); + + return 0; + +} \ No newline at end of file diff --git a/weighted_graph.cpp b/weighted_graph.cpp new file mode 100644 index 0000000..68849a1 --- /dev/null +++ b/weighted_graph.cpp @@ -0,0 +1,89 @@ +#include "weighted_graph.h" +#include +#include + +Node::Node(std::list> neighbours_) : neighbours(std::move(neighbours_)) +{ + +} + +void Node::add_edge(int node_id, double weight) +{ + neighbours.push_back(std::make_pair(node_id,weight)); +} + +int Node::deg() const +{ + return neighbours.size(); +} + +WeightedGraph::WeightedGraph(size_t num_nodes) : nodes(num_nodes) +{ + +} + +void WeightedGraph::add_edge(int from, int to, double weight) +{ + nodes[from].add_edge(to, weight); + nodes[to].add_edge(from, weight); +} + +std::list> WeightedGraph::adjList(int node_id) const +{ + return (nodes[node_id]).neighbours; +} + +int WeightedGraph::deg(int node_id) const +{ + return nodes[node_id].deg(); +} + +size_t WeightedGraph::num_nodes() const +{ + return nodes.size(); +} + +std::pair WeightedGraph::min_neighbour(int node_id) const +{ + if (adjList(node_id).empty()) + { + return std::make_pair(0,0); + } + int result = (adjList(node_id).front()).first; + int current_min = (adjList(node_id).front()).second; + for (const auto& i : adjList(node_id)) + { + if (current_min > i.second) + { + current_min = i.second; + result = i.first; + } + } + return std::make_pair(result,current_min); +} + +// removes duplicate edges, leaving only the lightest edge (guarantees m = O(n^2)) + +WeightedGraph WeightedGraph::remove_parallel() const +{ + WeightedGraph G(num_nodes()); + for (int i = 0; i < num_nodes(); i++) + { + std::unordered_map lightestEdges; + for (const auto& j: adjList(i)) + { + int to = j.first; + double weight = j.second; + if (lightestEdges.find(to) == lightestEdges.end() || weight < lightestEdges[to]) + { + lightestEdges[to] = weight; + } + } + for (const auto& [to, weight] : lightestEdges) + { + G.nodes[i].add_edge(to, weight); + } + } + return G; +} + diff --git a/weighted_graph.h b/weighted_graph.h new file mode 100644 index 0000000..1b4d4fb --- /dev/null +++ b/weighted_graph.h @@ -0,0 +1,33 @@ +#ifndef C___WEIGHTEDGRAPH_H +#define C___WEIGHTEDGRAPH_H + +#include +#include +#include + + +struct Node { + std::list> neighbours; + Node() = default; + Node(std::list> neighbours); + void add_edge(int node_id, double weight); + int deg() const; +}; + +class WeightedGraph { +public: + WeightedGraph(size_t num_nodes); + void add_edge(int from, int to, double weight); + std::list> adjList(int node_id) const; + size_t num_nodes() const; + int deg(int v) const; + std::pair min_neighbour(int node_id) const; + WeightedGraph remove_parallel() const; + +private: + std::vector nodes; +}; + +#endif //C___WEIGHTEDGRAPH_H + +