Added Prim's algorithm
This commit is contained in:
parent
8174057384
commit
342efb8e59
4 changed files with 225 additions and 0 deletions
6
min_spanning_tree/prim/CMakeLists.txt
Normal file
6
min_spanning_tree/prim/CMakeLists.txt
Normal file
|
@ -0,0 +1,6 @@
|
|||
cmake_minimum_required(VERSION 3.10)
|
||||
project(prim)
|
||||
add_executable(prim
|
||||
../../weighted_graph.cpp
|
||||
../../weighted_graph.h
|
||||
prim.cpp)
|
97
min_spanning_tree/prim/prim.cpp
Normal file
97
min_spanning_tree/prim/prim.cpp
Normal file
|
@ -0,0 +1,97 @@
|
|||
// Algorithm generating a minimum spanning tree using Prim's algorithm
|
||||
// Authors: Georǵi Kocharyan
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include "../../weighted_graph.h"
|
||||
|
||||
|
||||
struct key_compare
|
||||
{
|
||||
bool operator()(const std::pair<int, double>& l, const std::pair<int, double>& r)
|
||||
{
|
||||
return l.second > r.second;
|
||||
}
|
||||
};
|
||||
|
||||
void next_edge(WeightedGraph const & G, std::vector<std::pair<int, double>> & min_neighbours, std::priority_queue<std::pair<int, double>, std::vector<std::pair<int, double>>, key_compare> & notMST, std::vector<bool> & inMST, int & last_added, double & total_weight) {
|
||||
// update the priority queue
|
||||
for (const auto& j: G.adjList(last_added))
|
||||
{
|
||||
int to = j.first;
|
||||
double weight = j.second;
|
||||
if (weight < min_neighbours[to].second)
|
||||
{
|
||||
min_neighbours[to] = std::make_pair(last_added, weight);
|
||||
notMST.push(std::make_pair(to, weight));
|
||||
}
|
||||
}
|
||||
// add the next edge
|
||||
std::pair<int, double> node = notMST.top();
|
||||
notMST.pop();
|
||||
if (!inMST[node.first])
|
||||
{
|
||||
std::cout << node.first << "-" << min_neighbours[node.first].first << "\t" << node.second << std::endl;
|
||||
last_added = node.first;
|
||||
inMST[node.first] = true;
|
||||
total_weight = total_weight + node.second;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void prim(WeightedGraph const & G) {
|
||||
// preprocessing: remove all double edges except the minimal ones
|
||||
|
||||
WeightedGraph H = G.remove_parallel();
|
||||
|
||||
// preprocessing: create a vector with the min MST neighbour of all vertices
|
||||
|
||||
std::vector<std::pair<int, double>> min_neighbours(H.num_nodes(), std::make_pair(0,std::numeric_limits<double>::infinity()));
|
||||
|
||||
// preprocessing: track with a vector which elements are in the MST
|
||||
|
||||
std::vector<bool> inMST(H.num_nodes(), false);
|
||||
inMST[0] = true;
|
||||
|
||||
// preprocessing: create a priority queue of vertices not yet in the MST
|
||||
|
||||
std::priority_queue<std::pair<int, double>, std::vector<std::pair<int, double>>, key_compare> notMST;
|
||||
for (int i = 1; i < H.num_nodes(); i++)
|
||||
{
|
||||
notMST.push(std::make_pair(i, std::numeric_limits<double>::infinity()));
|
||||
}
|
||||
int last_added = 0;
|
||||
double total_weight = 0;
|
||||
while (!notMST.empty())
|
||||
{
|
||||
next_edge(H, min_neighbours, notMST, inMST, last_added, total_weight);
|
||||
}
|
||||
std::cout << "The total weight of the MST is " << total_weight << std::endl;
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
int size = 8;
|
||||
WeightedGraph G(size);
|
||||
G.add_edge(3,4,2);
|
||||
G.add_edge(4,3,3);
|
||||
G.add_edge(5,6,6);
|
||||
G.add_edge(6,7,1);
|
||||
G.add_edge(1,2,3);
|
||||
G.add_edge(2,3,8);
|
||||
G.add_edge(7,5,0.2);
|
||||
G.add_edge(7,3,9);
|
||||
G.add_edge(0,3,1);
|
||||
G.add_edge(3,0,5);
|
||||
|
||||
prim(G);
|
||||
|
||||
return 0;
|
||||
|
||||
}
|
89
weighted_graph.cpp
Normal file
89
weighted_graph.cpp
Normal file
|
@ -0,0 +1,89 @@
|
|||
#include "weighted_graph.h"
|
||||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
|
||||
Node::Node(std::list<std::pair<int, double>> neighbours_) : neighbours(std::move(neighbours_))
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
void Node::add_edge(int node_id, double weight)
|
||||
{
|
||||
neighbours.push_back(std::make_pair(node_id,weight));
|
||||
}
|
||||
|
||||
int Node::deg() const
|
||||
{
|
||||
return neighbours.size();
|
||||
}
|
||||
|
||||
WeightedGraph::WeightedGraph(size_t num_nodes) : nodes(num_nodes)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
void WeightedGraph::add_edge(int from, int to, double weight)
|
||||
{
|
||||
nodes[from].add_edge(to, weight);
|
||||
nodes[to].add_edge(from, weight);
|
||||
}
|
||||
|
||||
std::list<std::pair<int, double>> WeightedGraph::adjList(int node_id) const
|
||||
{
|
||||
return (nodes[node_id]).neighbours;
|
||||
}
|
||||
|
||||
int WeightedGraph::deg(int node_id) const
|
||||
{
|
||||
return nodes[node_id].deg();
|
||||
}
|
||||
|
||||
size_t WeightedGraph::num_nodes() const
|
||||
{
|
||||
return nodes.size();
|
||||
}
|
||||
|
||||
std::pair<int,double> WeightedGraph::min_neighbour(int node_id) const
|
||||
{
|
||||
if (adjList(node_id).empty())
|
||||
{
|
||||
return std::make_pair(0,0);
|
||||
}
|
||||
int result = (adjList(node_id).front()).first;
|
||||
int current_min = (adjList(node_id).front()).second;
|
||||
for (const auto& i : adjList(node_id))
|
||||
{
|
||||
if (current_min > i.second)
|
||||
{
|
||||
current_min = i.second;
|
||||
result = i.first;
|
||||
}
|
||||
}
|
||||
return std::make_pair(result,current_min);
|
||||
}
|
||||
|
||||
// removes duplicate edges, leaving only the lightest edge (guarantees m = O(n^2))
|
||||
|
||||
WeightedGraph WeightedGraph::remove_parallel() const
|
||||
{
|
||||
WeightedGraph G(num_nodes());
|
||||
for (int i = 0; i < num_nodes(); i++)
|
||||
{
|
||||
std::unordered_map<int, double> lightestEdges;
|
||||
for (const auto& j: adjList(i))
|
||||
{
|
||||
int to = j.first;
|
||||
double weight = j.second;
|
||||
if (lightestEdges.find(to) == lightestEdges.end() || weight < lightestEdges[to])
|
||||
{
|
||||
lightestEdges[to] = weight;
|
||||
}
|
||||
}
|
||||
for (const auto& [to, weight] : lightestEdges)
|
||||
{
|
||||
G.nodes[i].add_edge(to, weight);
|
||||
}
|
||||
}
|
||||
return G;
|
||||
}
|
||||
|
33
weighted_graph.h
Normal file
33
weighted_graph.h
Normal file
|
@ -0,0 +1,33 @@
|
|||
#ifndef C___WEIGHTEDGRAPH_H
|
||||
#define C___WEIGHTEDGRAPH_H
|
||||
|
||||
#include <list>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
|
||||
struct Node {
|
||||
std::list<std::pair<int, double>> neighbours;
|
||||
Node() = default;
|
||||
Node(std::list<std::pair<int, double>> neighbours);
|
||||
void add_edge(int node_id, double weight);
|
||||
int deg() const;
|
||||
};
|
||||
|
||||
class WeightedGraph {
|
||||
public:
|
||||
WeightedGraph(size_t num_nodes);
|
||||
void add_edge(int from, int to, double weight);
|
||||
std::list<std::pair<int, double>> adjList(int node_id) const;
|
||||
size_t num_nodes() const;
|
||||
int deg(int v) const;
|
||||
std::pair<int,double> min_neighbour(int node_id) const;
|
||||
WeightedGraph remove_parallel() const;
|
||||
|
||||
private:
|
||||
std::vector<Node> nodes;
|
||||
};
|
||||
|
||||
#endif //C___WEIGHTEDGRAPH_H
|
||||
|
||||
|
Loading…
Reference in a new issue