Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download

📚 The CoCalc Library - books, templates and other resources

132928 views
License: OTHER
1
#!/usr/bin/env python
2
#
3
# Copyright 2019 the original author or authors.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
# http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
#
17
import pygame
18
import numpy as np
19
import matplotlib.pyplot as plt
20
import networkx as nx
21
from cmath import isclose
22
import io
23
24
from vqe_playground.utils.resources import load_mem_image
25
from vqe_playground.utils.labels import comp_graph_node_labels
26
27
28
class NetworkGraph(pygame.sprite.Sprite):
29
"""Displays a network graph"""
30
def __init__(self, adj_matrix):
31
pygame.sprite.Sprite.__init__(self)
32
self.image = None
33
self.rect = None
34
self.adj_matrix = None
35
self.solution = None
36
self.graph = nx.Graph()
37
self.graph_pos = None
38
self.num_nodes = adj_matrix.shape[0] # Number of nodes in graph
39
self.set_adj_matrix(adj_matrix)
40
41
def update(self):
42
self.draw_network_graph(self.calc_node_colors())
43
44
def set_adj_matrix(self, adj_matrix):
45
self.graph = nx.Graph()
46
self.adj_matrix = adj_matrix
47
self.solution = np.zeros(self.num_nodes)
48
49
fig = plt.figure(figsize=(7, 5))
50
51
self.graph.add_nodes_from(np.arange(0, self.num_nodes, 1))
52
53
# tuple is (i,j,weight) where (i,j) is the edge
54
edge_list = []
55
for i in range(self.num_nodes):
56
for j in range(i + 1, self.num_nodes):
57
if not isclose(adj_matrix[i, j], 0.0):
58
edge_list.append((i, j, adj_matrix[i, j]))
59
60
self.graph.add_weighted_edges_from(edge_list)
61
62
self.graph_pos = nx.spring_layout(self.graph)
63
self.draw_network_graph(self.calc_node_colors())
64
65
def set_solution(self, solution):
66
self.solution = solution
67
68
self.draw_network_graph(self.calc_node_colors())
69
70
def draw_network_graph(self, colors):
71
edge_labels = dict([((u, v,), self.adj_matrix[u, v]) for u, v, d in self.graph.edges(data=True)])
72
nx.draw_networkx_edge_labels(self.graph, self.graph_pos, edge_labels=edge_labels)
73
74
labels = comp_graph_node_labels(self.num_nodes)
75
nx.draw_networkx_labels(self.graph, self.graph_pos, labels, font_size=16, font_color='white')
76
77
nx.draw_networkx(self.graph, self.graph_pos, with_labels=False, node_color=colors, node_size=600, alpha=.8, font_color='white')
78
plt.axis('off')
79
buf = io.BytesIO()
80
plt.savefig(buf, format="png")
81
82
self.image, self.rect = load_mem_image(buf, -1)
83
self.image.convert()
84
85
def calc_node_colors(self):
86
return ['r' if self.solution[self.num_nodes - i - 1] == 0 else 'b' for i in range(self.num_nodes)]
87
88
89