Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download
1576 views
1
import itertools
2
import numpy as np
3
import re
4
import numbers
5
import sys
6
from scipy import linalg
7
import pandas as pd
8
import grammar
9
from collections import OrderedDict
10
import matplotlib.pyplot as plt
11
12
ver = 0.3
13
14
15
class Fillers(object):
16
17
def __init__(self, cfg, add_null=True):
18
19
null_filler = '_'
20
21
self.g = grammar.Grammar(cfg)
22
self._get_filler_names()
23
if add_null:
24
self.null = null_filler
25
self.names.append(null_filler)
26
else:
27
self.null = None
28
self._construct_treelets()
29
self._construct_pairs()
30
31
def _get_filler_names(self):
32
33
hnf_rules = self.g.hnfRules
34
fillers = []
35
for key in hnf_rules:
36
fillers.append(key)
37
for item in hnf_rules[key]:
38
fillers.extend(item)
39
40
fillers = list(set(fillers))
41
fillers.sort()
42
self.names = fillers
43
44
def _construct_treelets(self):
45
46
hnf_rules = self.g.hnfRules
47
treelet_fillers = []
48
49
for lhs in hnf_rules.keys():
50
rhs = hnf_rules[lhs]
51
52
for sym in rhs:
53
if len(sym) == 1 and self.is_bracketed(sym[0]):
54
curr_treelet = [lhs, sym[0]]
55
rhs_new = hnf_rules[sym[0]]
56
for sym_new in rhs_new:
57
curr_treelet.extend(sym_new)
58
treelet_fillers.append(curr_treelet)
59
60
self.treelets = treelet_fillers
61
62
def _construct_pairs(self):
63
64
self.pairs = []
65
66
for treelet in self.treelets:
67
self.pairs.append([[treelet[0], treelet[1]], '1_of_1'])
68
n_daughters = len(treelet) - 2
69
70
for ii in range(n_daughters):
71
self.pairs.append([[treelet[1], treelet[ii + 2]],
72
'%d_of_%d' % (ii + 1, n_daughters)])
73
74
def find(self, filler):
75
return 0
76
77
def find_mothers(self, filler):
78
return 0
79
80
def find_daughters(self, filler):
81
return 0
82
83
def subset_bracketed(self):
84
"""Return a list of bracketed filler symbols"""
85
pattern = re.compile('.*\[[0-9]+\]$')
86
fillers_bracketed = []
87
for filler in self.names:
88
if pattern.match(filler) is not None:
89
fillers_bracketed.append(filler)
90
return fillers_bracketed
91
92
def subset_unbracketed(self):
93
"""Return a list of bracketed filler symbols"""
94
fillers_unbracketed = [filler for filler in self.names
95
if filler not in self.subset_bracketed()]
96
return fillers_unbracketed
97
98
def subset_root(self):
99
return self.g.getRootNode(self.g.hnfRules)
100
101
def subset_terminals(self):
102
return self.g.getTerminalNodes(self.g.hnfRules)
103
104
def is_terminal(self, filler):
105
return filler in self.subset_terminals()
106
107
def is_bracketed(self, filler):
108
return filler in self.subset_bracketed()
109
110
def is_unbracketed(self, filler):
111
return filler in self.subset_unbracketed()
112
113
def is_root(self, filler):
114
return filler in self.subset_root()
115
116
def read_treelets(self):
117
for ii, treelet in enumerate(self.treelets):
118
temp = (treelet[0] + ' ( ' +
119
treelet[1] + ' ( ' + ' '.join(treelet[2:]) + ' ) )')
120
print(temp)
121
122
123
class SpanRoles(object):
124
125
def __init__(self, max_sent_len, max_num_branch=2):
126
'''Construct a SpanRole object.'''
127
# use_terminal: do you have special terminal symbols?
128
self.max_sent_len = max_sent_len
129
self.max_num_branch = max_num_branch
130
self.use_terminal = False
131
self._generate()
132
self._sort()
133
self._name()
134
self._graph()
135
self._check_overlap()
136
self._construct_treelets()
137
self._construct_pairs()
138
139
def _generate(self):
140
# Generate a set of span roles with n-branches.
141
roles = []
142
for num_branch in range(self.max_num_branch):
143
roles += list(itertools.combinations(
144
range(self.max_sent_len + 1), num_branch + 2))
145
146
if self.use_terminal:
147
roles_terminal = []
148
for role in roles:
149
if self.get_width(role) == 1:
150
roles_terminal.append((role[0], role[1], role[1]))
151
roles = roles + roles_terminal
152
153
self.roles = roles
154
self.num_roles = len(roles)
155
156
def sort_roles(self, roles_tuple):
157
# Sort span roles.
158
# (1) width - increasing order
159
# (2) num_branch - decreasing order
160
# (3) init_position - increasing order
161
roles_augmented = []
162
for role in roles_tuple:
163
roles_augmented.append(
164
(role, self.get_width(role),
165
self.get_num_branch(role), role[0]))
166
167
roles_augmented = sorted(
168
roles_augmented, key=lambda x: (x[1], -x[2], x[3]))
169
roles_tuple_sorted = [role_augmented[0] for role_augmented
170
in roles_augmented]
171
return roles_tuple_sorted
172
173
def _sort(self):
174
self.roles = self.sort_roles(self.roles)
175
176
def _name(self):
177
role_names = []
178
for role in self.roles:
179
role_names.append(self.tuple2str(role))
180
self.names = role_names
181
182
def _graph(self):
183
G = {}
184
G = OrderedDict(G)
185
186
for role_str in self.names:
187
G[role_str] = {}
188
G[role_str]['m'] = []
189
G[role_str]['d'] = []
190
191
for role_tuple in self.roles:
192
role_str = self.tuple2str(role_tuple)
193
194
if self.is_bracketed(role_tuple):
195
mother_tuple = self._find_mother(role_tuple)
196
mother_str = self.tuple2str(mother_tuple)
197
daughter_tuple = self._find_daughter(role_tuple)
198
daughter_str = [self.tuple2str(d) for d in daughter_tuple
199
if self.get_width(d) > 0]
200
G[role_str]['m'].append([mother_str])
201
G[mother_str]['d'].append([role_str])
202
G[role_str]['d'].append(daughter_str)
203
for ii, d in enumerate(daughter_str):
204
G[d]['m'].append([role_str])
205
206
if self.use_terminal and self.is_terminal(role_tuple):
207
mother_tuple = self._find_mother(role_tuple)
208
mother_str = self.tuple2str(mother_tuple)
209
G[role_str]['m'].append([mother_str])
210
G[mother_str]['d'].append([role_str])
211
212
self.G = G
213
214
def _check_overlap(self):
215
# It is assumed that span roles are sorted correctly. (see sort_roles)
216
# Find a set of bracketed roles that have the same mother
217
# (e.g., Roles (0,1,3) and (0,2,3) have the same mother.).
218
# We assume that every binding competes with every other binding
219
# in this set of bracketed roles that share the same mother node.
220
quant_list = [] # list of list of role names
221
for role_name in self.G:
222
if not self.is_bracketed(self.str2tuple(role_name)):
223
# An unbracketed role cannot have multiple daughters.
224
# Thus, daughters (below) must be
225
# list of list of a single binding.
226
daughters = self.G[role_name]['d']
227
if len(daughters) > 0:
228
ds = [d[0] for d in daughters
229
if not self.is_terminal(self.str2tuple(d[0]))]
230
if len(ds) > 0:
231
quant_list.append(ds)
232
quant_list.append([role_name])
233
234
# sorting
235
quant_list = sorted(
236
quant_list,
237
key=lambda x: (self.get_width(self.str2tuple(x[0])),
238
-self.get_num_branch(self.str2tuple(x[0])),
239
self.str2tuple(x[0])[0]))
240
241
self.quant_list = quant_list
242
243
def _construct_treelets(self):
244
# [(0,2), (0,1,2), (0,1), (1,2)] # S(S[1](A B))
245
# [(0,1), (0,1,1)] # A(a) (when special terminal fillers are used)
246
treelet_roles = []
247
for role_str in self.G:
248
role_tuple = self.str2tuple(role_str)
249
if self.is_bracketed(role_tuple):
250
curr_treelet = []
251
mlist = self.G[role_str]['m']
252
if len(mlist) > 1: # error
253
sys.exit('CHECK!!!')
254
curr_treelet.extend(mlist[0])
255
curr_treelet.append(role_str)
256
dlist = self.G[role_str]['d']
257
curr_treelet.extend(dlist[0])
258
treelet_roles.append(curr_treelet)
259
260
self.treelets = treelet_roles
261
262
def _construct_pairs(self):
263
self.pairs = []
264
for treelet in self.treelets:
265
self.pairs.append([[treelet[0], treelet[1]], '1_of_1'])
266
n_daughters = len(treelet) - 2
267
for ii in range(n_daughters):
268
self.pairs.append([[treelet[1], treelet[ii + 2]],
269
'%d_of_%d' % (ii + 1, n_daughters)])
270
271
def _find_mother(self, role_tuple_bracketed):
272
if self.is_bracketed(role_tuple_bracketed):
273
return (role_tuple_bracketed[0], role_tuple_bracketed[-1])
274
if self.use_terminal and self.is_terminal(role_tuple_bracketed):
275
return (role_tuple_bracketed[0], role_tuple_bracketed[-1])
276
277
def _find_daughter(self, role_tuple_bracketed):
278
if self.is_bracketed(role_tuple_bracketed):
279
return [(role_tuple_bracketed[ii], role_tuple_bracketed[ii + 1])
280
for ii in range(len(role_tuple_bracketed) - 1)]
281
282
def str2tuple(self, role_str):
283
return tuple([int(pos) for pos in role_str[1:-1].split(',')])
284
285
def tuple2str(self, role_tuple):
286
return str(role_tuple).replace(' ', '')
287
288
def get_width(self, role_tuple):
289
return role_tuple[-1] - role_tuple[0]
290
291
def get_num_branch(self, role_tuple):
292
return len(role_tuple) - 1
293
294
def is_bracketed(self, role_tuple):
295
# If we assume HNF, only bracked roles can have multiple children.
296
return (self.get_num_branch(role_tuple) > 1 and
297
self.get_width(role_tuple) > 1)
298
299
def is_terminal(self, role_tuple):
300
if self.use_terminal:
301
return (self.get_width(role_tuple) == 1 and
302
self.get_num_branch(role_tuple) == 2)
303
else:
304
return (self.get_width(role_tuple) == 1 and
305
self.get_num_branch(role_tuple) == 1)
306
307
def subset_terminals(self):
308
'''Return a list of terminal roles (tuple)'''
309
return [r for r in self.roles if self.is_terminal(r)]
310
311
def subset_bracketed(self):
312
'''Return a list of terminal roles (tuple)'''
313
return [r for r in self.roles if self.is_bracketed(r)]
314
315
def read_treelets(self):
316
for ii, treelet in enumerate(self.treelets):
317
temp = (treelet[0] + ' ( ' +
318
treelet[1] + ' ( ' + ' '.join(treelet[2:]) + ' ) )')
319
print(temp)
320
321
322
class HarmonicGrammar(object):
323
324
def __init__(self, cfg, size, role_type="span_role",
325
add_null=True, match='pair', penalize='impossible',
326
null_bias=0., add_constraint=False, unary_base='filler'):
327
328
# cfg (sting): context-free grammar
329
# size (int): max_sent_len when span roles are used,
330
# max_depth when recursive roles are used.
331
# role_type (string): 'span_role' or 'recursive_role'
332
# add_null (bool): add a null treelet or not
333
# null_bias (numeric): bias values assigned to null bindings
334
335
self.role_type = role_type
336
self.size = size
337
self.rules = []
338
self.filler_names = None
339
self.role_names = None
340
self.binding_names = None
341
342
# CFG(CNF) -> CFG(HNF)
343
# For now, only binary branching is supported.
344
self.num_branch = 2
345
self.bias_constraint = -10.
346
self.unused_binding_harmony = -10
347
self.null_bias = null_bias
348
349
# Use Nick's program to get a grammar in harmonic normal form.
350
# Later integrate the program with this and
351
# make it more consistent with this program.
352
self.grammar = grammar.Grammar(cfg)
353
self.fillers = Fillers(cfg, add_null=add_null)
354
self.filler_names = self.fillers.names
355
356
if role_type == "span_role":
357
self.roles = SpanRoles(
358
max_sent_len=size, max_num_branch=self.num_branch)
359
self.role_names = self.roles.names
360
self._set_binding_names()
361
362
self._construct_treelets(add_constraint)
363
self._construct_pairs()
364
self._convert_tuple_to_str()
365
self._add_rules_binary(match=match)
366
self._add_rules_unary(which=unary_base)
367
368
if add_constraint:
369
self._add_constraints(penalize=penalize)
370
371
def _add_constraints(self, penalize='unused'):
372
373
if penalize is 'unused':
374
# penalize unused bindings
375
bindings_used = []
376
for treelet in self.treelets:
377
[bindings_used.append(binding) for binding in treelet]
378
379
for binding in self.binding_names:
380
if binding not in bindings_used:
381
if binding.split('/')[0] != '_':
382
self.rules.append(
383
[[binding], self.unused_binding_harmony])
384
385
elif penalize is 'impossible':
386
# penalize impossible bindings
387
388
for b in self.binding_names:
389
f, r = b.split('/')
390
r_tuple = self.roles.str2tuple(r)
391
392
if f is not self.fillers.null:
393
394
# terminal fillers - terminal roles
395
if f in self.fillers.subset_terminals():
396
if r_tuple not in self.roles.subset_terminals():
397
self.rules.append([[b], self.bias_constraint])
398
399
# starting fillers - non-terminal, unbracketed roles
400
elif self.fillers.is_root(f):
401
if (r_tuple in self.roles.subset_terminals()) or \
402
(r_tuple in self.roles.subset_bracketed()):
403
self.rules.append([[b], self.bias_constraint])
404
405
# unbracketed fillers - non-terminal, unbracketed roles
406
elif f not in self.get_bracketed_fillers():
407
if (r_tuple in self.roles.subset_terminals()) or \
408
(r_tuple in self.roles.subset_bracketed()):
409
self.rules.append([[b], self.bias_constraint])
410
411
# bracketed fillers - bracketed roles
412
elif f in self.get_bracketed_fillers():
413
if r_tuple not in self.roles.subset_bracketed():
414
self.rules.append([[b], self.bias_constraint])
415
416
def _set_binding_names(self, sep='/'):
417
418
self.binding_names = [f + sep + r for r in self.roles.names
419
for f in self.fillers.names]
420
421
def _construct_treelets(self, add_constraint):
422
423
terminal_fillers = self.get_terminal_fillers()
424
terminal_roles = self.get_terminal_roles()
425
426
if self.fillers.null is not None:
427
null_filler = self.fillers.null
428
else:
429
null_filler = None
430
431
treelet_bindings = []
432
for treelet_f in self.fillers.treelets:
433
for treelet_r in self.roles.treelets:
434
if len(treelet_f) == len(treelet_r):
435
# test terminals and non-terminals
436
treelet_b = list(zip(treelet_f, treelet_r))
437
if add_constraint:
438
count = 0
439
for binding_index in range(len(treelet_b)):
440
is_terminal_f = (treelet_b[binding_index][0]
441
in terminal_fillers)
442
is_terminal_r = (treelet_b[binding_index][1]
443
in terminal_roles)
444
is_null_f = (treelet_b[binding_index][0] ==
445
null_filler)
446
if (is_terminal_f and is_terminal_r) or \
447
((not is_terminal_f) and (not is_terminal_r)) or (is_null_f):
448
count += 1
449
if count == len(treelet_b):
450
treelet_bindings.append(treelet_b)
451
else:
452
treelet_bindings.append(treelet_b)
453
454
self.treelets = treelet_bindings
455
456
def _construct_pairs(self):
457
458
pair_bindings = []
459
for pair_f in self.fillers.pairs:
460
for pair_r in self.roles.pairs:
461
if pair_f[1] == pair_r[1]:
462
pair_bindings.append(list(zip(pair_f[0], pair_r[0])))
463
464
self.pairs = pair_bindings
465
466
def _convert_tuple_to_str(self):
467
468
for ii, treelet in enumerate(self.treelets):
469
for jj, binding in enumerate(treelet):
470
self.treelets[ii][jj] = binding[0] + '/' + binding[1]
471
472
for ii, pair in enumerate(self.pairs):
473
for jj, binding in enumerate(pair):
474
self.pairs[ii][jj] = binding[0] + '/' + binding[1]
475
476
def _add_rules_binary(self, match='pair'):
477
"""Add binary HG rules"""
478
479
if match == 'treelet':
480
for treelet in self.treelets:
481
self.rules.append([[treelet[0], treelet[1]], 2.0])
482
for ii in range(len(treelet) - 2):
483
self.rules.append([[treelet[1], treelet[ii + 2]], 2.0])
484
485
elif match == 'pair':
486
for pair in self.pairs:
487
self.rules.append([[pair[0], pair[1]], 2.0])
488
489
def _add_rules_unary(self, which="filler"):
490
"""Add unary HG rules"""
491
492
# Case 1
493
# role_type: span_role
494
# unary_base: filler
495
for ii, filler in enumerate(self.fillers.names):
496
if filler in self.fillers.subset_bracketed():
497
self.rules.append([[filler], -1.0 - self.num_branch])
498
elif filler in self.fillers.subset_terminals():
499
self.rules.append([[filler], -1.0])
500
elif filler in self.fillers.subset_root():
501
self.rules.append([[filler], -1.0])
502
elif filler == self.fillers.null:
503
self.rules.append([[filler], self.null_bias])
504
else:
505
self.rules.append([[filler], -2.0])
506
507
def read_rules(self):
508
509
print("Binary rules:\n")
510
for rule in self.rules:
511
if len(rule[0]) == 2:
512
print('H(' + rule[0][0] + ', ' + rule[0][1] + ') = %.1f' % rule[1])
513
514
print("\nUnary rules:\n")
515
for rule in self.rules:
516
if len(rule[0]) == 1:
517
if rule[0][0] in self.fillers.names:
518
print('H(' + rule[0][0] + '/*) = %.1f' % rule[1])
519
elif rule[0][0] in self.roles.names:
520
print('H(*/' + rule[0][0] + ') = %.1f' % rule[1])
521
elif rule[0][0] in self.binding_names:
522
print('H(' + rule[0][0] + ') = %.1f' % rule[1])
523
else:
524
return 0
525
526
def _add_rules_root(self):
527
# empty bias
528
root_fillers = self.grammar.getRootNode(self.grammar.hnfRules)
529
if not isinstance(root_fillers, list):
530
root_fillers = [root_fillers]
531
for filler in root_fillers:
532
self.rules.append([[filler], -1.])
533
return 0
534
535
def reorder_fillers(self, fillers_ordered):
536
if self.fillers.null is not None:
537
if not (self.fillers.null in fillers_ordered):
538
fillers_ordered.append(self.fillers.null)
539
540
if set(self.fillers.names) == set(fillers_ordered):
541
self.fillers.names = fillers_ordered
542
else:
543
sys.exit("The set of filler names constructued ffrom grammar is not same as the set of filler names you provided.")
544
545
self._set_binding_names()
546
547
def get_binary(self):
548
return 0
549
550
def get_unary(self):
551
return 0
552
553
def sort(self):
554
return 0
555
556
def get_terminal_fillers(self):
557
558
terminal_fillers = self.grammar.getTerminalNodes(self.grammar.hnfRules)
559
terminal_fillers.sort()
560
return terminal_fillers
561
562
def get_bracketed_fillers(self):
563
"""Return a list of bracketed filler symbols"""
564
565
pattern = re.compile('.*\[[0-9]+\]$')
566
fillers_bracketed = []
567
for key in self.grammar.hnfRules.keys():
568
if pattern.match(key) is not None:
569
fillers_bracketed.append(key)
570
return fillers_bracketed
571
572
def get_terminal_roles(self):
573
return [str(role).replace(' ', '') for role in self.roles.subset_terminals()]
574
575
def get_bracketed_roles(self):
576
return [str(role).replace(' ', '') for role in self.roles.subset_bracketed()]
577
578
579
class GscNet(object):
580
581
def __init__(self, hg=None, encodings=None, opts=None, seed=None):
582
583
self._set_opts()
584
self._update_opts(opts=opts)
585
586
self._set_encodings()
587
self._update_encodings(encodings=encodings)
588
589
if seed is not None:
590
self.set_seed(seed)
591
self.seed = seed
592
593
self.hg = hg
594
self._add_names()
595
self._generate_encodings()
596
self._compute_TPmat()
597
598
self.WC = np.zeros((self.num_bindings, self.num_bindings))
599
self.bC = np.zeros(self.num_bindings)
600
if hg is not None:
601
self._build_model(hg)
602
self._set_weights()
603
self._set_biases()
604
self._set_quant_list()
605
606
self.actC = np.zeros(self.num_bindings)
607
self.actC_prev = np.zeros(self.num_bindings)
608
self.act = self.C2N()
609
self.act_prev = self.C2N(actC=self.actC_prev)
610
611
self.extC = np.zeros(self.num_bindings)
612
self.extC_prev = np.zeros(self.num_bindings)
613
self.ext = self.C2N(actC=self.extC)
614
self.ext_prev = self.C2N(actC=self.extC_prev)
615
616
if isinstance(self.opts['bowl_center'], numbers.Number):
617
self.bowl_center = (self.opts['bowl_center'] *
618
np.ones(self.num_bindings))
619
else:
620
self.bowl_center = self.opts['bowl_center']
621
if self.opts['bowl_strength'] is None:
622
self.opts['bowl_strength'] = (
623
self._compute_recommended_bowl_strength() +
624
self.opts['beta_min_offset'])
625
else:
626
self.check_bowl_strength()
627
self.bowl_strength = self.opts['bowl_strength']
628
self.zeta = self.C2N(actC=self.bowl_center)
629
630
self.t = 0
631
self.speed = None
632
self.ema_speed = None
633
self.clamped = False
634
self.q = self.opts['q_init']
635
self.T = self.opts['T_init']
636
self.dt = self.opts['dt']
637
638
self.reset()
639
640
########################################################################
641
#
642
# Build a model
643
#
644
########################################################################
645
646
def _set_quant_list(self):
647
quant_list = []
648
for role_name in self.role_names:
649
quant_list.append(self.find_roles(role_name))
650
self.quant_list = quant_list
651
652
def _set_opts(self):
653
654
self.opts = {}
655
self.opts['trace_varnames'] = [
656
'act', 'H', 'H0', 'Q', 'q', 'T', 't', 'ema_speed', 'speed']
657
self.opts['norm_ord'] = np.inf
658
self.opts['coord'] = 'N'
659
self.opts['ema_factor'] = 0.001
660
self.opts['ema_tau'] = -1 / np.log(self.opts['ema_factor'])
661
self.opts['T_init'] = 1e-3
662
self.opts['T_min'] = 0.
663
self.opts['T_decay_rate'] = 1e-3
664
self.opts['q_init'] = 0.
665
self.opts['q_max'] = 200.
666
self.opts['q_rate'] = 10.
667
self.opts['c'] = 0.5
668
self.opts['bowl_center'] = 0.5
669
self.opts['bowl_strength'] = None
670
self.opts['beta_min_offset'] = 0.1
671
self.opts['dt'] = 0.001
672
self.opts['H0_on'] = True
673
self.opts['H1_on'] = True
674
self.opts['Hq_on'] = True
675
self.opts['max_dt'] = 0.01
676
self.opts['min_dt'] = 0.0005
677
self.opts['q_policy'] = None
678
679
def _update_opts(self, opts):
680
"""Update option variable values"""
681
682
if opts is not None:
683
for key in opts:
684
if key in self.opts:
685
self.opts[key] = opts[key]
686
if key == 'ema_factor':
687
self.opts['ema_tau'] = -1 / np.log(self.opts[key])
688
if key == 'ema_tau':
689
self.opts['ema_factor'] = np.exp(-1 / self.opts[key])
690
else:
691
sys.exit('Check [opts]')
692
693
def _set_encodings(self):
694
"""Set encoding variables to default values."""
695
696
self.encodings = {}
697
self.encodings['dp_f'] = 0.
698
self.encodings['dp_r'] = 0.
699
self.encodings['coord_f'] = 'dist'
700
self.encodings['coord_r'] = 'dist'
701
self.encodings['dim_f'] = None
702
self.encodings['dim_r'] = None
703
self.encodings['filler_names'] = None
704
self.encodings['role_names'] = None
705
self.encodings['F'] = None
706
self.encodings['R'] = None
707
self.encodings['similarity'] = None
708
709
def _update_encodings(self, encodings):
710
"""Update encoding variables"""
711
712
if encodings is not None:
713
for key in encodings:
714
if key in self.encodings:
715
self.encodings[key] = encodings[key]
716
717
def _add_names(self):
718
"""Add filler, role, and binding names to the GscNet object"""
719
720
if self.hg is None:
721
if self.encodings['filler_names'] is None:
722
sys.exit("Please provide a list of filler names.")
723
if self.encodings['role_names'] is None:
724
sys.exit("Please provide a list of role names.")
725
self.filler_names = self.encodings['filler_names']
726
self.role_names = self.encodings['role_names']
727
self.binding_names = [f + '/' + r for r in self.role_names for f in self.filler_names]
728
else:
729
if isinstance(self.hg, HarmonicGrammar):
730
self.filler_names = self.hg.fillers.names
731
self.role_names = self.hg.roles.names
732
self.binding_names = self.hg.binding_names
733
else:
734
sys.exit('[hg] is not an instance of HarmonicGrammar class.')
735
736
self.num_fillers = len(self.filler_names)
737
self.num_roles = len(self.role_names)
738
self.num_bindings = len(self.binding_names)
739
740
def _generate_encodings(self):
741
"""Generate vector encodings of fillers, roles, and their bindings"""
742
743
if self.encodings['similarity'] is not None:
744
# Update dp_f and dp_r
745
dp_f = np.diag(np.ones(self.num_fillers))
746
dp_r = np.diag(np.ones(self.num_roles))
747
748
for dp in self.encodings['similarity']:
749
if all(sym in self.filler_names for sym in dp[0]):
750
dp_f[self.filler_names.index(dp[0][0]), self.filler_names.index(dp[0][1])] = dp[1]
751
dp_f[self.filler_names.index(dp[0][1]), self.filler_names.index(dp[0][0])] = dp[1]
752
elif all(sym in self.role_names for sym in dp[0]):
753
dp_r[self.role_names.index(dp[0][0]), self.role_names.index(dp[0][1])] = dp[1]
754
dp_r[self.role_names.index(dp[0][1]), self.role_names.index(dp[0][0])] = dp[1]
755
else:
756
sys.exit('Cannot find some f/r bindings in your similarity list.')
757
758
self.encodings['dp_f'] = dp_f
759
self.encodings['dp_r'] = dp_r
760
761
self.F = encode_symbols(
762
self.num_fillers,
763
coord=self.encodings['coord_f'],
764
dp=self.encodings['dp_f'],
765
dim=self.encodings['dim_f'])
766
767
self.R = encode_symbols(
768
self.num_roles,
769
coord=self.encodings['coord_r'],
770
dp=self.encodings['dp_r'],
771
dim=self.encodings['dim_r'])
772
773
# Overwrite if users provide F and R
774
if self.encodings['F'] is not None:
775
self.F = self.encodings['F']
776
if self.encodings['R'] is not None:
777
self.R = self.encodings['R']
778
779
self.dim_f = self.F.shape[0]
780
self.dim_r = self.R.shape[0]
781
self.num_units = self.dim_f * self.dim_r
782
783
ndigits = len(str(self.num_units))
784
self.unit_names = ['U' + str(ii+1).zfill(ndigits) for ii in list(range(self.num_units))]
785
786
def _build_model(self, hg):
787
"""Set the weight and bias values using a HarmonicGrammar object [hg]."""
788
789
if isinstance(hg, HarmonicGrammar):
790
for rule in hg.rules:
791
if len(rule[0]) == 2: # binary rules
792
self.set_weight(rule[0][0], rule[0][1], rule[1])
793
elif len(rule[0]) == 1: # unary rules
794
if rule[0][0] in self.binding_names:
795
self.set_bias(rule[0][0], rule[1])
796
elif rule[0][0] in self.filler_names:
797
self.set_filler_bias(rule[0][0], rule[1])
798
elif rule[0][0] in self.role_names:
799
self.set_role_bias(rule[0][0], rule[1])
800
else:
801
sys.exit('Check the rule in your Harmonic Grammar:' + rule)
802
else:
803
sys.exit('The given grammar as hg is not an instance of HarmonicGrammar class.')
804
805
if np.allclose(self.W, self.W.T) == False:
806
sys.exit("The weight matrix (2D array) is not symmetric. Please check it.")
807
808
def _compute_TPmat(self):
809
"""Compute the matrices of change of basis from conceptual to neural and from neural to conceptual coordinates.
810
"""
811
812
# TP matrix that converts local to distributed representations (conceptual coordinate to neural coordinate).
813
# See http://en.wikipedia.org/wiki/Vectorization_(mathematics) for justification of kronecker product.
814
TP = np.kron(self.R, self.F) # Pay attention to the argument order.
815
if TP.shape[0] == TP.shape[1]:
816
TPinv = linalg.inv(TP)
817
else:
818
TPinv = linalg.pinv(TP) # TP may be a non-square matrix. So use pseudo-inverse.
819
self.TP = TP
820
self.TPinv = TPinv
821
self.Gc = self.TPinv.T.dot(self.TPinv)
822
823
def _set_weights(self):
824
"""Compute the weight values in the neural space (distributed representation)"""
825
826
self.W = self.TPinv.T.dot(self.WC).dot(self.TPinv)
827
828
def _set_biases(self):
829
"""Compute the bias values in the neural space (distributed representation)"""
830
831
self.b = self.TPinv.T.dot(self.bC)
832
833
def _compute_recommended_bowl_strength(self):
834
"""Compute the recommended value of bowl strength. Note that the value depends on external input."""
835
836
eigvals, eigvecs = np.linalg.eigh(self.WC) # WC should be a symmetric matrix. So eigh() was used instead of eig()
837
eig_max = max(eigvals) # Condition 1: beta > eig_max to be stable
838
if np.sum(abs(self.bowl_center)) > 0:
839
if self.num_bindings == 1:
840
beta1 = -(self.bC + self.extC) / self.bowl_center
841
beta2 = (self.bC + self.extC + eig_max) / (1 - self.bowl_center)
842
else:
843
beta1 = -min((self.bC + self.extC) /self.bowl_center) # Condition 2: beta > beta1
844
beta2 = max((self.bC + self.extC + eig_max) / (1 - self.bowl_center)) # Condition 3: beta > beta2 [CHECK]
845
val = max(eig_max, beta1, beta2)
846
else:
847
val = eig_max
848
849
return val
850
851
def check_bowl_strength(self, disp=True):
852
'''Compute and print the recommended beta value
853
given the weights and biases in the C-space.'''
854
855
beta_min = self._compute_recommended_bowl_strength()
856
if self.opts['bowl_strength'] <= beta_min:
857
sys.exit("Bowl strength should be greater than %.4f." % beta_min)
858
859
if disp:
860
print('(Current bowl strength: %.3f) must be greater than (minimum: %.3f)' % (self.opts['bowl_strength'], beta_min))
861
862
def update_bowl_center(self, bowl_center):
863
"""Update the bowl center
864
865
Usage:
866
867
>>> net.update_bowl_center(0.3) # Set the bowl center to 0.3 * \vec{1}
868
>>>
869
>>> import numpy as np
870
>>> bowl_center = np.random.random(net.num_bindings)
871
>>> net.update_bowl_center(bowl_center)
872
873
: bowl_center: float or 1d NumPy array (size=number of bindings). the bowl center.
874
"""
875
876
if not (isinstance(bowl_center, np.ndarray) or isinstance(bowl_center, numbers.Number)):
877
sys.exit('You must provide a scalar or a NumPy array as bowl_center.')
878
879
if isinstance(bowl_center, numbers.Number):
880
bowl_center = bowl_center * np.ones(self.num_bindings)
881
882
if bowl_center.shape[0] != self.num_bindings:
883
sys.exit('When you provide a NumPy array as bowl_center, it must have the same number of elements as the number of f/r bindings.')
884
885
self.bowl_center = bowl_center
886
self.zeta = self.C2N(actC=self.bowl_center)
887
888
def update_bowl_strength(self, bowl_strength=None):
889
"""Replace the current bowl strength with
890
the recommended bowl strength (+ offset)
891
892
Usage:
893
894
>>> net = gsc.GscNet(...)
895
>>> net.set_weight('a/(0,1)', 'b/(1,2)', 2.0)
896
>>> net.update_bowl_strength()
897
898
: bowl_strength : float or None (=default)
899
"""
900
901
if bowl_strength is None:
902
self.opts['bowl_strength'] = (
903
self._compute_recommended_bowl_strength() +
904
self.opts['beta_min_offset'])
905
else:
906
self.opts['bowl_strength'] = bowl_strength
907
self.bowl_strength = self.opts['bowl_strength']
908
909
def set_weight(self, binding_name1, binding_name2, weight, symmetric=True):
910
'''Set the weight of a connection between binding1 and binding2.
911
When symmetric is set to True (default), the connection weight from
912
binding2 to binding1 is set to the same value.'''
913
914
idx1 = self.find_bindings(binding_name1)
915
idx2 = self.find_bindings(binding_name2)
916
if symmetric:
917
self.WC[idx1, idx2] = self.WC[idx2, idx1] = weight
918
else:
919
self.WC[idx2, idx1] = weight
920
self._set_weights()
921
922
def set_bias(self, binding_name, bias):
923
'''Set bias values of [binding_name] to [bias]'''
924
925
idx = self.find_bindings(binding_name)
926
self.bC[idx] = bias
927
self._set_biases()
928
929
def set_filler_bias(self, filler_name, bias):
930
'''Set the bias of bindings of all roles
931
with particular fillers to [bias].'''
932
933
filler_list = [bb.split('/')[0] for bb in self.binding_names]
934
if not isinstance(filler_name, list):
935
filler_name = [filler_name]
936
for jj, filler in enumerate(filler_name):
937
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
938
self.bC[idx] = bias
939
940
self._set_biases()
941
942
def set_role_bias(self, role_name, bias):
943
'''Set the bias of bindings of all fillers
944
with particular roles to [bias].'''
945
946
role_list = [bb.split('/')[1] for bb in self.binding_names]
947
if not isinstance(role_name, list):
948
role_name = [role_name]
949
for jj, role in enumerate(role_name):
950
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
951
self.bC[idx] = bias
952
953
self._set_biases()
954
955
######################################################################
956
#
957
# Util functions
958
#
959
######################################################################
960
961
def find_bindings(self, binding_names):
962
'''Find the indices of the bindings from the list of binding names.'''
963
964
if not isinstance(binding_names, list):
965
binding_names = [binding_names]
966
return [self.binding_names.index(bb) for bb in binding_names]
967
968
def find_fillers(self, filler_name):
969
970
if not isinstance(filler_name, list):
971
filler_name = [filler_name]
972
973
filler_list = [bb.split('/')[0] for bb in self.binding_names]
974
filler_idx = []
975
for jj, filler in enumerate(filler_name):
976
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
977
filler_idx += idx
978
979
return filler_idx
980
981
def find_roles(self, role_name):
982
983
if not isinstance(role_name, list):
984
role_name = [role_name]
985
986
role_list = [bb.split('/')[1] for bb in self.binding_names]
987
role_idx = []
988
for jj, role in enumerate(role_name):
989
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
990
role_idx += idx
991
992
return role_idx
993
994
def vec2mat(self, actC=None):
995
'''Convert an activation state vector to a matrix form
996
in which each row corresponds to a filler
997
and each column corresponds to a role.'''
998
999
if actC is None:
1000
actC = self.S2C()
1001
return actC.reshape(self.num_fillers, self.num_roles, order='F')
1002
1003
def C2N(self, actC=None):
1004
'''Change basis: from conceptual/pattern to neural space.'''
1005
1006
if actC is None:
1007
actC = self.actC
1008
return self.TP.dot(actC)
1009
1010
def N2C(self, act=None):
1011
'''Change basis: from neural to conceptual/pattern space.'''
1012
1013
if act is None:
1014
act = self.act
1015
return self.TPinv.dot(act)
1016
1017
def read_weight(self, which='WC'):
1018
'''Print the weight matrix in a readable format
1019
(in the pattern coordinate).'''
1020
1021
if which[-1] == 'C':
1022
print(pd.DataFrame(
1023
getattr(self, which), index=self.binding_names,
1024
columns=self.binding_names))
1025
else:
1026
print(pd.DataFrame(
1027
getattr(self, which), index=self.unit_names,
1028
columns=self.unit_names))
1029
1030
def read_bias(self, which='bC', print_vertical=True):
1031
'''Print the bias vector (in the pattern coordinate).'''
1032
1033
if which[-1] == 'C':
1034
if print_vertical:
1035
print(pd.DataFrame(
1036
getattr(self, which).reshape(self.num_bindings, 1),
1037
index=self.binding_names, columns=["bias"]))
1038
else:
1039
print(pd.DataFrame(
1040
getattr(self, which).reshape(1, self.num_bindings),
1041
index=["bias"], columns=self.binding_names))
1042
else:
1043
if print_vertical:
1044
print(pd.DataFrame(
1045
getattr(self, which).reshape(self.num_bindings, 1),
1046
index=self.unit_names, columns=["bias"]))
1047
else:
1048
print(pd.DataFrame(
1049
getattr(self, which).reshape(1, self.num_bindings),
1050
index=["bias"], columns=self.unit_names))
1051
1052
def read_state(self, act=None):
1053
'''Print the current state (C-SPACE) in a readable format.
1054
Pandas should be installed.'''
1055
1056
if act is None:
1057
act = self.act
1058
actC = self.vec2mat(self.N2C(act))
1059
print(pd.DataFrame(
1060
actC, index=self.filler_names, columns=self.role_names))
1061
1062
def read_grid_point(self, act=None, disp=True, skip=True):
1063
'''Print a grid point close to the current state. The grid point will be
1064
chosen by the snap-it method: a filler with the highest activation
1065
value in each role will be chosen.'''
1066
1067
act_min = 0.5
1068
1069
if act is None:
1070
act = self.act
1071
1072
actC = self.vec2mat(self.N2C(act))
1073
winner_idx = np.argmax(actC, axis=0)
1074
winners = [self.filler_names[ii] for ii in winner_idx]
1075
winners = ["%s/%s" % bb for bb in zip(winners, self.role_names)]
1076
1077
if skip:
1078
# if true, do not print null winners nor weak winners
1079
# (whose activation values are smaller than act_min)
1080
roles = [role_num for role_num, filler_num in enumerate(winner_idx)
1081
if (actC[filler_num, role_num] > act_min and
1082
self.filler_names[filler_num] is not self.hg.fillers.null)]
1083
winners = [winners[r] for r in roles]
1084
1085
if disp:
1086
print(winners)
1087
return winners
1088
1089
def get_grid_points(self, n=1e7):
1090
"""Get a list of the top [n] grid points with high harmony values
1091
and compute H_0 at every grid point. Regardless of the [n] value,
1092
the program will check every grid point. Note that the number of
1093
grid points is [num_fillers]^[num_roles] which increases explosively
1094
as [num_roles] increases. This method works only when the total
1095
number of grid points is reasonably small (currently set to 1e7)."""
1096
1097
if self.num_fillers ** self.num_roles > 1e7:
1098
sys.exit('There are too many grid points (= %d) to check all grid points.' % self.num_fillers ** self.num_roles)
1099
1100
if self.num_fillers ** self.num_roles < n:
1101
n = self.num_fillers ** self.num_roles
1102
1103
quant_list = [None] * self.num_roles
1104
for rind, role in enumerate(self.role_names):
1105
quant_list[rind] = [self.binding_names[ii] for ii in self.find_roles(role)]
1106
1107
gpset = []
1108
gpset_h = np.zeros(n)
1109
if self.num_fillers ** self.num_roles > 10000:
1110
print('Of %d grid points: ' % self.num_fillers ** self.num_roles)
1111
1112
for ii, gp in enumerate(itertools.product(*quant_list)):
1113
1114
if (ii + 1) % 1e4 == 0:
1115
print('[%06d]' % (ii + 1), end='')
1116
if (ii + 1) % 1e5 == 0:
1117
print('')
1118
1119
gp = list(gp)
1120
self.set_state(gp)
1121
hh = self.H0()
1122
if ii < n:
1123
gpset_h[ii] = hh
1124
gpset.append(gp)
1125
else:
1126
if hh > np.min(gpset_h):
1127
gpset[np.argmin(gpset_h)] = gp
1128
gpset_h[np.argmin(gpset_h)] = hh
1129
1130
# Sort the grid points in a decreasing order of Hg
1131
idx = np.argsort(gpset_h)[::-1]
1132
self.gpset = [gpset[ii] for ii in idx]
1133
self.gpset_h = gpset_h[idx]
1134
1135
def set_seed(self, num):
1136
'''Set a random number seed.'''
1137
1138
np.random.seed(num)
1139
1140
######################################################################
1141
#
1142
# Harmony
1143
#
1144
######################################################################
1145
1146
def H(self, act=None):
1147
"""Evalutate total harmony"""
1148
1149
return self.Hg(act) + float(self.opts['Hq_on']) * self.q * self.Qa(act)
1150
1151
def Hg(self, act=None):
1152
"""Evalutate H_G (= H0 + H1)"""
1153
1154
return (float(self.opts['H0_on']) * self.H0(act) +
1155
float(self.opts['H1_on']) * self.H1(act)) # + constant
1156
1157
def H0(self, act=None):
1158
"""Evaluate H0"""
1159
1160
if act is None:
1161
act = self.act
1162
return 0.5 * act.dot(self.W).dot(act) + (self.b + self.ext).dot(act)
1163
1164
def H1(self, act=None):
1165
"""Evalutate H1 (bowl harmony)"""
1166
1167
if act is None:
1168
act = self.act
1169
return (self.bowl_strength *
1170
(-0.5 * (act - self.zeta).T.dot(self.Gc).dot(act - self.zeta)))
1171
1172
def Q(self, act=None):
1173
"""Evaluate quantization harmony Q = c * Q0 + (1-c) * Q1"""
1174
1175
return (self.opts['c'] * self.Q0(act) +
1176
(1 - self.opts['c']) * self.Q1(act))
1177
1178
def Qa(self, act=None): # Experimental
1179
"""Evaluate quantization harmony Q = c * Q0 + (1-c) * Q1"""
1180
1181
return self.opts['c'] * self.Q0(act) + (1-self.opts['c']) * self.Q1a(act=act, quant_list=self.quant_list)
1182
1183
def Qb(self, act=None): # Experimental
1184
"""Evaluate quantization harmony Q = c * Q0 + (1-c) * Q1"""
1185
1186
return self.opts['c'] * self.Q0(act) + (1-self.opts['c']) * self.Q1b(act=act, quant_list=self.quant_list)
1187
1188
def Q0(self, act=None):
1189
"""Evaluate Q0"""
1190
1191
if act is None:
1192
act = self.act
1193
actC = self.N2C(act)
1194
return -np.sum(actC**2 * (1 - actC)**2)
1195
1196
def Q1(self, act=None):
1197
"""Evaluate Q1"""
1198
1199
if act is None:
1200
act = self.act
1201
return -np.sum((np.sum(self.vec2mat(self.N2C(act))**2, axis=0) - 1)**2)
1202
1203
def Q1a(self, act=None, quant_list=None):
1204
"""Evaluate Q1 (sum of squared = 1)"""
1205
1206
if act is None:
1207
act = self.act
1208
if quant_list is None:
1209
quant_list = self.quant_list
1210
1211
actC = self.N2C(act)
1212
q1 = 0
1213
for qlist in quant_list:
1214
# ssq = (actC[qlist]**2).sum()
1215
ssq = actC[qlist].dot(actC[qlist])
1216
q1 += (ssq - 1)**2
1217
1218
return -q1
1219
1220
def Q1b(self, act=None, quant_list=None):
1221
"""Evaluate Q1 (sum of squared = 0 or 1)"""
1222
1223
if act is None:
1224
act = self.act
1225
if quant_list is None:
1226
quant_list = self.quant_list
1227
1228
actC = self.N2C(act)
1229
q1 = 0
1230
for qlist in quant_list:
1231
# ssq = (actC[qlist]**2).sum()
1232
ssq = actC[qlist].dot(actC[qlist])
1233
q1 += (ssq - 1)**2 * ssq**2
1234
1235
return -q1
1236
1237
######################################################################
1238
#
1239
# Harmony Gradient
1240
#
1241
######################################################################
1242
1243
def HGrad(self, act=None):
1244
'''Compute the harmony gradient evaluated at the current state'''
1245
1246
return self.HgGrad(act) + float(self.opts['Hq_on']) * self.q * self.QaGrad(act)
1247
1248
def HgGrad(self, act=None):
1249
"""Compute the gradient of grammar harmony H_G"""
1250
1251
return (float(self.opts['H0_on']) * self.H0Grad(act) +
1252
float(self.opts['H1_on']) * self.H1Grad(act))
1253
1254
def H0Grad(self, act=None):
1255
1256
if act is None:
1257
act = self.act
1258
return self.W.dot(act) + self.b + self.ext
1259
1260
def H1Grad(self, act=None):
1261
1262
if act is None:
1263
act = self.act
1264
return (self.bowl_strength *
1265
(-self.Gc.dot(act) + self.Gc.dot(self.zeta)))
1266
1267
def QGrad(self, act=None):
1268
1269
return (self.opts['c'] * self.Q0Grad(act) +
1270
(1 - self.opts['c']) * self.Q1Grad(act))
1271
1272
def QaGrad(self, act=None):
1273
1274
return self.opts['c'] * self.Q0Grad(act) + (1-self.opts['c']) * self.Q1aGrad(act=act, quant_list=self.quant_list)
1275
1276
def QbGrad(self, act=None):
1277
1278
return self.opts['c'] * self.Q0Grad(act) + (1-self.opts['c']) * self.Q1bGrad(act=act, quant_list=self.quant_list)
1279
1280
def Q0Grad(self, act=None):
1281
1282
if act is None:
1283
act = self.act
1284
actC = self.N2C(act)
1285
g = 2 * actC * (1 - actC) * (1 - 2 * actC) # g_{fr} vectorized
1286
return -np.einsum('ij,i', self.TPinv, g)
1287
1288
def Q1Grad(self, act=None):
1289
"""Compute the gradient of quantization harmony (Q1)"""
1290
1291
if act is None:
1292
act = self.act
1293
TPinv_reshaped = self.TPinv.reshape(
1294
(self.num_fillers, self.num_roles, self.num_units), order='F')
1295
actC = self.N2C(act)
1296
amat = self.vec2mat(actC)
1297
term1 = np.einsum('ij->j', amat**2) - 1
1298
term2 = np.einsum('ij,ijk->jk', amat, TPinv_reshaped)
1299
# == in term2 ==
1300
# i: filler index (f)
1301
# j: role index (r)
1302
# k: unit index (phi-rho pair)
1303
return -4 * np.einsum('j,jk', term1, term2)
1304
1305
def Q1aGrad(self, act=None, quant_list=None):
1306
"""Compute the gradient of quantization harmony (Q1)"""
1307
1308
if act is None:
1309
act = self.act
1310
if quant_list is None:
1311
quant_list = self.quant_list
1312
1313
actC = self.N2C(act)
1314
1315
q1grad = 0
1316
for qlist in quant_list:
1317
curr_actC = actC[qlist]
1318
curr_TPinv = self.TPinv[qlist, :]
1319
1320
curr_term1 = (curr_actC**2).sum() - 1
1321
curr_term2 = np.einsum('i,ij->j', curr_actC, curr_TPinv)
1322
q1grad += curr_term1 * curr_term2
1323
1324
return -4 * q1grad
1325
1326
def Q1bGrad(self, act=None, quant_list=None):
1327
"""Compute the gradient of quantization harmony (Q1)"""
1328
1329
if act is None:
1330
act = self.act
1331
if quant_list is None:
1332
quant_list = self.quant_list
1333
1334
actC = self.N2C(act)
1335
1336
q1grad = 0
1337
for qlist in quant_list:
1338
curr_actC = actC[qlist]
1339
curr_TPinv = self.TPinv[qlist, :]
1340
1341
ssq = curr_actC.dot(curr_actC)
1342
curr_term1 = ssq * (ssq - 1) * (2*ssq - 1)
1343
curr_term2 = np.einsum('i,ij->j', curr_actC, curr_TPinv) # CHECK
1344
q1grad += curr_term1 * curr_term2
1345
1346
return -4 * q1grad
1347
1348
######################################################################
1349
#
1350
# Log traces
1351
#
1352
######################################################################
1353
1354
def initialize_traces(self, trace_list):
1355
"""Create storage for traces."""
1356
1357
if trace_list == 'all':
1358
trace_list = self.opts['trace_varnames']
1359
else:
1360
if not isinstance(trace_list, list):
1361
sys.exit(('Check [trace_list] that should be a list object. \n'
1362
'If you want to log a single variable (e.g., H), \n'
1363
'you must provide ["H"], not "H", as the value of [trace_list].'))
1364
1365
var_not_in_varnames = [var for var in trace_list if var not in self.opts['trace_varnames']]
1366
if len(var_not_in_varnames) > 0:
1367
sys.exit(('Check [trace_list]. You provided variable name(s) that are not availalbe in the software.\n'
1368
'Currently, the following variables are available:\n' + self.opts['trace_varnames']))
1369
1370
if hasattr(self, 'traces'):
1371
for key in trace_list:
1372
self.traces[key] = list(self.traces[key])
1373
else:
1374
self.traces = {}
1375
for key in trace_list:
1376
self.traces[key] = []
1377
1378
self.update_traces()
1379
1380
def update_traces(self):
1381
"""Log traces"""
1382
1383
if 'act' in self.traces:
1384
self.traces['act'].append(list(self.act))
1385
if 'H' in self.traces:
1386
self.traces['H'].append(self.H())
1387
if 'H0' in self.traces:
1388
self.traces['H0'].append(self.H0())
1389
if 'Q' in self.traces:
1390
self.traces['Q'].append(self.Q())
1391
if 'q' in self.traces:
1392
self.traces['q'].append(self.q)
1393
if 't' in self.traces:
1394
self.traces['t'].append(self.t)
1395
if 'T' in self.traces:
1396
self.traces['T'].append(self.T)
1397
if 'ema_speed' in self.traces:
1398
self.traces['ema_speed'].append(self.ema_speed)
1399
if 'speed' in self.traces:
1400
self.traces['speed'].append(self.speed)
1401
1402
def finalize_traces(self):
1403
"""Convert list objects of traces to NumPy array objects."""
1404
1405
for key in self.traces:
1406
self.traces[key] = np.array(self.traces[key])
1407
1408
######################################################################
1409
#
1410
# Input (clamp vs. external input)
1411
#
1412
######################################################################
1413
1414
def _compute_projmat(self, A):
1415
"""Compute a projection matrix of a given matrix A. A is an n x m
1416
matrix of basis (column) vectors of the subspace. This function
1417
works only when the rank of A is equal to the nunmber of columns of A.
1418
"""
1419
1420
return A.dot(linalg.inv(A.T.dot(A))).dot(A.T)
1421
1422
def clamp(self, binding_names,
1423
clamp_vals=1.0, clamp_comp=False): # [CHECK]
1424
'''Clamp f/r bindings to [clamp_vals]'''
1425
1426
if not isinstance(clamp_vals, list):
1427
clamp_vals = [clamp_vals]
1428
if not isinstance(binding_names, list):
1429
binding_names = [binding_names]
1430
if len(clamp_vals) > 1:
1431
if len(clamp_vals) != len(binding_names):
1432
sys.exit('The number of bindings clamped is not equal to the number of values provided.')
1433
1434
self.clamped = True
1435
self.binding_names_clamped = binding_names
1436
clampvecC = np.zeros(self.num_bindings)
1437
1438
if clamp_comp:
1439
role_names = [b.split('/')[1] for b in binding_names]
1440
idx1 = self.find_roles(role_names)
1441
clampvecC[idx1] = 0.0
1442
1443
idx = self.find_bindings(binding_names)
1444
clampvecC[idx] = clamp_vals
1445
self.clampvecC = clampvecC
1446
1447
if clamp_comp:
1448
idx += idx1
1449
idx.sort()
1450
1451
idx0 = [bb for bb in np.arange(self.num_bindings) if bb not in idx]
1452
A = self.TP[:, idx0]
1453
if len(idx0) > 0:
1454
self.projmat = self._compute_projmat(A)
1455
else:
1456
self.projmat = np.zeros((self.num_units, self.num_units))
1457
self.clampvec = self.C2N(clampvecC)
1458
self.act = self.act_clamped(self.act)
1459
self.actC = self.N2C()
1460
1461
def unclamp(self):
1462
1463
if self.clamped is True:
1464
del self.clampvec
1465
del self.clampvecC
1466
del self.projmat
1467
del self.binding_names_clamped
1468
self.clamped = False
1469
1470
def act_clamped(self, act=None):
1471
"""Get a new activation vector after projecting an activation vector
1472
to a subspace."""
1473
1474
if act is None:
1475
act = self.act
1476
return self.projmat.dot(act) + self.clampvec
1477
1478
def set_input(self, binding_names, ext_vals, inhib_comp=False):
1479
1480
if not isinstance(ext_vals, list):
1481
ext_vals = [ext_vals]
1482
if not isinstance(binding_names, list):
1483
binding_names = [binding_names]
1484
if len(ext_vals) > 1:
1485
if len(binding_names) != len(ext_vals):
1486
sys.exit("binding_names and ext_vals have different lengths.")
1487
1488
self.clear_input()
1489
1490
idx = self.find_bindings(binding_names)
1491
self.extC[idx] = ext_vals
1492
self.ext = self.C2N(self.extC)
1493
1494
def clear_input(self):
1495
1496
self.extC = np.zeros(self.num_bindings)
1497
self.ext = self.C2N(self.extC)
1498
1499
#######################################################################
1500
#
1501
# Set state
1502
#
1503
#######################################################################
1504
1505
def reset(self):
1506
'''Reset the model. q and T will be set to their initial values'''
1507
1508
self.q = self.opts['q_init']
1509
self.T = self.opts['T_init']
1510
self.t = 0
1511
self.randomize_state()
1512
self.actC = self.N2C()
1513
1514
self.extC_prev = np.zeros(self.num_bindings)
1515
self.ext_prev = self.TP.dot(self.extC_prev)
1516
self.unclamp()
1517
self.clear_input()
1518
if hasattr(self, 'traces'):
1519
del self.traces
1520
1521
def set_state(self, binding_names, vals=1.0):
1522
"""Set state to a particular vector at which the activation values
1523
of the given bindings are set to [vals] (default=1.0)
1524
and the activation values of the other bindings are set to 0."""
1525
1526
idx = self.find_bindings(binding_names)
1527
self.actC = np.zeros(self.num_bindings)
1528
self.actC[idx] = vals
1529
self.act = self.C2N()
1530
1531
def set_init_state(self, mu=0.5, sd=0.2):
1532
1533
self.actC = np.random.normal(loc=mu, scale=sd, size=self.num_bindings)
1534
self.act = self.C2N()
1535
1536
def randomize_state(self, minact=0, maxact=1):
1537
'''Set the activation state to a random vector
1538
inside a hypercube of [minact, maxact]^num_bindings'''
1539
1540
self.actC = np.random.uniform(minact, maxact, self.num_bindings)
1541
self.act = self.C2N(self.actC)
1542
1543
#######################################################################
1544
#
1545
# Update
1546
#
1547
#######################################################################
1548
1549
def run(self, duration, update_T=True, update_q=True, log_trace=True,
1550
trace_list='all', plot=False, tol=None, testvar='ema_speed',
1551
grayscale=False, colorbar=True):
1552
'''Run simulations for a given amount of time [time].'''
1553
1554
self.converged = False
1555
t_max = self.t + duration
1556
1557
self.step = 0
1558
if log_trace:
1559
self.initialize_traces(trace_list)
1560
1561
while self.t < t_max:
1562
self.update(update_T=update_T, update_q=update_q)
1563
if log_trace:
1564
self.update_traces()
1565
1566
if tol is not None:
1567
self.check_convergence(tol=tol, testvar=testvar)
1568
if self.converged:
1569
break
1570
1571
self.rt = self.t
1572
self.extC_prev[:] = self.extC
1573
self.ext_prev[:] = self.TP.dot(self.extC_prev)
1574
1575
if log_trace:
1576
self.finalize_traces()
1577
1578
if log_trace and plot:
1579
actC_trace = self.N2C(self.traces['act'].T).T
1580
times = self.traces['t']
1581
times_new = np.linspace(times[0], times[-1], times.shape[0])
1582
actC_trace_new = []
1583
for b_ind in range(actC_trace.shape[1]):
1584
actC_trace_new.append(
1585
np.interp(times_new, times, actC_trace[:, b_ind]))
1586
actC_trace_new = np.array(actC_trace_new).T
1587
heatmap(
1588
actC_trace_new.T,
1589
xlabel="Time", xtick=False,
1590
ylabel="Bindings", yticklabels=self.binding_names,
1591
grayscale=grayscale, colorbar=colorbar, val_range=[0, 1])
1592
1593
def update(self, update_T=True, update_q=True):
1594
"""Update state, speed, ema_speed (and optionally T, q, dt)"""
1595
1596
self.act_prev[:] = self.act
1597
self.actC_prev[:] = self.actC
1598
self.update_state()
1599
self.update_speed()
1600
1601
if update_T and (self.opts['T_decay_rate'] > 0):
1602
self.update_T()
1603
if update_q:
1604
self.update_q()
1605
1606
def update_state(self):
1607
'''Update state (with noise)'''
1608
1609
grad = self.HGrad()
1610
grad_mag = np.sqrt(grad.dot(grad))
1611
if grad_mag > 0:
1612
self.dt = min(self.opts['max_dt'], self.opts['max_dt'] / grad_mag)
1613
self.dt = max(self.opts['min_dt'], self.dt)
1614
1615
self.t += self.dt
1616
self.act += self.dt * grad
1617
self.add_noise()
1618
if self.clamped:
1619
self.act = self.act_clamped()
1620
self.actC = self.N2C()
1621
1622
def add_noise(self):
1623
'''Add noise to state in neural coordinates.'''
1624
1625
self.act += (np.sqrt(2 * self.T * self.dt) *
1626
np.random.randn(self.num_units))
1627
1628
def update_T(self):
1629
'''Update temperature'''
1630
1631
self.T = (np.exp(-self.opts['T_decay_rate'] * self.dt) *
1632
(self.T - self.opts['T_min']) + self.opts['T_min'])
1633
1634
def update_q(self):
1635
'''Update quantization strength'''
1636
1637
if self.opts['q_policy'] is not None:
1638
self.q = np.interp(
1639
self.t, self.opts['q_policy'][:, 0], self.opts['q_policy'][:, 1])
1640
else:
1641
self.q = max(min(self.q + self.opts['q_rate'] *
1642
self.dt, self.opts['q_max']), 0)
1643
1644
def update_speed(self):
1645
"""Update speed and ema_speed"""
1646
1647
if self.opts['coord'] == 'N':
1648
diff = self.act - self.act_prev
1649
elif self.opts['coord'] == 'C':
1650
diff = self.actC - self.actC_prev
1651
1652
self.speed = linalg.norm(
1653
diff, ord=self.opts['norm_ord']) / abs(self.dt)
1654
if self.ema_speed is None:
1655
self.ema_speed = self.speed
1656
else:
1657
ema_weight = self.opts['ema_factor'] ** abs(self.dt)
1658
self.ema_speed = (ema_weight * self.ema_speed +
1659
(1 - ema_weight) * self.speed)
1660
1661
# See EMA_{eq} in the following document: http://www.eckner.com/papers/ts_alg.pdf
1662
# : tau = -1 / log(ema_factor).
1663
# : ema_factor = exp(-1/ema_tau)
1664
1665
def check_convergence(self, tol, testvar='ema_speed'):
1666
'''Check if the convergence criterion (distance vs. ema_speed) has been satisfied.'''
1667
1668
if testvar == 'ema_speed':
1669
if self.ema_speed < tol:
1670
self.converged = True
1671
1672
if testvar == 'Q':
1673
if self.Q() > tol:
1674
self.converged = True
1675
1676
def plot_state(self, act=None, actC=None, coord='C',
1677
colorbar=True, disp=True, grayscale=True):
1678
"""Plot the activation state (conceptual coordinate) in a heatmap."""
1679
1680
if (act is None) and (actC is None):
1681
act = self.act
1682
actC = self.actC
1683
elif (act is None) and (actC is not None):
1684
act = self.C2N(actC)
1685
elif (act is not None) and (actC is None):
1686
actC = self.N2C(act)
1687
else:
1688
sys.exit('Error. You must pass either act or actC but not both to the function.')
1689
1690
if coord == 'C':
1691
heatmap(
1692
self.vec2mat(actC), xticklabels=self.role_names,
1693
yticklabels=self.filler_names, grayscale=grayscale,
1694
colorbar=colorbar, disp=disp, val_range=[0, 1])
1695
elif coord == 'N':
1696
act_mat = act.reshape((self.dim_f, self.dim_r), order='F')
1697
yticklabels = ['f' + str(ii) for ii in range(self.dim_f)]
1698
xticklabels = ['r' + str(ii) for ii in range(self.dim_r)]
1699
heatmap(
1700
act_mat, xticklabels=xticklabels, yticklabels=yticklabels,
1701
grayscale=grayscale, colorbar=colorbar, disp=disp)
1702
1703
def plot_trace(self, varname):
1704
"""Plot the trace of a given variable"""
1705
1706
x = self.traces['t']
1707
if varname is 'actC':
1708
y = self.N2C(self.traces[varname[:-1]].T).T
1709
else:
1710
y = self.traces[varname]
1711
1712
plt.plot(x, y)
1713
plt.xlabel('Time', fontsize=16)
1714
plt.ylabel(varname, fontsize=16)
1715
plt.grid(True)
1716
plt.show()
1717
1718
1719
def encode_symbols(num_symbols, coord='dist', dp=0., dim=None):
1720
"""Generate the vector encodings of [num_symbols] symbols assuming a given similarity structure.
1721
Each column vector will represent a unique symbol.
1722
1723
Usage:
1724
1725
>>> gsc.encode_symbols(2)
1726
>>> gsc.encode_symbols(3, coord='dist', dp=0.3, dim=5)
1727
1728
: num_symbols : int, number of symbols to encode
1729
: coord : string, 'dist' (distributed representation, default) or 'local' (local representation)
1730
: dp : float (0 [default] <= dp <= 1) or 2D-numpy array of pairwise similarity (dot product)
1731
: dim : int, number of dimensions to encode a symbol. must not be smaller than [num_symbols]
1732
:
1733
: [dp] and [dim] values are ignored if coord is set to 'local'.
1734
"""
1735
1736
if coord == 'local':
1737
sym_mat = np.eye(num_symbols)
1738
else:
1739
if dim is None:
1740
dim = num_symbols
1741
else:
1742
if dim < num_symbols:
1743
sys.exit("The [dim] value must be same as or greater than the [num_symbols] value.")
1744
1745
if isinstance(dp, numbers.Number):
1746
dp = (dp * np.ones((num_symbols, num_symbols)) +
1747
(1 - dp) * np.eye(num_symbols, num_symbols))
1748
1749
sym_mat = dot_products(num_symbols, dim, dp)
1750
1751
return sym_mat
1752
1753
1754
def dot_products(num_symbols, dim, dp_mat, max_iter=100000):
1755
"""Generate a 2D numpy array of random numbers such that the pairwise dot
1756
products of column vectors are close to the numbers specified in [dp_mat].
1757
1758
Don Matthias wrote the original script in MATLAB for the LDNet program.
1759
He explains how this program works as follows:
1760
1761
Given square matrix dpMatrix of dimension N-by-N, find N
1762
dim-dimensional unit vectors whose pairwise dot products match
1763
dpMatrix. Results are returned in the columns of M. itns is the
1764
number of iterations of search required, and may be ignored.
1765
1766
Algorithm: Find a matrix M such that M'*M = dpMatrix. This is done
1767
via gradient descent on a cost function that is the square of the
1768
frobenius norm of (M'*M-dpMatrix).
1769
1770
NOTE: It has trouble finding more than about 16 vectors, possibly for
1771
dumb numerical reasons (like stepsize and tolerance), which might be
1772
fixable if necessary.
1773
"""
1774
1775
if not (dp_mat.T == dp_mat).all():
1776
sys.exit('dot_products: dp_mat must be symmetric')
1777
1778
if (np.diag(dp_mat) != 1).any():
1779
sys.exit('dot_products: dp_mat must have all ones on the main diagonal')
1780
1781
sym_mat = np.random.uniform(
1782
size=dim * num_symbols).reshape(dim, num_symbols, order='F')
1783
min_step = .1
1784
tol = 1e-6
1785
converged = False
1786
for iter_num in range(1, max_iter + 1):
1787
inc = sym_mat.dot(sym_mat.T.dot(sym_mat) - dp_mat)
1788
step = min(min_step, .01 / abs(inc).max())
1789
sym_mat = sym_mat - step * inc
1790
max_diff = abs(sym_mat.T.dot(sym_mat) - dp_mat).max()
1791
if max_diff <= tol:
1792
converged = True
1793
break
1794
1795
if not converged:
1796
print("Didn't converge after %d iterations" % max_iter)
1797
1798
return sym_mat
1799
1800
1801
def heatmap(data, xlabel=None, ylabel=None, xticklabels=None, yticklabels=None,
1802
grayscale=False, colorbar=True, rotate_xticklabels=False,
1803
xtick=True, ytick=True, disp=True, val_range=None):
1804
1805
if grayscale:
1806
cmap = plt.cm.get_cmap("gray_r")
1807
else:
1808
cmap = plt.cm.get_cmap("Reds")
1809
1810
if val_range is not None:
1811
plt.imshow(data, cmap=cmap, vmin=val_range[0], vmax=val_range[1],
1812
interpolation="nearest", aspect='auto')
1813
else:
1814
plt.imshow(data, cmap=cmap, interpolation="nearest", aspect='auto')
1815
1816
if xlabel is not None:
1817
plt.xlabel(xlabel, fontsize=16)
1818
if ylabel is not None:
1819
plt.ylabel(ylabel, fontsize=16)
1820
if xticklabels is not None:
1821
if rotate_xticklabels:
1822
plt.xticks(
1823
np.arange(len(xticklabels)), xticklabels,
1824
rotation='vertical')
1825
else:
1826
plt.xticks(np.arange(len(xticklabels)), xticklabels)
1827
1828
if yticklabels is not None:
1829
plt.yticks(np.arange(len(yticklabels)), yticklabels)
1830
1831
if xtick is False:
1832
plt.tick_params(
1833
axis='x', # changes apply to the x-axis
1834
which='both', # both major and minor ticks are affected
1835
bottom='off', # ticks along the bottom edge are off
1836
top='off', # ticks along the top edge are off
1837
labelbottom='off') # labels along the bottom edge are off
1838
if ytick is False:
1839
plt.tick_params(
1840
axis='y', # changes apply to the x-axis
1841
which='both', # both major and minor ticks are affected
1842
left='off', # ticks along the bottom edge are off
1843
right='off', # ticks along the top edge are off
1844
labelleft='off') # labels along the bottom edge are off
1845
1846
if colorbar:
1847
plt.colorbar()
1848
1849
if disp:
1850
plt.show()
1851
1852
1853
def plot_TP(vec1, vec2, figsize=None):
1854
'''Compute the outer product of two vectors and present it in a diagram.'''
1855
1856
nrow = vec1.shape[0]
1857
ncol = vec2.shape[0]
1858
radius = 0.4
1859
1860
arr = np.zeros((nrow + 1, ncol + 1))
1861
arr[1:, 1:] = np.outer(vec1, vec2)
1862
arr[0, 1:] = vec2
1863
arr[1:, 0] = vec1
1864
1865
if figsize is None:
1866
fig, ax = plt.subplots()
1867
else:
1868
fig, ax = plt.subplots(figsize=figsize)
1869
1870
for ii in range(nrow + 1):
1871
for jj in range(ncol + 1):
1872
if (ii == 0) and (jj == 0):
1873
continue
1874
if (ii == 0) or (jj == 0):
1875
alpha = 1 # 0.3
1876
else:
1877
alpha = 1
1878
1879
if arr[ii, jj] >= 0:
1880
curr_unit = plt.Circle(
1881
(jj, -ii), radius,
1882
color=plt.cm.gray(1 - abs(arr[ii, jj])),
1883
alpha=alpha)
1884
ax.add_artist(curr_unit)
1885
curr_unit = plt.Circle(
1886
(jj, -ii), radius,
1887
color='k', fill=False)
1888
ax.add_artist(curr_unit)
1889
else:
1890
curr_unit = plt.Circle(
1891
(jj, -ii), radius,
1892
color='k', fill=False)
1893
ax.add_artist(curr_unit)
1894
curr_unit = plt.Circle(
1895
(jj, -ii), radius - 0.1,
1896
color=plt.cm.gray(1 - abs(arr[ii, jj])),
1897
alpha=alpha)
1898
ax.add_artist(curr_unit)
1899
curr_unit = plt.Circle(
1900
(jj, -ii), radius - 0.1,
1901
color='k', fill=False)
1902
ax.add_artist(curr_unit)
1903
1904
ax.axis([
1905
0 - radius - 0.6, ncol + radius + 0.6,
1906
- nrow - radius - 0.6, 0 + radius + 0.6])
1907
ax.set_aspect('equal', adjustable='box')
1908
ax.axis('off')
1909
1910
1911
# ============================================================================
1912
# The functions below are not the part of the GSC simulator. They were
1913
# used to compuate the graidnet of quantization harmony in an earlier
1914
# version. I included the functions (1) to check if the newly implemented
1915
# functions return the same values, and (2) to compare the computation
1916
# speed in both versions.
1917
#
1918
# Q0GradE and Q1GradE (elementwise computation) must return the same values
1919
# as Q0GradV, and Q1GradV (partial vectorization).
1920
# ============================================================================
1921
1922
def b_ind(f, r, net):
1923
'''Get a binding index.'''
1924
return f + r * net.num_fillers
1925
1926
1927
def u_ind(phi, rho, net):
1928
'''Get a unit index.'''
1929
return phi + rho * net.dim_f
1930
1931
1932
def w(f, r, phi, rho, net):
1933
# A = net.TPinv
1934
return net.TPinv[b_ind(f, r, net), u_ind(phi, rho, net)]
1935
1936
1937
def get_a(n, net, f, r):
1938
'''Check the activation value of a f/r binding.'''
1939
act = 0
1940
for phi in range(net.dim_f):
1941
for rho in range(net.dim_r):
1942
act += w(f, r, phi, rho, net) * n[u_ind(phi, rho, net)]
1943
return act
1944
1945
1946
def n2a(n, net, f=None, r=None):
1947
# quant_list
1948
if (f is None) and (r is None):
1949
avec = np.zeros(net.num_bindings)
1950
for f in range(net.num_fillers):
1951
for r in range(net.num_roles):
1952
avec[b_ind(f, r, net)] = get_a(n, net, f, r)
1953
return avec
1954
elif (f is None) and (r is not None):
1955
avec = np.zeros(net.num_fillers)
1956
for f in range(net.num_fillers):
1957
avec[f] = get_a(n, net, f, r)
1958
return avec
1959
elif (f is not None) and (r is None):
1960
avec = np.zeros(net.num_roles)
1961
for r in range(net.num_roles):
1962
avec[r] = get_a(n, net, f, r)
1963
return avec
1964
else:
1965
return get_a(n, net, f, r)
1966
1967
1968
def Q0E(net, n):
1969
q0 = 0.0
1970
for f in range(net.num_fillers):
1971
for r in range(net.num_roles):
1972
q0 += n2a(n, net, f=f, r=r)**2 * (1 - n2a(n, net, f=f, r=r))**2
1973
return -q0
1974
1975
1976
def Q0GradE(net, n):
1977
# Elementwise computation. Very slow.
1978
# Based on the first derivation
1979
q0grad = np.zeros(net.num_units)
1980
for phi in range(net.dim_f):
1981
for rho in range(net.dim_r):
1982
q0grad[u_ind(phi, rho, net)] = 0.0
1983
for f in range(net.num_fillers):
1984
for r in range(net.num_roles):
1985
a_fr = n2a(n, net, f, r)
1986
g_fr = 2 * a_fr * (1 - a_fr) * (1 - 2 * a_fr)
1987
q0grad[u_ind(phi, rho, net)] += w(
1988
f, r, phi, rho, net) * g_fr
1989
return -q0grad
1990
1991
1992
def Q1E(net, n):
1993
q1 = 0.0
1994
for r in range(net.num_roles):
1995
q1 += (np.sum(n2a(n, net, r=r)**2) - 1)**2
1996
return -np.sum((np.sum(net.vec2mat(n2a(n, net))**2, axis=0) - 1)**2)
1997
1998
1999
def Q1GradE(net, n): # Elementwise computation
2000
q1grad = np.zeros(net.num_units)
2001
for phi in range(net.dim_f):
2002
for rho in range(net.dim_r):
2003
unit_grad = 0.0
2004
for r in range(net.num_roles):
2005
var1 = np.sum(n2a(n, net, r=r)**2) - 1
2006
var2 = 0.0
2007
for f in range(net.num_fillers):
2008
var2 += n2a(n, net, f, r) * w(f, r, phi, rho, net)
2009
unit_grad += 4 * var1 * var2
2010
q1grad[u_ind(phi, rho, net)] = unit_grad
2011
return -q1grad
2012
2013
2014
def Q0GradV(net, n):
2015
# Based on the first derivation
2016
a = net.N2C(n)
2017
g = 2 * a * (1 - a) * (1 - 2 * a) # a vectorized version of g_{fr}
2018
gmat = np.tile(g, (net.num_units, 1)).T
2019
q0grad = np.sum(net.TPinv * gmat, axis=0)
2020
return -q0grad
2021
2022
2023
def Q1GradV(net, n):
2024
a = net.N2C(n)
2025
q1grad = 0.0
2026
for r_ind, rr in enumerate(net.role_names):
2027
curr_binding_ind = net.find_roles(rr)
2028
amat = np.tile(a[curr_binding_ind], (net.num_units, 1)).T
2029
term2 = np.sum(net.TPinv[curr_binding_ind, :] * amat, axis=0)
2030
term1 = np.sum(a[curr_binding_ind] ** 2) - 1
2031
q1grad += term1 * term2
2032
q1grad = 4 * q1grad
2033
return -q1grad
2034
2035