Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download
436 views
1
# -*- coding: utf-8 -*-
2
"""
3
GSC Software:
4
Author: Pyeong Whan Cho
5
Department of Cognitive Science, Johns Hopkins University
6
Summer 2015
7
(based on the MATLAB program, LDNet2.0)
8
"""
9
from math import *
10
import numpy as np
11
from matplotlib import pyplot as plt
12
from scipy import optimize as optim
13
import pandas as pd
14
import pylab as P
15
from scipy import linalg as la
16
from numbers import Number
17
import sys
18
import copy
19
import itertools
20
import pprint
21
22
23
# Nested dictionary
24
class Vividict(dict):
25
#http://stackoverflow.com/questions/16333296/how-do-you-create-nested-dict-in-python
26
def __missing__(self, key):
27
value = self[key] = type(self)()
28
return value
29
30
class RefSet():
31
def __init__(self):
32
self.dict = Vividict()
33
34
def add(self, key, val, isSym=True, isTree=False, isSpanRole=False):
35
if isTree:
36
if isSpanRole:
37
sys.exit('Not yet implemented. Check tree.py.')
38
else:
39
self.dict[key]['sym'] = ['/'.join(binding) for binding in val.recursiveFRtoString]
40
else:
41
if isSym:
42
# val: a list of f/r bindings
43
self.dict[key]['sym'] = val
44
else:
45
# val: a state vector (1d-array) in the pattern coordinate
46
self.dict[key]['num'] = val
47
48
def disp(self):
49
pprint.pprint(self.dict)
50
51
52
53
class GscNet():
54
def __init__(self, filler_names, role_names, grammar=None,
55
WGC=None, bGC=None, extC=None, prev_extC=None, z=0.5, beta=4,
56
q_init=0, q_rate=4, q_max=50, q_fun='plinear', c=0.5, # check q_max
57
T_init=0, dt=0.001, reptype_r='local', reptype_f='local',
58
dp_r=0, dp_f=0, ndim_r=None, ndim_f=None,
59
T_decay_rate=0, T_min=0, quant_list=None, grid_points=None,
60
F=None, R=None, getGPset=False):
61
'''Construct an instance of GscNet.'''
62
# filler_names: a list of filler names
63
# role_names: a list of role names
64
# hgRules: a list of harmonic grammar rules
65
66
# Local representation
67
self.binding_names = bind(filler_names, role_names)
68
self.filler_names = filler_names
69
self.role_names = role_names
70
self.nbindings = len(self.binding_names) # number of bindings
71
self.nfillers = len(filler_names) # number of fillers
72
self.nroles = len(role_names) # number of roles
73
74
# Network parameters (C-space) --- local representation
75
if WGC is None:
76
WGC = np.zeros(self.nbindings**2).reshape(self.nbindings, self.nbindings, order='F')
77
if bGC is None:
78
bGC = np.zeros(self.nbindings)#.reshape(self.nbindings)
79
if extC is None:
80
extC = np.zeros(self.nbindings)#.reshape(self.nbindings)
81
if prev_extC is None:
82
prev_extC = np.zeros(self.nbindings)#.reshape(self.nbindings)
83
84
self.WGC = WGC # weight matrix
85
self.bGC = bGC # bias vector
86
self.extC = extC # external input
87
self.prev_extC = prev_extC
88
89
# Network parameters (S-space) --- distributed representation
90
# If 'local' is chosen, these parameters would be same as the above parameters.
91
if ndim_r is None:
92
ndim_r = self.nroles
93
if ndim_f is None:
94
ndim_f = self.nfillers
95
96
self.reptype_r = reptype_r
97
self.reptype_f = reptype_f
98
self.ndim_r = ndim_r
99
self.ndim_f = ndim_f
100
101
# R: a matrix whose column vectors correspond to roles.
102
# F: a matrix whose column vectors correspond to fillers.
103
# TP: a matrix whose column vectors correspond to bindings.
104
# it is used for a state conversion from C- to S-space.
105
# TPinv: the inverse of TP. Used for a state conversion from S- to C-Space.
106
# for now, only works for symmetric TP.
107
self.R = encode_symbols(self.nroles, reptype=reptype_r, dp=dp_r, ndim=ndim_r)
108
self.F = encode_symbols(self.nfillers, reptype=reptype_f, dp=dp_f, ndim=ndim_f)
109
110
# If either R or F is provided, overwrite self.R or self.F.
111
if R is not None:
112
self.R = R
113
if F is not None:
114
self.F = F
115
116
self.TP, self.TPinv = compute_TPmat(self.R, self.F) # Pay attention to the argument order.
117
118
self.nunits = ndim_r * ndim_f # CHECK: should be same as the number of rows of TP
119
self.act = np.zeros(self.nunits) # activation state in S-space.
120
self.actC = self.S2C()
121
122
# self.W = compute_sspace_weights(self.WC, self)
123
# self.b = compute_sspace_biases(self.bC, self)
124
# self.ext = compute_sspace_biases(self.extC, self) # CHECK:
125
126
# Add the bowl with bowl parameters (C-space is assumed)
127
self.z = z # bowl center: vector or scalar
128
self.beta = beta # bowl strength
129
self.WBC = -self.beta * np.eye(self.nbindings)
130
self.bBC = self.beta * self.z * np.ones(self.nbindings)
131
#self.check_beta()
132
133
self.WC = self.WGC + self.WBC
134
self.bC = self.bGC + self.bBC
135
136
self.WG = compute_sspace_weights(self.WGC, self)
137
self.WB = compute_sspace_weights(self.WBC, self)
138
self.bG = compute_sspace_biases(self.bGC, self)
139
self.bB = compute_sspace_biases(self.bBC, self)
140
141
self.W = compute_sspace_weights(self.WC, self)
142
self.b = compute_sspace_biases(self.bC, self)
143
self.zeta = self.TP.dot(self.z * np.ones(self.nbindings))
144
self.ext = compute_sspace_biases(self.extC, self)
145
self.prev_ext = compute_sspace_biases(self.prev_extC, self)
146
147
self.check_beta()
148
149
# Quantization dynamics parameter
150
self.q_init = q_init # quantization strength
151
self.q = copy.deepcopy(self.q_init)
152
self.q_rate = q_rate
153
self.q_fun = q_fun
154
self.q_max = q_max # used when qfun = 'linear'
155
self.q_time = 0
156
self.c = c # quantization parameter
157
158
# Update parameters
159
self.T_init = T_init
160
self.T = copy.deepcopy(self.T_init) # temperature
161
self.T_decay_rate = T_decay_rate # temperature decay rate
162
self.T_min = T_min
163
self.dt = dt # time step constant
164
165
# Clamping state
166
self.clamped = False
167
self.act_trace = None
168
self.quant_list = quant_list
169
170
# Test if W is symmetric
171
if np.allclose(self.W, self.W.T) == False:
172
sys.exit("The weight matrix (2D array) is not symmetric. Please check it.")
173
174
# unit labels
175
ndigits = len(str(self.nunits))
176
self.unit_names = ['U' + str(ii+1).zfill(ndigits) for ii in list(range(self.nunits))]
177
178
self.speed = None
179
self.ema_speed = None
180
self.grid_points = grid_points
181
182
self.getGPset = getGPset
183
if getGPset:
184
self.all_grid_points()
185
186
def randomize_state(self, minact=0, maxact=1):
187
'''Set the activation state to a random vector inside a unit hypercube'''
188
actC = np.random.uniform(minact, maxact, self.nbindings)
189
self.act = self.C2S(actC)
190
191
def set_init_state(self, mu=0.2, sd=0.2):
192
actC = np.random.normal(loc=mu, scale=sd, size=self.nbindings)
193
self.act = self.C2S(actC)
194
195
def randomize_weights(self, mu=0, sigma=1):
196
'''Randomize WC'''
197
WC = sigma*np.random.randn(self.nbindings**2).reshape(self.nbindings, self.nbindings, order='F') + mu
198
WC = (WC + WC.T)/2
199
self.WC = WC
200
self.W = compute_sspace_weights(self.WC, self)
201
202
def all_grid_points(self):
203
if self.quant_list is None:
204
quant_list = [None] * self.nroles
205
for rind, role in enumerate(self.role_names):
206
quant_list[rind] = [self.binding_names[ii] for ii in self.find_roles(role)]
207
else:
208
quant_list = self.quant_list
209
210
gpset = list(itertools.product(*quant_list))
211
gpset = [list(gp) for ii, gp in enumerate(gpset)]
212
213
#all_grid_points = [None] * len(temp)
214
gpset_gh = np.zeros(len(gpset))
215
for ii, gp in enumerate(gpset):
216
actC = np.zeros(self.nbindings)
217
#actC[self.find_bindings(list(gp))] = 1.0
218
actC[self.find_bindings(gp)] = 1.0
219
gpset_gh[ii] = self.Hg(act=self.C2S(actC))
220
221
idx = np.argsort(gpset_gh)[::-1]
222
self.gpset = [gpset[ii] for ii in idx]
223
self.gpset_gh = gpset_gh[idx]
224
225
##########################################################################
226
###
227
### set weights and biases
228
###
229
##########################################################################
230
231
def set_weight(self, binding_name1, binding_name2, weight, symmetric=True):
232
'''Set the weight of a connection between binding1 and binding2.
233
When symmetric is set to True (default), the connection weight from
234
binding2 to binding1 is set to the same value.'''
235
idx1 = self.find_bindings(binding_name1)
236
idx2 = self.find_bindings(binding_name2)
237
if symmetric:
238
self.WGC[idx1, idx2] = self.WGC[idx2, idx1] = weight
239
else:
240
self.WGC[idx2, idx1] = weight
241
self.WC = self.WGC + self.WBC
242
self.WG = compute_sspace_weights(self.WGC, self)
243
self.W = compute_sspace_weights(self.WC, self)
244
245
def set_bias(self, binding_name, bias):
246
'''Set the bias of a binding to [bias].'''
247
idx = self.find_bindings(binding_name)
248
self.bGC[idx] = bias
249
self.bC = self.bGC + self.bBC
250
self.bG = compute_sspace_biases(self.bGC, self)
251
self.b = compute_sspace_biases(self.bC, self)
252
253
def set_role_bias(self, role_name, bias):
254
'''Set the bias of bindings of all fillers with particular roles to [bias].'''
255
role_list = [bb.split('/')[1] for bb in self.binding_names]
256
if not isinstance(role_name, list):
257
role_name = [role_name]
258
for jj, role in enumerate(role_name):
259
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
260
self.bGC[idx] = bias
261
262
self.bC = self.bGC + self.bBC
263
self.bG = compute_sspace_biases(self.bGC, self)
264
self.b = compute_sspace_biases(self.bC, self)
265
266
def set_filler_bias(self, filler_name, bias):
267
'''Set the bias of bindings of all roles with particular fillers to [bias].'''
268
filler_list = [bb.split('/')[0] for bb in self.binding_names]
269
if not isinstance(filler_name, list):
270
filler_name = [filler_name]
271
for jj, filler in enumerate(filler_name):
272
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
273
self.bGC[idx] = bias
274
self.bC = self.bGC + self.bBC
275
self.bG = compute_sspace_biases(self.bGC, self)
276
self.b = compute_sspace_biases(self.bC, self)
277
278
# set_state()
279
def set_state(self, binding_names, vals=1.0):
280
idx = self.find_bindings(binding_names)
281
self.actC = np.zeros(self.nbindings)
282
self.actC[idx] = vals
283
self.act = self.C2S()
284
285
##########################################################################
286
###
287
### Etc
288
###
289
##########################################################################
290
291
def vec2mat(self, act=None):
292
if act is None:
293
act = self.S2C()
294
return act.reshape(self.nfillers, self.nroles, order='F')
295
296
def C2S(self, C=None):
297
'''Change basis: from C- to S-space.'''
298
if C is None:
299
C = self.actC
300
return self.TP.dot(C)
301
302
def S2C(self, S=None):
303
'''Change basis: from S- to C-space.'''
304
if S is None:
305
S = self.act
306
return self.TPinv.dot(S)
307
308
def find_bindings(self, binding_names):
309
'''Find the indices of the bindings from the list of binding names.'''
310
311
if not isinstance(binding_names, list):
312
binding_names = [binding_names]
313
return [self.binding_names.index(bb) for bb in binding_names]
314
315
def find_roles(self, role_name):
316
317
if not isinstance(role_name, list):
318
role_name = [role_name]
319
320
role_list = [bb.split('/')[1] for bb in self.binding_names]
321
role_idx = []
322
for jj, role in enumerate(role_name):
323
#idx = [ii for ii, rr in enumerate(role_list) if role_name == rr]
324
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
325
role_idx += idx
326
#idx_all = [[ii for ii, rr in enumerate(role_list) if role == rr] for jj, role in enumerate(role_name)]
327
return role_idx
328
329
def find_fillers(self, filler_name):
330
331
if not isinstance(filler_name, list):
332
filler_name = [filler_name]
333
334
filler_list = [bb.split('/')[0] for bb in self.binding_names]
335
filler_idx = []
336
for jj, filler in enumerate(filler_name):
337
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
338
filler_idx += idx
339
return filler_idx
340
341
def read_state(self, act=None):
342
'''Print the current state (C-SPACE) in a readable format. Pandas should be installed.'''
343
344
if act is None:
345
act = self.act
346
#actC = self.S2C().reshape(self.nfillers, self.nroles, order='F')
347
actC = self.vec2mat(self.S2C(act))
348
print(pd.DataFrame(actC, index=self.filler_names, columns=self.role_names))
349
350
def read_grid_point(self, act=None, disp=True):
351
'''Print a grid point close to the current state. The grid point will be
352
chosen by the snapping method (winner-takes-it-all).'''
353
354
if act is None:
355
act = self.act
356
357
if self.quant_list is None:
358
# By default, it is assumed that a role can be bound with only one filler.
359
actC = self.vec2mat(self.S2C(act))
360
winner_idx = np.argmax(actC, axis=0)
361
winners = [self.filler_names[ii] for ii in winner_idx]
362
winners = ["%s/%s" % bb for bb in zip(winners, self.role_names)]
363
else:
364
# quant_list: a list of lists of binding names.
365
actC = self.S2C(act) # pattern coordinate
366
winners = []
367
for kk, group in enumerate(self.quant_list):
368
idx = self.find_bindings(group)
369
winner_idx = [idx[ii] for ii, jj in enumerate(actC[idx]) if jj == max(actC[idx])]
370
winner = [self.binding_names[jj] for ii, jj in enumerate(winner_idx)]
371
winners.extend(winner)
372
if disp:
373
print(winners)
374
return winners
375
376
def read_weight(self, which='WGC'):
377
'''Print the weight matrix in a readable format (in the pattern coordinate).'''
378
if which[-1] == 'C':
379
print(pd.DataFrame(getattr(self, which), index=self.binding_names, columns=self.binding_names))
380
else:
381
print(pd.DataFrame(getattr(self, which), index=self.unit_names, columns=self.unit_names))
382
383
def read_bias(self, which='bGC', print_vertical=True):
384
'''Print the bias vector (in the pattern coordinate).'''
385
if which[-1] == 'C':
386
if print_vertical:
387
print(pd.DataFrame(getattr(self, which).reshape(self.nbindings, 1), index=self.binding_names, columns=["bias"]))
388
else:
389
print(pd.DataFrame(getattr(self, which).reshape(1, self.nbindings), index=["bias"], columns=self.binding_names))
390
else:
391
if print_vertical:
392
print(pd.DataFrame(getattr(self, which).reshape(self.nbindings, 1), index=self.unit_names, columns=["bias"]))
393
else:
394
print(pd.DataFrame(getattr(self, which).reshape(1, self.nbindings), index=["bias"], columns=self.unit_names))
395
396
def hinton(self, which='WGC', label=True):
397
'''Draw a hinton diagram'''
398
# CHECK: which of W or W-beta*I should be visualized?
399
if label:
400
if which[-1] == 'C':
401
labels=self.binding_names
402
else:
403
labels=self.unit_names
404
else:
405
labels=None
406
hinton(getattr(self, which), xlabels=labels, ylabels=labels)
407
408
def act_clamped(self, act=None):
409
'''S-space'''
410
if act is None:
411
act = self.act
412
return self.projmat.dot(act) + self.clampvec
413
414
def check_beta(self, disp=False):
415
'Compute and print the recommended beta value given the weights and biases in the C-space.'''
416
eigvals, eigvecs = la.eigh(self.WGC) # WGC should be a symmetric matrix. So eigh() was used instead of eig()
417
eig_max = max(eigvals) # Condition 1: beta > eig_max to be stable
418
if self.nbindings == 1:
419
beta1 = -(self.bGC+self.extC)/self.z
420
beta2 = (self.bGC+self.extC+eig_max)/(1-self.z)
421
else:
422
beta1 = -min(self.bGC+self.extC)/self.z # Condition 2: beta > beta1
423
beta2 = (max(self.bGC+self.extC)+eig_max)/(1-self.z) # Condition 3: beta > beta2
424
beta_min = max(eig_max, beta1, beta2)
425
if self.beta <= beta_min:
426
sys.exit("Beta (bowl strength) should be greater than %.4f." % beta_min)
427
if disp:
428
print('Recommended beta min = %.3f' % beta_min)
429
430
def info(self):
431
'''Print the network information.'''
432
print('Fillers = ', self.filler_names)
433
print('Roles = ', self.role_names)
434
print('Num_of_units = ', self.nunits)
435
print('Current T = ', self.T)
436
# Add more information ...
437
438
def compute_dist(self, ref_point, norm_ord=2, space='S'):
439
"""
440
Compute the distance of the current state from a grid point.
441
[grid point] is a set of bindings.
442
"""
443
444
idx = self.find_bindings(ref_point)
445
destC = np.zeros(self.nbindings)
446
destC[idx] = 1.0
447
448
if space == 'S':
449
state1 = self.act
450
state2 = self.C2S(destC)
451
elif space == 'C':
452
state1 = self.S2C(self.act)
453
state2 = destC
454
455
return np.linalg.norm(state1-state2, ord=norm_ord)
456
457
458
##########################################################################
459
###
460
### Set or clear external input / Clamp or unclamp
461
###
462
##########################################################################
463
464
def set_input(self, binding_names, ext_vals, inhib_comp=False):
465
'''Set external input.'''
466
467
# priming effect [CHECK]
468
# self.prev_extC = copy.deepcopy(self.extC) # move to run()
469
470
if not isinstance(ext_vals, list):
471
ext_vals = [ext_vals]
472
if not isinstance(binding_names, list):
473
binding_names = [binding_names]
474
if len(ext_vals) > 1:
475
if len(binding_names) != len(ext_vals):
476
sys.exit("binding_names and ext_vals have different lengths.")
477
478
self.clear_input() # Consider removing this line.
479
if inhib_comp:
480
# Check qunat_list
481
if self.quant_list is None:
482
# Assume filler competition in each role.
483
# Find the current role.
484
role_names = [b.split('/')[1] for b in binding_names]
485
idx = self.find_roles(role_names)
486
self.extC[idx] = -np.asarray(ext_vals) # -ext_vals (list object)
487
else:
488
for ii, bb in enumerate(binding_names):
489
group_id = [jj for jj, group in enumerate(self.quant_list) if bb in group]
490
for kk, gid in enumerate(group_id):
491
curr_bindings = self.quant_list[gid]
492
idx = self.find_bindings(curr_bindings)
493
self.extC[idx] = -np.asarray(ext_vals) # -ext_vals (list object)
494
495
idx = self.find_bindings(binding_names)
496
self.extC[idx] = ext_vals
497
self.ext = compute_sspace_biases(self.extC, self)
498
499
def clear_input(self):
500
'''Remove external input.'''
501
self.extC = np.zeros(self.nbindings)
502
self.ext = compute_sspace_biases(self.extC, self)
503
504
def clamp(self, binding_names, clamp_vals=1.0, clamp_comp=False):
505
'''Clamp f/r bindings'''
506
# A matrix of basic vectors each of which corresponds to a filler/role
507
# binding whose activation state can change.
508
if not isinstance(clamp_vals, list):
509
clamp_vals = [clamp_vals]
510
if not isinstance(binding_names, list):
511
binding_names = [binding_names]
512
if len(clamp_vals) > 1:
513
if len(clamp_vals) != len(binding_names):
514
sys.exit('The number of bindings clamped is not equal to the number of values provided.')
515
516
self.clamped = True
517
self.binding_names_clamped = binding_names
518
clampvecC = np.zeros(self.nbindings)
519
520
# ADD
521
if clamp_comp:
522
# Check qunat_list
523
if self.quant_list is None:
524
# Assume filler competition in each role.
525
# Find the current role.
526
role_names = [b.split('/')[1] for b in binding_names]
527
idx1 = self.find_roles(role_names)
528
clampvecC[idx1] = 0.0
529
else:
530
for ii, bb in enumerate(binding_names):
531
group_id = [jj for jj, group in enumerate(self.quant_list) if bb in group]
532
for kk, gid in enumerate(group_id):
533
curr_bindings = self.quant_list[gid]
534
idx1 = self.find_bindings(curr_bindings)
535
clampvecC[idx1] = 0.0
536
537
idx = self.find_bindings(binding_names)
538
clampvecC[idx] = clamp_vals
539
self.clampvecC = clampvecC
540
541
if clamp_comp:
542
idx += idx1 # CHECK
543
idx.sort()
544
545
# Choose unclamped bindings. --- free to vary.
546
idx0 = [bb for bb in np.arange(self.nbindings) if bb not in idx]
547
A = self.TP[:, idx0] # complement set of idx (basis vectors of the subspace)
548
if len(idx0) > 0:
549
self.projmat = compute_projmat(A)
550
else:
551
self.projmat = np.zeros((self.nunits, self.nunits))
552
self.clampvec = self.C2S(clampvecC)
553
self.act = self.act_clamped(self.act)
554
self.actC = self.S2C()
555
556
def unclamp(self):
557
if self.clamped is True:
558
del self.clampvec
559
del self.clampvecC
560
del self.projmat
561
del self.binding_names_clamped
562
self.clamped = False
563
564
##########################################################################
565
###
566
### Harmony and harmony gradient
567
###
568
##########################################################################
569
570
# def harmony(self, act=None, quant_list=None):
571
# '''Compute the total harmony of the current state'''
572
# if act is None:
573
# act = self.act
574
#
575
# if quant_list is None:
576
# quant_list = self.quant_list
577
#
578
# return harmony(net=self, act=act, quant_list=quant_list)
579
580
# def H(self, act=None, quant_list=None):
581
# '''Compute the total harmony of the current state'''
582
# if act is None:
583
# act = self.act
584
# if quant_list is None:
585
# quant_list = self.quant_list
586
# return H(net=self, act=act, quant_list=quant_list)
587
#
588
# def HGrad(self, act=None, quant_list=None):
589
# '''Compute the harmony gradient evaluated at the current state'''
590
# if act is None:
591
# act = self.act
592
# if quant_list is None:
593
# quant_list = self.quant_list
594
# return HGrad(net=self, act=act, quant_list=quant_list)
595
596
def H(self, act=None, quant_list=None):
597
'''Compute the total harmony of the current state'''
598
if act is None:
599
act = self.act
600
if quant_list is None:
601
quant_list = self.quant_list
602
return Hg(net=self, act=act) + self.q * Q(net=self, n=act)
603
604
def Hg(self, act=None):
605
'''Compute the grammar harmony of the current state'''
606
if act is None:
607
act = self.act
608
return Hg(net=self, act=act)
609
610
def Hq(self, act=None, quant_list=None):
611
if act is None:
612
act = self.act
613
if quant_list is None:
614
quant_list = self.quant_list
615
return Q(net=self, n=act)
616
617
def H0(self, act=None):
618
'''Compute the grammar harmony of the current state'''
619
if act is None:
620
act = self.act
621
return H0(net=self, act=act)
622
623
def H1(self, act=None):
624
'''Compute the grammar harmony of the current state'''
625
if act is None:
626
act = self.act
627
return H1(net=self, act=act)
628
629
def HGrad(self, act=None, quant_list=None):
630
'''Compute the harmony gradient evaluated at the current state'''
631
if act is None:
632
act = self.act
633
if quant_list is None:
634
quant_list = self.quant_list
635
return HgGrad(net=self, act=act) + self.q * QGrad(net=self, n=act)
636
637
def HgGrad(self, act=None):
638
'''Compute the grammar harmony of the current state'''
639
if act is None:
640
act = self.act
641
return HgGrad(net=self, act=act)
642
643
def H0Grad(self, act=None, quant_list=None):
644
'''Compute the harmony gradient evaluated at the current state'''
645
if act is None:
646
act = self.act
647
if quant_list is None:
648
quant_list = self.quant_list
649
return H0Grad(net=self, act=act, quant_list=quant_list)
650
651
def H1Grad(self, act=None):
652
'''Compute the grammar harmony of the current state'''
653
if act is None:
654
act = self.act
655
return H1Grad(net=self, act=act)
656
657
def HqGrad(self, act=None, quant_list=None):
658
if act is None:
659
act = self.act
660
if quant_list is None:
661
quant_list = self.quant_list
662
return QGrad(net=self, n=act)
663
664
# def oharmony(self, act=None):
665
# '''Compute the grammar harmony of the current state'''
666
# if act is None:
667
# act = self.act
668
# return oharmony(net=self, act=act)
669
670
# def qharmony(self, act=None, quant_list=None):
671
# '''Compute the grammar harmony of the current state'''
672
# if act is None:
673
# act = self.act
674
# if quant_list is None:
675
# quant_list = self.quant_list
676
# return q2harmony(net=self, act=act, quant_list=quant_list)
677
678
# def Hq(self, act=None, quant_list=None):
679
# '''Compute the grammar harmony of the current state'''
680
# if act is None:
681
# act = self.act
682
# if quant_list is None:
683
# quant_list = self.quant_list
684
# return Hq(net=self, act=act, quant_list=quant_list)
685
686
# def gharmony(self, act=None):
687
# '''Compute the grammar harmony of the current state'''
688
# if act is None:
689
# act = self.act
690
# return gharmony(net=self, act=act)
691
692
# def bharmony(self, act=None):
693
# '''Compute the grammar harmony of the current state'''
694
# if act is None:
695
# act = self.act
696
# return bharmony(net=self, act=act)
697
698
# def hgrad(self, act=None, quant_list=None):
699
# '''Compute the harmony gradient evaluated at the current state'''
700
# if act is None:
701
# act = self.act
702
# if quant_list is None:
703
# quant_list = self.quant_list
704
# return hgrad(net=self, act=act, quant_list=quant_list)
705
706
##########################################################################
707
###
708
### Update state and parameters
709
###
710
##########################################################################
711
712
#def update(self, update_T=True, update_q=True, q_fun='log', quant_list=None):
713
def update(self, update_T=True, update_q=True, q_fun=None, space='S', norm_ord=2):
714
715
if q_fun is None:
716
q_fun = self.q_fun
717
718
self.prev_act = copy.deepcopy(self.act)
719
self.update_state()
720
self.update_speed(space=space, norm_ord=norm_ord)
721
if self.grid_points is not None:
722
self.update_dist(self.grid_points, space=space, norm_ord=norm_ord)
723
724
if update_T:
725
self.update_T()
726
if update_q:
727
self.update_q(fun=q_fun)
728
729
def update_state(self):
730
'''Update the current state (with noise)'''
731
732
#self.act += self.dt * self.hgrad(quant_list=self.quant_list) # update state (deterministic change)
733
self.act += self.dt * self.HGrad(quant_list=self.quant_list) # update state (deterministic change)
734
self.add_noise() # add noise (stochastic part)
735
if self.clamped:
736
self.act = self.act_clamped()
737
self.actC = self.S2C()
738
739
def add_noise(self):
740
'''Add a noise to the current activation state.'''
741
noise = sqrt(2 * self.T * self.dt) * np.random.randn(self.nunits)
742
self.act += noise
743
744
def update_T(self, method='exponential'):
745
'''Update the current temperature'''
746
# In the LDNet, the decay rate per timestep is adjusted based on the time step constant.
747
# For simplicity, the decay rate is assumed to be per-step.
748
if method == 'exponential':
749
self.T = (1-self.T_decay_rate) * (self.T - self.T_min) + self.T_min
750
751
def update_q(self, fun='plinear'):
752
753
if fun == 'log':
754
self.q_time += 1
755
self.q = self.q_rate * log(self.q_time)
756
# self.q = self.q_rate * log(exp(self.q/self.q_rate)+1) # without q_time
757
if fun == 'linear':
758
self.q += self.q_rate
759
if fun == 'plinear':
760
# piecewise linear
761
if self.q <= self.q_max:
762
self.q += self.q_rate
763
764
def reset(self):
765
766
self.q = self.q_init
767
self.q_time = 0
768
self.T = self.T_init
769
self.randomize_state()
770
771
# [CHECK]
772
self.prev_extC = np.zeros(self.nbindings)
773
self.prev_ext = compute_sspace_biases(self.prev_extC, self)
774
self.unclamp()
775
self.clear_input()
776
777
def update_speed(self, ema_factor=0.1, norm_ord=2, space='S'):
778
if space == 'S':
779
diff = self.act - self.prev_act
780
elif space == 'C':
781
diff = self.S2C(self.act) - self.S2C(self.prev_act)
782
783
self.speed = np.linalg.norm(diff, ord=norm_ord) #/ net.dt
784
if self.ema_speed is None:
785
self.ema_speed = self.speed
786
else:
787
self.ema_speed = ema_factor * self.ema_speed + (1-ema_factor) * self.speed
788
#return self.ema_speed
789
790
def update_dist(self, grid_points, norm_ord=2, space='S'):
791
792
if not any(isinstance(tt, list) for tt in grid_points):
793
grid_points = [grid_points]
794
795
dist = np.zeros(len(grid_points))
796
for ii, grid_point in enumerate(grid_points):
797
dist[ii] = self.compute_dist(ref_point=grid_point, norm_ord=norm_ord, space=space)
798
799
self.dist = dist
800
801
def update_traces(self):
802
803
self.act_trace[self.step, :] = self.act
804
# self.h_trace[self.step] = self.harmony()
805
# self.oh_trace[self.step] = self.oharmony()
806
# self.qh_trace[self.step] = self.qharmony()
807
# self.gh_trace[self.step] = self.gharmony()
808
# self.bh_trace[self.step] = self.bharmony()
809
self.H_trace[self.step] = self.H()
810
self.Hg_trace[self.step] = self.Hg()
811
self.Hq_trace[self.step] = self.Hq()
812
self.H0_trace[self.step] = self.H0()
813
self.H1_trace[self.step] = self.H1()
814
self.speed_trace[self.step] = self.speed
815
self.ema_speed_trace[self.step] = self.ema_speed
816
self.q_trace[self.step] = self.q
817
self.T_trace[self.step] = self.T
818
self.gp_trace[self.step] = self.read_grid_point(disp=False)
819
# Consider storing winner indices instead of the F/R binding strings.
820
# It will reduce the data size a lot
821
822
if self.grid_points is not None:
823
self.dist_trace[self.step, :] = self.dist
824
825
# def update_input(self):
826
827
##########################################################################
828
###
829
### Run
830
###
831
##########################################################################
832
833
def run(self, maxstep, plot=False, grayscale=False, colorbar=True,
834
tol=None, ema_factor=0.1, update_T=True, update_q=True, q_fun=None,
835
grid_points=None, testvar='ema_speed', space='S', norm_ord=2,
836
ext_overlap=False, ext_overlap_steps=0):
837
'''
838
Run simulations for [maxstep] steps.
839
'''
840
if q_fun is None:
841
q_fun = self.q_fun
842
843
if grid_points is None:
844
grid_points = self.grid_points
845
846
# initialize
847
self.converged = False
848
self.act_trace = np.zeros((maxstep+1, self.nunits))
849
# self.h_trace = np.zeros(maxstep+1)
850
# self.oh_trace = np.zeros(maxstep+1)
851
# self.qh_trace = np.zeros(maxstep+1)
852
# self.gh_trace = np.zeros(maxstep+1)
853
# self.bh_trace = np.zeros(maxstep+1)
854
self.H_trace = np.zeros(maxstep+1)
855
self.Hg_trace = np.zeros(maxstep+1)
856
self.Hq_trace = np.zeros(maxstep+1)
857
self.H0_trace = np.zeros(maxstep+1)
858
self.H1_trace = np.zeros(maxstep+1)
859
self.speed_trace = np.zeros(maxstep+1)
860
self.ema_speed_trace = np.zeros(maxstep+1)
861
self.q_trace = np.zeros(maxstep+1)
862
self.T_trace = np.zeros(maxstep+1)
863
self.gp_trace = [[] for _ in range(maxstep+1)]
864
865
if grid_points is not None:
866
if not any(isinstance(pp, list) for pp in grid_points):
867
grid_points = [grid_points]
868
n_grid_points = len(grid_points)
869
self.dist_trace = np.zeros((maxstep+1, len(grid_points)))
870
871
# Log initial state
872
self.step = 0
873
if grid_points is not None:
874
self.dist = np.empty(n_grid_points)
875
self.dist[:] = np.NAN
876
self.update_traces()
877
878
prev_extC = copy.deepcopy(self.prev_extC)
879
curr_extC = copy.deepcopy(self.extC)
880
for tt in np.arange(maxstep)+1:
881
self.step = tt
882
if ext_overlap and (tt <= ext_overlap_steps):
883
# Linear overlap
884
# update extC
885
self.extC = (1 - tt/ext_overlap_steps) * prev_extC + (tt/ext_overlap_steps) * curr_extC
886
self.ext = compute_sspace_biases(self.extC, self)
887
888
self.update(update_T=update_T, update_q=update_q, q_fun=q_fun,
889
space=space, norm_ord=norm_ord)
890
self.update_traces()
891
892
if tol is not None:
893
self.check_convergence(tol=tol, testvar=testvar)
894
if self.converged:
895
break
896
897
self.rt = self.step
898
#self.prev_extC = copy.deepcopy(self.extC) # this is done in the sim method.
899
900
if plot:
901
actC_trace = self.S2C(self.act_trace[:(self.rt+1),:].T).T
902
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings",
903
yticklabels=self.binding_names, grayscale=grayscale, colorbar=colorbar)
904
905
def check_convergence(self, tol, testvar='ema_speed'):
906
'''Check if the convergence criterion (distance vs. ema_speed) has been satisfied.'''
907
908
if testvar == 'dist':
909
if (self.dist < tol).any():
910
self.converged = True
911
912
if testvar == 'ema_speed':
913
if self.ema_speed < tol:
914
self.converged = True
915
916
def simulate(self):
917
918
return 0
919
920
def plot_act_trace(self, timesteps=None):
921
922
if hasattr(self, 'act_trace'):
923
if timesteps is None:
924
timesteps = list(range(self.rt+1))
925
926
actC_trace = self.S2C(self.act_trace[timesteps,:].T).T
927
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings",
928
xticklabels=None, yticklabels=self.binding_names, grayscale=False)
929
else:
930
sys.exit("There is no [act_trace] attribute in the current object.")
931
932
def plot_trace(self, varname, timesteps=None):
933
if timesteps is None:
934
timesteps = list(range(self.rt+1))
935
curr_trace = getattr(self, varname+'_trace')
936
plt.plot(curr_trace[timesteps])
937
plt.xlabel('Timestep', fontsize=16)
938
plt.ylabel(varname, fontsize=16)
939
plt.show()
940
941
# def plot_dist_trace(self, timesteps=None):
942
# if timesteps is None:
943
# timesteps = list(range(self.rt+1))
944
# plt.plot(self.dist_trace[timesteps, :])
945
# plt.xlabel('Timestep', fontsize=16)
946
# plt.ylabel('Distance', fontsize=16)
947
# plt.show()
948
949
##########################################################################
950
###
951
### Update state and parameters
952
###
953
##########################################################################
954
955
def optim(self, initvals=[], method='nelder-mead', options={'xtol': 1e-8, 'disp': True}):
956
'''Find a local optimum given an initial guess (by default, the current state is used)'''
957
# CHECK: does this work correctly when some bindings are clamped?
958
if initvals==[]:
959
initvals = self.act
960
res = optim.minimize(lambda x: -harmony(self, x), initvals, method='nelder-mead', options=options)
961
res.fun = -res.fun # Convert the (locally) minimal negative harmony (energy) to the (locally) maximal harmony
962
return res
963
964
def harmony_landscape(self, binding_name, minact, maxact):
965
'''1D harmony landscape'''
966
# To-do: Currently it works only with local representation.
967
if isinstance(binding_name, list):
968
if len(binding_name) > 1:
969
sys.exit("Choose only one binding.!")
970
971
idx = self.binding_names.index(binding_name)
972
agrid = np.linspace(minact, maxact, 1000)
973
hgrid = np.zeros(len(agrid))
974
actC = self.S2C()
975
for ii, aa in enumerate(agrid):
976
act0 = actC
977
act0[idx] = aa
978
act = self.C2S(act0)
979
hgrid[ii] = harmony(net=self, act=act)
980
return agrid, hgrid
981
982
983
984
class WeightTP():
985
986
def __init__(self, binding_names):
987
self.binding_names = binding_names
988
self.nbindings = len(binding_names)
989
self.WGC = np.zeros((self.nbindings, self.nbindings))
990
991
def set_weight(self, binding1, binding2, weight):
992
idx1 = self.binding_names.index(binding1)
993
idx2 = self.binding_names.index(binding2)
994
self.WGC[idx1, idx2] = self.WGC[idx2, idx1] = weight
995
996
def show(self):
997
print(pd.DataFrame(self.WGC, index=self.binding_names, columns=self.binding_names))
998
999
class BiasTP():
1000
def __init__(self, binding_names):
1001
self.binding_names = binding_names
1002
self.nbindings = len(binding_names)
1003
self.bGC = np.zeros(self.nbindings)
1004
1005
def set_bias(self, binding, bias):
1006
idx = self.binding_names.index(binding)
1007
self.bGC[idx] = bias
1008
1009
def set_role_bias(self, role, bias):
1010
role_list = [b.split('/')[1] for b in self.binding_names]
1011
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
1012
self.bGC[idx] = bias
1013
1014
def set_filler_bias(self, filler, bias):
1015
filler_list = [b.split('/')[0] for b in self.binding_names]
1016
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
1017
self.bGC[idx] = bias
1018
1019
def show(self):
1020
print(pd.DataFrame(self.bGC, index=self.binding_names, columns=["bias"]))
1021
1022
1023
# wh-question
1024
# To-do: quant_list should be considered in read_grid_point()
1025
class GscNet0():
1026
# local representation first.
1027
def __init__(self, net_list, group_names):
1028
self.ngroups = len(net_list)
1029
self.group_names = group_names
1030
self.groups = net_list
1031
1032
self.binding_names = []
1033
for idx, net in enumerate(self.groups):
1034
self.binding_names = self.binding_names + [b + ':' + group_names[idx] for b in net.binding_names]
1035
1036
self.nbindings = len(self.binding_names)
1037
1038
self.WGC = np.zeros((self.nbindings, self.nbindings))
1039
self.WBC = np.zeros((self.nbindings, self.nbindings))
1040
for ii, net in enumerate(self.groups):
1041
curr_bindings = [b + ':' + group_names[ii] for b in net.binding_names]
1042
idx = [self.binding_names.index(bb) for bb in curr_bindings]
1043
self.WGC[np.ix_(idx,idx)] = net.WGC
1044
self.WBC[np.ix_(idx,idx)] = net.WBC
1045
self.WC = self.WGC + self.WBC
1046
1047
self.bGC = np.zeros(self.nbindings)
1048
self.bBC = np.zeros(self.nbindings)
1049
self.extC = np.zeros(self.nbindings)
1050
nunits = 0
1051
self.group_unit_idx = list()
1052
self.group_binding_idx = list()
1053
for ii, net in enumerate(self.groups):
1054
curr_bindings = [b + ':' + group_names[ii] for b in net.binding_names]
1055
idx = [self.binding_names.index(bb) for bb in curr_bindings]
1056
self.bGC[np.ix_(idx)] = net.bGC
1057
self.bBC[np.ix_(idx)] = net.bBC
1058
self.extC[np.ix_(idx)] = net.extC
1059
unit_idx = np.arange(net.nunits) + nunits
1060
self.group_unit_idx.append(unit_idx) # WRONG ....
1061
self.group_binding_idx.append(idx)
1062
nunits = nunits + net.nunits
1063
self.bC = self.bGC + self.bBC
1064
1065
self.nunits = nunits
1066
1067
self.W = self.compute_sspace_weights()
1068
self.b, self.ext = self.compute_sspace_b_and_ext()
1069
self.act = np.zeros(self.nunits)
1070
1071
def compute_sspace_b_and_ext(self):
1072
b_distributed = np.zeros(self.nunits)
1073
ext_distributed = np.zeros(self.nunits)
1074
for ii, net in enumerate(self.groups):
1075
b_distributed[self.group_unit_idx[ii]] = net.b
1076
ext_distributed[self.group_unit_idx[ii]] = net.ext
1077
return b_distributed, ext_distributed
1078
1079
# The program does not use this:
1080
def compute_sspace_weights(self):
1081
W_distributed = np.zeros((self.nunits, self.nunits))
1082
for ii, net1 in enumerate(self.groups):
1083
idx1 = self.group_unit_idx[ii]
1084
bidx1 = self.group_binding_idx[ii]
1085
for jj, net2 in enumerate(self.groups):
1086
idx2 = self.group_unit_idx[jj]
1087
bidx2 = self.group_binding_idx[jj]
1088
W_local = self.WC[np.ix_(bidx1, bidx2)]
1089
if ii == jj:
1090
W_distributed[np.ix_(idx1, idx1)] = net1.W
1091
else:
1092
W_temp = np.zeros((net1.nunits, net2.nunits)) # from net2 to net1
1093
for kk in np.arange(0, net1.nbindings):
1094
si = net1.TP[:, kk]
1095
for ll in np.arange(0, net2.nbindings):
1096
sj = net2.TP[:, ll]
1097
#W_temp = W_temp + W_local[kk,ll] * (np.outer(si, sj) + np.outer(sj, si)) / (np.dot(si, si) * np.dot(sj, sj))
1098
W_temp = W_temp + W_local[kk,ll] * np.outer(si, sj) / (np.dot(si, si) * np.dot(sj, sj))
1099
W_distributed[np.ix_(idx1, idx2)] = W_temp
1100
return W_distributed
1101
1102
def randomize_state(self, minact=0, maxact=1):
1103
'''Set the activation state to a random vector inside a unit hypercube'''
1104
for ii, net in enumerate(self.groups):
1105
net.randomize_state()
1106
1107
# def randomize_weights(self, mu=0, sigma=1):
1108
# '''Randomize WC'''
1109
# WC = sigma*np.random.randn(self.nbindings**2).reshape(self.nbindings, self.nbindings, order='F') + mu
1110
# WC = (WC + WC.T)/2
1111
# self.WC = WC
1112
# self.W = compute_sspace_weights(self.WC, self)
1113
1114
##########################################################################
1115
###
1116
### set weights and biases (Use this only for setting the weights between networks)
1117
###
1118
##########################################################################
1119
1120
def set_weight(self, binding_name1, binding_name2, weight, symmetric=True):
1121
idx1 = self.binding_names.index(binding_name1)
1122
idx2 = self.binding_names.index(binding_name2)
1123
if symmetric:
1124
self.WGC[idx1, idx2] = self.WGC[idx2, idx1] = weight
1125
else:
1126
self.WGC[idx2, idx1] = weight
1127
self.WC = self.WGC + self.WBC
1128
self.W = self.compute_sspace_weights()
1129
1130
def set_bias(self, binding_name, bias):
1131
idx = self.binding_names.index(binding_name)
1132
self.bGC[idx] = bias
1133
self.bC = self.bGC + self.bBC
1134
self.b, self.ext = self.compute_sspace_b_and_ext()
1135
1136
def set_role_bias(self, role_name, bias):
1137
role_list = [bb.split('/')[1] for bb in self.binding_names]
1138
idx = [ii for ii, rr in enumerate(role_list) if role_name in rr]
1139
self.bGC[idx] = bias
1140
self.bC = self.bGC + self.bBC
1141
self.b, self.ext = self.compute_sspace_b_and_ext()
1142
1143
def set_filler_bias(self, filler_name, bias):
1144
filler_list = [bb.split('/')[0] for bb in self.binding_names]
1145
idx = [ii for ii, ff in enumerate(filler_list) if filler_name in ff]
1146
self.bGC[idx] = bias
1147
self.bC = self.bGC + self.bBC
1148
self.b, self.ext = self.compute_sspace_b_and_ext()
1149
1150
##########################################################################
1151
###
1152
### Etc --- CHECK if it works
1153
###
1154
##########################################################################
1155
1156
def set_seed(self, num):
1157
np.random.seed(num)
1158
1159
def find_bindings(self, binding_names):
1160
'''Find the indices of the bindings from the list of binding names.'''
1161
1162
if not isinstance(binding_names, list):
1163
binding_names = [binding_names]
1164
return [self.binding_names.index(bb) for bb in binding_names]
1165
1166
def read_state(self):
1167
'''Print the current state (C-SPACE) in a readable format. Pandas should be installed.'''
1168
for ii, net in enumerate(self.groups):
1169
print("Group: %s" % self.group_names[ii])
1170
net.read_state()
1171
print("\n")
1172
1173
def read_grid_point(self):
1174
# CHECK: current method assumes filler competition in each role.
1175
# quant_list should be considered
1176
for ii, net in enumerate(self.groups):
1177
print("Group: %s" % self.group_names[ii])
1178
print(net.read_grid_point(disp=False))
1179
print("\n")
1180
1181
def read_weight(self):
1182
print(pd.DataFrame(self.WC, index=self.binding_names, columns=self.binding_names))
1183
print(pd.DataFrame(self.bC.reshape(1, self.nbindings), index=["bias"], columns=self.binding_names))
1184
1185
def hinton(self, label=True):
1186
'''Draw a hinton diagram'''
1187
# CHECK: which of W or W-beta*I should be visualized?
1188
if label:
1189
labels=self.binding_names
1190
else:
1191
labels=None
1192
hinton(self.WC, xlabels=labels, ylabels=labels)
1193
1194
# def act_clamped(self, act=None):
1195
# # Not used. Consider removing this part.
1196
# if act is None:
1197
# act = self.act
1198
# return self.projmat.dot(act) + self.clampvec
1199
#
1200
# def check_beta(self):
1201
# # Unused.
1202
# # To-do: eig() vs. eigh()
1203
# eigvals, eigvecs = la.eigh(self.W) # W should be a symmetric matrix.
1204
# eig_max = max(eigvals) # Condition 1: beta > eig_max to be stable
1205
# if self.nbindings == 1:
1206
# beta1 = -(self.b+self.ext)/self.z
1207
# beta2 = (self.b+self.ext+eig_max)/(1-self.z)
1208
# else:
1209
# beta1 = -min(self.b+self.ext)/self.z # Condition 2: beta > beta1
1210
# beta2 = (max(self.b+self.ext)+eig_max)/(1-self.z) # Condition 3: beta > beta2
1211
# beta_min = max(eig_max, beta1, beta2)
1212
# if self.beta <= beta_min:
1213
# sys.exit("Beta (bowl strength) should be greater than %.4f." % beta_min)
1214
#
1215
# def info(self):
1216
# '''Print the network information.'''
1217
# print('Fillers = ', self.fillers)
1218
# print('Roles = ', self.roles)
1219
# print('Num_of_units = ', self.nunits)
1220
# print('Current T = ', self.T)
1221
# # Add more information ...
1222
1223
##########################################################################
1224
###
1225
### Set or clear external input / Clamp or unclamp
1226
###
1227
##########################################################################
1228
1229
def set_input(self, binding_names, ext_vals):
1230
if len(binding_names) != len(ext_vals):
1231
sys.exit("binding_names and ext_vals have different lengths.")
1232
self.extC = np.zeros(self.nbindings)
1233
ind = self.find_bindings(binding_names)
1234
self.extC[ind] = ext_vals
1235
self.ext = compute_sspace_biases(self.extC, self)
1236
1237
def clear_input(self):
1238
self.extC = np.zeros(self.nbindings)
1239
self.ext = compute_sspace_biases(self.extC, self)
1240
1241
def clamp(self, binding_names, clamp_vals):
1242
'''Clamp f/r bindings'''
1243
# A matrix of basic vectors each of which corresponds to a filler/role
1244
# binding whose activation state can change.
1245
if len(clamp_vals) != len(binding_names):
1246
sys.exit('The number of bindings clamped is not equal to the number of values provided.')
1247
1248
self.clamped = True
1249
# self.binding_names_clamped = binding_names
1250
#
1251
# idx = self.find_bindings(binding_names)
1252
# clampvecC = np.zeros(self.nbindings)
1253
# clampvecC[idx] = clamp_vals
1254
# self.clampvecC = clampvecC
1255
#
1256
# # Choose unclamped bindings. --- free to vary.
1257
# idx0 = [bb for bb in np.arange(self.nbindings) if bb not in idx]
1258
# A = self.TP[:, idx0] # complement set of idx (basis vectors of the subspace)
1259
# self.projmat = compute_projmat(A)
1260
# self.clampvec = self.TP.dot(clampvecC)
1261
# self.act = self.act_clamped(self.act)
1262
1263
# CHECK: clamping each group
1264
temp = [b.split(':') for b in binding_names]
1265
group_num = [self.group_names.index(b[1]) for b in temp]
1266
bb_names = [b[0] for b in temp]
1267
1268
for ii, gg in enumerate(set(group_num)):
1269
idx1 = [jj for jj, val in enumerate(group_num) if val==gg]
1270
curr_bindings = [bb_names[jj] for jj in idx1]
1271
curr_vals = [clamp_vals[jj] for jj in idx1]
1272
self.groups[gg].clamp(binding_names=curr_bindings, clamp_vals=curr_vals)
1273
1274
def unclamp(self):
1275
self.clamped = False
1276
for ii, net in enumerate(self.groups):
1277
if net.clamped:
1278
net.unclamp()
1279
1280
# def clamp_group(self, group_name, clamp_vals):
1281
# self.group_clamped = group_name
1282
# group_idx = self.group_names.index(group_name)
1283
# binding_names = [b + ':' + group_name for b in self.groups[group_idx].binding_names]
1284
# self.clamp(binding_names, clamp_vals)
1285
1286
##########################################################################
1287
###
1288
### Update state and parameters
1289
###
1290
##########################################################################
1291
1292
def update(self, update_T=True, update_q=True):
1293
'''Update the current state (with noise) for each group'''
1294
for ii, net1 in enumerate(self.groups): # to
1295
extC = np.zeros(net1.nbindings)
1296
binding_names1 = [b + ':' + self.group_names[ii] for b in net1.binding_names]
1297
idx1 = [self.binding_names.index(b) for b in binding_names1]
1298
for jj, net2 in enumerate(self.groups): # from
1299
# between group weights
1300
# curr_W = self.W[idx1, idx2]
1301
if not (ii==jj):
1302
binding_names2 = [b + ':' + self.group_names[jj] for b in net2.binding_names]
1303
idx2 = [self.binding_names.index(b) for b in binding_names2]
1304
curr_WGC = self.WGC[np.ix_(idx1, idx2)]
1305
extC = extC + curr_WGC.dot(net2.S2C(net2.act))
1306
net1.ext = net1.C2S(extC)
1307
1308
for ii, net in enumerate(self.groups):
1309
net.update(update_T = update_T, update_q = update_q)
1310
self.act[self.group_unit_idx[ii]] = net.act
1311
1312
# def update_state(self):
1313
# '''Update the current state (with noise)'''
1314
# self.act += self.dt * self.hgrad() # update state (deterministic change)
1315
# self.add_noise() # add noise (stochastic part)
1316
# if self.clamped:
1317
# self.act = self.act_clamped()
1318
#
1319
# def add_noise(self):
1320
# '''Add a noise to the current activation state.'''
1321
# noise = sqrt(2 * self.T * self.dt) * np.random.randn(self.nunits)
1322
# self.act += noise
1323
1324
# def update_T(self, method='exponential'):
1325
# '''Update the current temperature'''
1326
# # In the LDNet, the decay rate per timestep is adjusted based on the time step constant.
1327
# # For simplicity, the decay rate is assumed to be per-step.
1328
# if method == 'exponential':
1329
# self.T = (1-self.T_decay_rate) * (self.T - self.T_min) + self.T_min
1330
#
1331
# def update_q(self, method='log'):
1332
# if method == 'log':
1333
# self.q_time += 1
1334
# self.q = self.q_factor * log(self.q_time)
1335
1336
def reset(self):
1337
for ii, net in enumerate(self.groups):
1338
net.reset()
1339
1340
##########################################################################
1341
###
1342
### Harmony and harmony gradient
1343
###
1344
##########################################################################
1345
1346
def harmony(self, act=None, quant_list=None):
1347
'''Compute the total harmony of the current state'''
1348
if act is None:
1349
act = self.act
1350
return harmony(net=self, act=act, quant_list=quant_list)
1351
1352
def gharmony(self, act=None):
1353
'''Compute the grammar harmony of the current state'''
1354
if act is None:
1355
act = self.act
1356
return gharmony(net=self, act=act)
1357
1358
def hgrad(self, act=None, quant_list=None):
1359
'''Compute the harmony gradient evaluated at the current state'''
1360
if act is None:
1361
act = self.act
1362
return hgrad(net=self, act=act, quant_list=quant_list)
1363
1364
1365
##########################################################################
1366
###
1367
### Run
1368
###
1369
##########################################################################
1370
1371
def run(self, maxstep, tol=None):
1372
self.rt, self.final_state, self.act_trace, self.ema_speed_trace = run_GscNet0(self, maxstep=maxstep, tol=tol)
1373
1374
def plot_act_trace(self):
1375
for ii, net in enumerate(self.groups):
1376
curr_act_trace = self.act_trace[:, self.group_unit_idx[ii]]
1377
actC_trace = net.TPinv.dot(curr_act_trace.T).T
1378
print("Group: %s" % self.group_names[ii])
1379
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings", yticklabels=net.binding_names, grayscale=False, colorbar=True)
1380
1381
# def run(self, maxstep, tol=None):
1382
# if tol is not None:
1383
# self.rt, self.final_state, self.act_trace, self.t_trace, self.q_trace, self.ema_speed_trace = run(self, maxstep=maxstep, tol=tol)
1384
# else:
1385
# self.rt, self.final_state, self.act_trace, self.t_trace, self.q_trace, self.ema_speed_trace = run(self, maxstep=maxstep)
1386
#
1387
# def plot_act_trace(self):
1388
# if hasattr(self, 'act_trace'):
1389
# heatmap(self.act_trace.T, xlabel="Timestep", ylabel="Bindings",
1390
# xticklabels=None, yticklabels=self.binding_names, grayscale=False)
1391
# else:
1392
# sys.exit("There is no [act_trace] attribute in the current object.")
1393
1394
##########################################################################
1395
###
1396
### Optim (CHECK)
1397
###
1398
##########################################################################
1399
1400
def gharmony_given_clamping(self, x, group_name_clamped, clampval, quant_list, q):
1401
x0 = np.zeros(self.nbindings)
1402
self.clamp_group(group_name_clamped, clampval)
1403
x0[self.clampind] = self.clampvec[self.clampind]
1404
ind = list(set(np.arange(self.nbindings)) - set(self.clampind))
1405
x0[ind] = x
1406
H = 0.5 * x0.dot(self.W.dot(x0)) + x0.dot(self.b)
1407
for jj, subgroup in enumerate(quant_list):
1408
temp_ind = [self.binding_names.index(bb) for bb in subgroup]
1409
qx = x0[temp_ind]
1410
H = H + q* self.qharmony(qx)
1411
return H
1412
1413
def optimal_state_given_clamping(self, group_name_clamped, clampval, quant_list, q, method="Nelder-Mead"):
1414
group_ind = self.group_names.index(group_name_clamped)
1415
if method=="basinhopping":
1416
res = optim.basinhopping(lambda x: -self.gharmony_given_clamping(x, group_name_clamped, clampval, quant_list, q),
1417
np.random.uniform(size = self.nbindings - self.groups[group_ind].nbindings))
1418
else:
1419
res = optim.minimize(lambda x: -self.gharmony_given_clamping(x, group_name_clamped, clampval, quant_list, q),
1420
np.random.uniform(size = self.nbindings - self.groups[group_ind].nbindings),
1421
method = method)
1422
print(res)
1423
x0 = np.zeros(self.nbindings)
1424
x0[self.clampind] = self.clampvec[self.clampind]
1425
ind = list(set(np.arange(self.nbindings)) - set(self.clampind))
1426
x0[ind] = res.x
1427
self.optim_state = x0
1428
print(x0)
1429
1430
1431
# CHECK: for now, consider all units' act change:
1432
def run_GscNet0(net, maxstep, plot=False, grayscale=False, colorbar=True,
1433
tol=None, ema_factor=0.1, crop=True, update_q=True):
1434
# To-do: needs to include conversion from distributed to local representations
1435
# ema factor
1436
ngroups = len(net.groups)
1437
act_trace = np.zeros((maxstep, net.nunits))
1438
1439
if tol is None:
1440
ema_speed_trace = np.zeros(maxstep)
1441
prev_act = copy.deepcopy(net.act)
1442
ema_speed = 0
1443
ema_factor = 0.1
1444
for tt in np.arange(maxstep): #range(1, maxstep+1):
1445
net.update(update_q=update_q)
1446
act_trace[tt, :] = net.act
1447
ema_speed_trace = -1
1448
1449
else:
1450
if tol < 0:
1451
sys.exit("The tolerance parameter should be a positive real number.")
1452
else:
1453
ema_speed_trace = np.zeros(maxstep)
1454
prev_act = copy.deepcopy(net.act)
1455
ema_speed = 0
1456
ema_factor = ema_factor
1457
for tt in np.arange(maxstep): #range(1, maxstep+1):
1458
net.update(update_q=update_q)
1459
#diff = sqrt(sum((net.act - prev_act)**2))/net.nbindings # CHECK: nunits vs. nbindings
1460
diff = max(abs(net.act - prev_act)) # divide it by dt?
1461
if tt == 0:
1462
ema_speed = diff
1463
else:
1464
ema_speed = ema_factor * ema_speed + (1-ema_factor) * diff
1465
act_trace[tt, :] = net.act
1466
ema_speed_trace[tt] = ema_speed
1467
prev_act = copy.deepcopy(net.act)
1468
1469
if (ema_speed < tol) and (tol > 0):
1470
break
1471
1472
ema_speed_trace = ema_speed_trace[tt]
1473
1474
if plot:
1475
for ii, net in enumerate(net.groups):
1476
curr_act_trace = act_trace[:, net.group_unit_idx[ii]]
1477
actC_trace = net.S2C(curr_act_trace[:tt,:].T).T
1478
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings", yticklabels=net.binding_names, grayscale=grayscale, colorbar=colorbar)
1479
1480
# return [tt+1, net.act, act_trace[:tt,], t_trace[:tt], q_trace[:tt], ema_speed_trace[:tt]]
1481
return [tt+1, net.act, act_trace[:tt,], ema_speed_trace]
1482
1483
1484
1485
# Harmony
1486
1487
# Total Harmony = H_G + q * H_Q
1488
# H_G = H_0 + H_1
1489
# H_0: Harmonic Grammar
1490
# H_1: Bowl
1491
# H_Q: Quantization
1492
1493
def bind(fillernames, rolenames, sep='/'):
1494
return [f + sep + r for r in rolenames for f in fillernames]
1495
1496
1497
def H(net, act, quant_list=None):
1498
# Evaluate total harmony at [act] in neural coordinates.
1499
# For not, quant_list will be ignored.
1500
# act: neural coordinate (n in Paul's derivation)
1501
# actC: conceptual coordinate (a in Paul's derivation)
1502
return Hg(net, act) + net.q * Q(n=act, net=net)
1503
1504
def HGrad(net, act, quant_list=None):
1505
return HgGrad(net, act) + net.q * QGrad(n=act, net=net)
1506
1507
def Hg(net, act):
1508
# act: neural coordinate of the current state
1509
# WG: weights in neural coordinates
1510
return H0(net, act) + H1(net, act) # + constant
1511
1512
def HgGrad(net, act):
1513
return H0Grad(net, act) + H1Grad(net, act)
1514
1515
def H0(net, act):
1516
# act: neural coordinate of the current state
1517
# WG: weights in neural coordinates
1518
return 0.5 * act.dot(net.WG).dot(act) + (net.bG + net.ext).dot(act)
1519
1520
def H0Grad(net, act):
1521
return net.WG.dot(act) + net.bG + net.ext
1522
1523
def H1(net, act):
1524
# zeta: bowl center in neural coordinates
1525
return -0.5 * net.beta * (act - net.zeta).dot(act - net.zeta)
1526
#return 0.5 * act.dot(net.WB).dot(act) + net.bB.dot(act)
1527
1528
def H1Grad(net, act):
1529
return -net.beta * (act - net.zeta)
1530
1531
1532
1533
#def Hq(net, act, quant_list):
1534
# # Compute quantization harmony Q. By default, it considers filler competition in each role.
1535
# # For now, use dynamics in the pattern coordinates.
1536
# # Should be replaced with distributed dynamics.
1537
# actC = net.S2C(act)
1538
#
1539
# if quant_list is None:
1540
# # By default, assumes filler-competition in each role.
1541
# if net.nfillers == 1:
1542
# const1 = 0
1543
# const2 = -np.sum(actC**2 * (1-actC)**2)
1544
# Q = const2
1545
# else:
1546
# const1 = -np.sum((np.sum(actC.reshape(net.nfillers,net.nroles,order='F')**2, axis=0)-1)**2)
1547
# const2 = -np.sum(actC**2 * (1-actC)**2)
1548
# Q = net.c * const1 + (1-net.c) * const2
1549
# else:
1550
# Q = 0
1551
# for jj, group in enumerate(quant_list):
1552
# temp_ind = [net.binding_names.index(bb) for bb in group]
1553
# actC0 = actC[temp_ind]
1554
# Q = Q + q2(net, actC0)
1555
# return Q
1556
1557
#def HqGrad(net, act, quant_list=None):
1558
# # Should be replaced with distributed dynamics.
1559
# actC = net.S2C(act)
1560
# if quant_list is None:
1561
# # By default, assume filler competition in each role.
1562
# if net.nfillers == 1:
1563
# # One-unit network case
1564
# # 'Sum_of_square = 1' does not make sense.
1565
# q_grad = -2 * actC * (actC-1) * (2*actC-1)
1566
# else:
1567
# const1 = -4 * actC * ( np.tile( np.sum(actC.reshape(net.nfillers,net.nroles,order='F')**2, axis=0), (net.nfillers,1) ).flatten('F') - 1)
1568
# const2 = -2 * actC * (actC-1) * (2*actC-1)
1569
# q_grad = net.c * const1 + (1-net.c) * const2
1570
# else:
1571
# q_grad = np.zeros(net.nbindings)
1572
# for jj, subgroup in enumerate(quant_list):
1573
# curr_q_grad = np.zeros(net.nbindings)
1574
# temp_ind = [net.binding_names.index(bb) for bb in subgroup]
1575
# actC0 = actC[temp_ind]
1576
# curr_q_grad[temp_ind] = q2grad(net, actC0)
1577
# q_grad = q_grad + curr_q_grad
1578
# return net.C2S(q_grad)
1579
1580
1581
1582
1583
1584
1585
#def harmony(net, act, quant_list=None):
1586
# '''Compute total harmony at a given activation state.'''
1587
# # gharmony = H_0, bharmony = H_1, q2haromny = H_Q
1588
# return gharmony(net, act) + bharmony(net, act) + net.q * q2harmony(net, act, quant_list)
1589
1590
#def oharmony(net, act):
1591
# '''Compute total harmony at a given activation state.'''
1592
# #return gharmony(net, act) + bharmony(net, act)
1593
# return 0.5 * act.dot(net.W).dot(act) + act.dot(net.b) + act.dot(net.ext) # distributed version
1594
# #return 0.5 * net.S2C(act).dot(net.WC).dot(net.S2C(act)) + net.S2C(act).dot(net.bC) + net.S2C(act).dot(net.extC)
1595
1596
#def gharmony(net, act):
1597
# '''Compute harmony (given a harmonic grammar) at a given activation state.'''
1598
# return 0.5 * act.dot(net.WG).dot(act) + (net.bG + net.ext).dot(act)
1599
1600
#def bharmony(net, act):
1601
# '''Compute bowl harmony. act should be an array object.'''
1602
# return -0.5 * net.beta * (net.S2C(act) - net.z).dot(net.S2C(act) - net.z)
1603
1604
def q2harmony(net, act=None, quant_list=None):
1605
'''Compute quantization harmony Q. By default, it considers filler competition in each role.'''
1606
if act is None:
1607
act = net.act
1608
actC = net.S2C(act)
1609
1610
if quant_list is None:
1611
# By default, assumes filler-competition in each role.
1612
if net.nfillers == 1:
1613
const1 = 0
1614
const2 = -np.sum(actC**2 * (1-actC)**2)
1615
Q = const2
1616
else:
1617
const1 = -np.sum((np.sum(actC.reshape(net.nfillers,net.nroles,order='F')**2, axis=0)-1)**2)
1618
const2 = -np.sum(actC**2 * (1-actC)**2)
1619
Q = net.c * const1 + (1-net.c) * const2
1620
else:
1621
Q = 0
1622
for jj, group in enumerate(quant_list):
1623
temp_ind = [net.binding_names.index(bb) for bb in group]
1624
actC0 = actC[temp_ind]
1625
Q = Q + q2(net, actC0)
1626
return Q
1627
1628
def q2(net, actC):
1629
# Within-group competition
1630
# actC = net.TPinv.dot(act)
1631
const1 = -np.sum(actC**2-1)**2
1632
const2 = -np.sum(actC**2 * (1-actC)**2)
1633
if len(actC) == 1:
1634
return const2
1635
else:
1636
return net.c * const1 + (1-net.c) * const2
1637
1638
#def hgrad(net, act, quant_list=None):
1639
# #return ghgrad(net, act) + bhgrad(net, act) + net.q * q2hgrad(net, act, quant_list)
1640
# return ohgrad(net, act) + net.q * q2hgrad(net, act, quant_list)
1641
1642
#def ghgrad(net, act):
1643
# return net.WG.dot(act) + net.bG + net.ext
1644
1645
#def bhgrad(net, act):
1646
# return -net.beta * (net.S2C(act) - net.z)
1647
# #return net.WB.dot(net.act) + net.bB # should be same -net.beta * (net.S2C(act) - net.z)
1648
1649
#def ohgrad(net, act):
1650
# return net.W.dot(act) + net.b + net.ext
1651
1652
def q2hgrad(net, act, quant_list=None):
1653
actC = net.S2C(act)
1654
if quant_list is None:
1655
if net.nfillers == 1:
1656
const1 = 0
1657
const2 = -2 * actC * (actC-1) * (2*actC-1)
1658
else:
1659
const1 = -4 * actC * ( np.tile( np.sum(actC.reshape(net.nfillers,net.nroles,order='F')**2, axis=0), (net.nfillers,1) ).flatten('F') - 1)
1660
const2 = -2 * actC * (actC-1) * (2*actC-1)
1661
q_grad = net.c * const1 + (1-net.c) * const2
1662
else:
1663
q_grad = np.zeros(net.nbindings)
1664
for jj, subgroup in enumerate(quant_list):
1665
curr_q_grad = np.zeros(net.nbindings)
1666
temp_ind = [net.binding_names.index(bb) for bb in subgroup]
1667
actC0 = actC[temp_ind]
1668
curr_q_grad[temp_ind] = q2grad(net, actC0)
1669
q_grad = q_grad + curr_q_grad
1670
return net.C2S(q_grad)
1671
1672
def q2grad(net, actC):
1673
'''q2grad for each group'''
1674
const1 = -4 * actC * (np.sum(actC**2) - 1)
1675
const2 = -2 * actC * (actC-1) * (2*actC-1)
1676
return net.c * const1 + (1-net.c) * const2
1677
1678
# Hinton diagram from ( http://wiki.scipy.org/Cookbook/Matplotlib/HintonDiagrams )
1679
# The functions were slightly modified for consistency
1680
def _blob(x,y,area,colour):
1681
"""
1682
Draws a square-shaped blob with the given area (< 1) at
1683
the given coordinates.
1684
http://wiki.scipy.org/Cookbook/Matplotlib/HintonDiagrams
1685
"""
1686
# Replaced N with np for consistency
1687
hs = np.sqrt(area) / 2
1688
xcorners = np.array([x - hs, x + hs, x + hs, x - hs])
1689
ycorners = np.array([y - hs, y - hs, y + hs, y + hs])
1690
P.fill(xcorners, ycorners, colour, edgecolor=colour)
1691
1692
def hinton(W, maxWeight=None, xlabels=None, ylabels=None):
1693
"""
1694
Draws a Hinton diagram for visualizing a weight matrix.
1695
Temporarily disables matplotlib interactive mode if it is on,
1696
otherwise this takes forever.
1697
http://wiki.scipy.org/Cookbook/Matplotlib/HintonDiagrams
1698
"""
1699
# Replaced N with np for consistency
1700
# xrange was replaced with range because it is not available in Python3
1701
# To do: Consider adding the f/r binding information as X and YTick labels
1702
reenable = False
1703
if P.isinteractive():
1704
P.ioff()
1705
P.clf()
1706
height, width = W.shape
1707
if not maxWeight:
1708
maxWeight = 2**np.ceil(np.log(np.max(np.abs(W)))/np.log(2))
1709
1710
P.fill(np.array([0,width,width,0]),np.array([0,0,height,height]),'gray')
1711
#P.axis('off')
1712
#P.axis('equal')
1713
for x in range(width):
1714
for y in range(height):
1715
_x = x+1
1716
_y = y+1
1717
w = W[y,x]
1718
if w > 0:
1719
_blob(_x - 0.5, height - _y + 0.5, min(1,w/maxWeight),'white')
1720
elif w < 0:
1721
_blob(_x - 0.5, height - _y + 0.5, min(1,-w/maxWeight),'black')
1722
if reenable:
1723
P.ion()
1724
1725
# PWC
1726
label = False
1727
if xlabels is not None:
1728
P.xticks(np.arange(len(xlabels))+0.5, xlabels, rotation='vertical')
1729
label = True
1730
1731
if ylabels is not None:
1732
P.yticks(np.arange(len(ylabels))+0.5, ylabels[::-1])
1733
label = True
1734
1735
if label:
1736
P.xlim([0,width])
1737
P.ylim([0,height])
1738
P.gca().set_aspect('equal', adjustable='box')
1739
P.gca().tick_params(direction='out')
1740
#P.axis('equal')
1741
else:
1742
P.axis('off')
1743
P.axis('equal')
1744
1745
P.show()
1746
1747
1748
def heatmap(data, xlabel=None, ylabel=None, xticklabels=None, yticklabels=None,
1749
grayscale=False, colorbar=True):
1750
# Plot the activation trace as heatmap
1751
if grayscale:
1752
cmap = plt.cm.get_cmap("gray_r")
1753
else:
1754
cmap=plt.cm.get_cmap("Reds")
1755
plt.imshow(data, cmap=cmap, interpolation="nearest", aspect='auto')
1756
if xlabel is not None:
1757
plt.xlabel(xlabel, fontsize=16)
1758
if ylabel is not None:
1759
plt.ylabel(ylabel, fontsize=16)
1760
if xticklabels is not None:
1761
plt.xticks(np.arange(len(xticklabels)), xticklabels)
1762
if yticklabels is not None:
1763
plt.yticks(np.arange(len(yticklabels)), yticklabels)
1764
if colorbar:
1765
plt.colorbar()
1766
plt.show()
1767
1768
# create symbol representations
1769
def encode_symbols(nsymbols, reptype='local', dp=0, ndim=None):
1770
if reptype == 'local':
1771
sym_mat = np.eye(nsymbols)
1772
else:
1773
if ndim==None:
1774
ndim = nsymbols
1775
1776
if isinstance(dp, Number):
1777
# dp is a scalar that represents the expected pairwise similarity (dot product).
1778
sym_mat = dot_products(nsymbols, ndim, dp)
1779
else:
1780
sym_mat = dot_products2(nsymbols, ndim, dp)
1781
return sym_mat
1782
1783
def dot_products(nsymbols, ndim, s):
1784
dp_mat = s * np.ones((nsymbols, nsymbols)) + (1-s) * np.eye(nsymbols, nsymbols)
1785
sym_mat = dot_products2(nsymbols, ndim, dp_mat)
1786
return sym_mat
1787
1788
def dot_products2(nsymbols, ndim, dp_mat, max_iter = 100000):
1789
# Given square matrix dpMatrix of dimension N-by-N, find N
1790
# dim-dimensional unit vectors whose pairwise dot products match
1791
# dpMatrix. Results are returned in the columns of M. itns is the
1792
# number of iterations of search required, and may be ignored.
1793
#
1794
# Algorithm: Find a matrix M such that M'*M = dpMatrix. This is done
1795
# via gradient descent on a cost function that is the square of the
1796
# frobenius norm of (M'*M-dpMatrix).
1797
#
1798
# NOTE: It has trouble finding more than about 16 vectors, possibly for
1799
# dumb numerical reasons (like stepsize and tolerance), which might be
1800
# fixable if necessary.
1801
1802
if not (dp_mat.T == dp_mat).all():
1803
# print error
1804
sys.exit('dot_products2: dp_mat must be symmetric')
1805
1806
if (np.diag(dp_mat) != 1).any():
1807
sys.exit('dot_products2: dp_mat must have all ones on the main diagonal')
1808
1809
sym_mat = np.random.uniform(size=ndim*nsymbols).reshape(ndim, nsymbols, order='F')
1810
min_step = .1
1811
tol = 1e-6
1812
#max_iter = 100000
1813
converged = False
1814
for iter_num in range(1, max_iter+1):
1815
inc = sym_mat.dot(sym_mat.T.dot(sym_mat) - dp_mat)
1816
step = min(min_step, .01/abs(inc).max())
1817
sym_mat = sym_mat - step * inc
1818
max_diff = abs(sym_mat.T.dot(sym_mat)-dp_mat).max()
1819
if max_diff <= tol:
1820
converged = True
1821
break
1822
1823
if not converged:
1824
print("Didn't converge after %d iterations" % max_iter)
1825
1826
return sym_mat
1827
1828
def compute_TPmat(R, F):
1829
# TP matrix that converts local to distributed representations (C-space to S-space).
1830
# See http://en.wikipedia.org/wiki/Vectorization_(mathematics) for justification of kronecker product.
1831
TP = np.kron(R,F)
1832
if TP.shape[0] == TP.shape[1]:
1833
TPinv = la.inv(TP)
1834
else:
1835
TPinv = la.pinv(TP) # TP may be a non-square matrix. So use pseudo-inverse.
1836
return TP, TPinv
1837
1838
def compute_sspace_biases(b_local, net):
1839
return net.TPinv.T.dot(b_local)
1840
1841
def compute_sspace_weights(W_local, net):
1842
return net.TPinv.T.dot(W_local).dot(net.TPinv)
1843
1844
def compute_projmat(A):
1845
# A is an n x m matrix of basis (column) vectors of the subspace.
1846
# Only works if the rank of A is equal to the nunmber of columns of A.
1847
return A.dot(la.inv(A.T.dot(A))).dot(A.T)
1848
1849
#class sim_params():
1850
#
1851
# def __init__(self):
1852
#
1853
# self.testvar = 'ema_speed'
1854
# self.convergence_check = True
1855
# self.norm_ord = np.inf
1856
# self.dist_space = 'S'
1857
# self.tol = 1e-2
1858
# self.ema_factor = 0.1
1859
# self.grid_points = None
1860
# self.input_method = 'ext'
1861
# self.update_T = True
1862
# self.update_q = True
1863
# self.maxstep = np.array(10000)
1864
# self.nrep = 1
1865
# self.seed_num = time.time()
1866
#
1867
# def set_params(self, name, val):
1868
#
1869
# setattr(self, name, val)
1870
1871
class sim():
1872
# Use this class to run simulations and store the results.
1873
def __init__(self, net, params):
1874
# net: an object of the GscNet class.
1875
# params: a dictionary for parameter setting.
1876
self.net = net
1877
self.params = params
1878
1879
# Before running, save all gridpoints
1880
if self.net.getGPset:
1881
self.net.all_grid_points()
1882
self.gpset = copy.deepcopy(self.net.gpset)
1883
self.gpset_gh = copy.deepcopy(self.net.gpset_gh)
1884
1885
def set_params(self, p_name, p_val):
1886
self.params[p_name] = p_val
1887
1888
def set_seed(self, num):
1889
np.random.seed(num)
1890
1891
def set_input(self, input_list, input_vals):
1892
1893
if not isinstance(input_list, list):
1894
input_list = [input_list]
1895
1896
nwords = len(input_list)
1897
self.inputC = np.zeros((self.net.nbindings, nwords))
1898
1899
for ii, binding in enumerate(input_list):
1900
1901
val = input_vals[ii]
1902
inputC = np.zeros(self.net.nbindings)
1903
curr_role = binding.split('/')[1]
1904
role_list = [bb.split('/')[1] for bb in self.net.binding_names]
1905
role_idx = [jj for jj, rr in enumerate(role_list) if curr_role == rr]
1906
if not isinstance(binding, list):
1907
binding = [binding]
1908
binding_idx = self.net.find_bindings(binding)
1909
1910
if self.params['input_inhib']:
1911
if self.params['input_method'] == 'ext':
1912
inputC[role_idx] = -val
1913
1914
inputC[binding_idx] = val
1915
self.inputC[:, ii] = inputC
1916
1917
if self.params['cumulative_input']:
1918
1919
self.inputC = self.inputC.cumsum(axis=1)
1920
1921
def simulate(self, params=None):
1922
1923
if params is None:
1924
params = self.params
1925
1926
nrep = self.params['nrep']
1927
maxstep = self.params['maxstep'].max()
1928
nwords = len(self.input_list) # including the end of sentence markers, if any.
1929
self.act_trace = np.zeros((nrep, nwords, maxstep+1, self.net.nunits))
1930
self.T_trace = np.zeros((nrep, nwords, maxstep+1))
1931
self.q_trace = np.zeros((nrep, nwords, maxstep+1))
1932
self.speed_trace = np.zeros((nrep, nwords, maxstep+1))
1933
self.ema_speed_trace = np.zeros((nrep, nwords, maxstep+1))
1934
self.converged = np.ones((nrep, nwords), dtype=bool)
1935
self.rt = np.zeros((nrep, nwords))
1936
#self.grid_point = np.chararray((nrep, nwords))
1937
self.gp = [[[ [] for _ in range(maxstep+1)] for _ in range(nwords)] for _ in range(nrep)]
1938
1939
if self.params['grid_points'] is not None:
1940
grid_points = self.params['grid_points']
1941
if not any(isinstance(pp, list) for pp in grid_points):
1942
grid_points = [grid_points]
1943
self.dist_trace = np.zeros((nrep, nwords, maxstep+1, len(grid_points)))
1944
1945
# create a storage to save the states at the end of processing a word.
1946
# extC hist
1947
1948
for rep_num in np.arange(self.params['nrep']):
1949
1950
#self.net.unclamp()
1951
#self.net.clear_input()
1952
self.net.reset() # Now reset() will do unclamp and clear_input
1953
self.set_input(self.input_list, self.input_vals)
1954
1955
for word_num, curr_input in enumerate(self.input_list):
1956
# curr word
1957
# if self.params['cumulative_input']:
1958
# curr_input = self.input_list[0:word_num]
1959
# curr_vals = self.input_vals[0:word_num]
1960
# else:
1961
# curr_input = self.input_list[word_num]
1962
# act_trace curr_vals = self.input_vals[word_num]
1963
#
1964
# curr_val = self.input_vals[word_num]
1965
# if not isinstance(curr_input, list):
1966
# curr_input = [curr_input]
1967
# if not isinstance(curr_val, list):
1968
# curr_val = [curr_val]
1969
1970
curr_maxstep = self.params['maxstep'][word_num]
1971
1972
if self.params['input_method'] == 'ext':
1973
self.net.extC = copy.deepcopy(self.inputC[:,word_num])
1974
self.net.ext = compute_sspace_biases(self.net.extC, self.net)
1975
1976
if word_num > 0:
1977
self.net.prev_extC = copy.deepcopy(self.inputC[:,word_num-1])
1978
self.net.prev_ext = compute_sspace_biases(self.net.prev_extC, self.net)
1979
1980
elif self.params['input_method'] == 'clamp':
1981
if self.params['cumulative_input'] == True:
1982
self.net.clamp(self.input_list[:word_num+1])
1983
else:
1984
self.net.clamp(self.input_list[word_num])
1985
#sys.exit('Not yet implementd')
1986
1987
self.net.run(curr_maxstep,
1988
norm_ord=self.params['norm_ord'],
1989
tol=self.params['tol'],
1990
update_T=self.params['update_T'],
1991
update_q=self.params['update_q'],
1992
q_fun=self.params['q_fun'],
1993
grid_points=self.params['grid_points'],
1994
ext_overlap=self.params['ext_overlap'],
1995
ext_overlap_steps=self.params['ext_overlap_steps']) # ADDED
1996
1997
#self.net.prev_extC = copy.deepcopy(self.net.extC)
1998
1999
self.act_trace[rep_num, word_num, :(curr_maxstep+1), :] = self.net.act_trace
2000
self.rt[rep_num, word_num] = self.net.rt
2001
self.T_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.T_trace
2002
self.q_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.q_trace
2003
self.speed_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.speed_trace
2004
self.ema_speed_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.ema_speed_trace
2005
self.converged[rep_num, word_num] = self.net.converged
2006
if self.params['grid_points'] is not None:
2007
self.dist_trace[rep_num, word_num, :(curr_maxstep+1), :] = self.net.dist_trace
2008
self.gp[rep_num][word_num][:(curr_maxstep+1)] = self.net.gp_trace
2009
2010
def plot_act_trace(self, rep_num):
2011
2012
rep_ind = rep_num - 1
2013
2014
nwords = self.act_trace.shape[1]
2015
curr_act_trace = self.act_trace[rep_ind, 0, :(self.rt[rep_ind, 0]+1), :]
2016
if nwords > 1:
2017
for wind in np.arange(1, nwords):
2018
temp_act_trace = self.act_trace[rep_ind, wind, 1:(self.rt[rep_ind, wind]+1), :]
2019
curr_act_trace = np.concatenate((curr_act_trace, temp_act_trace), axis=0)
2020
2021
actC_trace = self.net.S2C(curr_act_trace.T).T
2022
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings",
2023
xticklabels=None, yticklabels=self.net.binding_names, grayscale=False)
2024
2025
def plot_trace(self, varname, rep_num):
2026
2027
rep_ind = rep_num - 1
2028
nwords = self.act_trace.shape[1]
2029
curr_trace = getattr(self, varname+'_trace')[rep_ind, 0, :(self.rt[rep_ind, 0]+1)]
2030
if nwords > 1:
2031
for wind in np.arange(1, nwords):
2032
temp_trace = getattr(self, varname+'_trace')[rep_ind, wind, 1:(self.rt[rep_ind, 0]+1)]
2033
curr_trace = np.concatenate((curr_trace, temp_trace), axis=0)
2034
2035
curr_rt_trace = self.rt[rep_ind, :]
2036
curr_rt_trace = curr_rt_trace.cumsum()
2037
2038
plt.plot(curr_trace)
2039
for ii, rt in enumerate(curr_rt_trace):
2040
plt.plot([rt, rt], [min(curr_trace[1:,:].flatten()), max(curr_trace[1:,:].flatten())], 'g-')
2041
2042
plt.xlabel('Timestep', fontsize=16)
2043
plt.ylabel(varname, fontsize=16)
2044
2045
def read_state(self, rep_num, word_num):
2046
2047
act = self.act_trace[rep_num-1, word_num-1, self.rt[rep_num-1, word_num-1], :]
2048
self.net.read_state(act)
2049
#actC = self.net.S2C(self.act_trace[rep_num-1, word_num-1, self.rt[rep_num-1, word_num-1], :])
2050
#print(pd.DataFrame(self.net.vec2mat(actC), index=self.net.filler_names, columns=self.net.role_names))
2051
2052
def read_grid_point(self, rep_num, word_num):
2053
act = self.act_trace[rep_num-1, word_num-1, self.rt[rep_num-1, word_num-1], :]
2054
self.net.read_grid_point(act)
2055
2056
def compute_gpdist(self):
2057
# create gpset here after unclamping and clearing ext input.
2058
nrep = self.params['nrep']
2059
nwords = len(self.input_list)
2060
maxstep = self.params['maxstep']
2061
gp_count = np.zeros((nrep, nwords, maxstep+1, len(self.gpset)))
2062
for rep_ind in range(nrep):
2063
for word_ind in range(nwords):
2064
curr_rt = self.rt[rep_ind][word_ind].astype(int)
2065
for step in range(curr_rt):
2066
idx = self.gpset.index(self.gp[rep_ind][word_ind][step])
2067
gp_count[rep_ind, word_ind, step, idx] += 1
2068
gp_prob = gp_count.sum(axis=0)/nrep
2069
return gp_prob
2070
2071
def plot_gp_prob_trace(self, word_num, gp_prob_trace, ngp=None, yticklab=True, grayscale=False, colorbar=True):
2072
if ngp is None:
2073
ngp = gp_prob_trace.shape[-1]
2074
2075
minrt = np.min(self.rt[:][word_num-1]).astype(int)
2076
curr_trace = gp_prob_trace[word_num-1,:minrt,:(ngp+1)]
2077
if yticklab:
2078
yticklabels = [','.join(gp)+('(%.2f)' % self.gpset_gh[ii]) for ii, gp in enumerate(self.gpset[:(ngp+1)])]
2079
heatmap(curr_trace.T, xlabel="Timestep", ylabel="", \
2080
xticklabels=None, yticklabels=yticklabels,
2081
grayscale=grayscale, colorbar=colorbar)
2082
else:
2083
heatmap(curr_trace.T, xlabel="Timestep", ylabel="", \
2084
xticklabels=None, yticklabels=None,
2085
grayscale=grayscale, colorbar=colorbar)
2086
2087
2088
2089
def b_ind(f, r, net):
2090
return f + r * net.nfillers
2091
2092
def u_ind(phi, rho, net):
2093
return phi + rho * net.ndim_f
2094
2095
def w(f, r, phi, rho, net):
2096
# A = net.TPinv
2097
return net.TPinv[b_ind(f, r, net), u_ind(phi, rho, net)]
2098
2099
def get_a(n, net, f, r):
2100
act = 0
2101
for phi in range(net.ndim_f):
2102
for rho in range(net.ndim_r):
2103
act += w(f, r, phi, rho, net) * n[u_ind(phi, rho, net)]
2104
return act
2105
2106
def n2a(n, net, f=None, r=None):
2107
# quant_list
2108
if (f is None) and (r is None):
2109
avec = np.zeros(net.nbindings)
2110
for f in range(net.nfillers):
2111
for r in range(net.nroles):
2112
avec[b_ind(f, r, net)] = get_a(n, net, f, r)
2113
return avec
2114
elif (f is None) and (r is not None):
2115
avec = np.zeros(net.nfillers)
2116
for f in range(net.nfillers):
2117
avec[f] = get_a(n, net, f, r)
2118
return avec
2119
elif (f is not None) and (r is None):
2120
avec = np.zeros(net.nroles)
2121
for r in range(net.nroles):
2122
avec[r] = get_a(n, net, f, r)
2123
return avec
2124
else:
2125
return get_a(n, net, f, r)
2126
2127
def Q0(net, n):
2128
# q0 = 0.0
2129
# for f in range(net.nfillers):
2130
# for r in range(net.nroles):
2131
# q0 += n2a(n, net, f=f, r=r)**2 * (1-n2a(n, net, f=f, r=r))**2
2132
# return q0
2133
# return -np.sum(n2a(n,net)**2 * (1-n2a(n,net))**2)
2134
a = net.S2C(n)
2135
return -np.sum(a**2 * (1-a)**2)
2136
2137
def Q0GradE(net, n):
2138
# Elementwise computation. Very slow.
2139
# Based on the first derivation
2140
q0grad = np.zeros(net.nunits)
2141
for phi in range(net.ndim_f):
2142
for rho in range(net.ndim_r):
2143
q0grad[u_ind(phi, rho, net)] = 0.0
2144
for f in range(net.nfillers):
2145
for r in range(net.nroles):
2146
a_fr = n2a(n, net, f, r)
2147
g_fr = 2 * a_fr * (1 - a_fr) * (1 - 2 * a_fr)
2148
q0grad[u_ind(phi, rho, net)] += w(f, r, phi, rho, net) * g_fr
2149
return -q0grad
2150
2151
def Q0Grad(net, n):
2152
# Based on the first derivation
2153
a = net.S2C(n)
2154
g = 2 * a * (1-a) * (1-2*a) # a vectorized version of g_{fr}
2155
#A_fr_phirho = np.sum(net.TPinv[:, phirho] * g)
2156
gmat = np.tile(g, (net.nunits, 1)).T
2157
q0grad = np.sum(net.TPinv * gmat, axis=0)
2158
return -q0grad
2159
2160
2161
def Q1(net, n):
2162
# quant_list (for now, ignore this)
2163
# q1 = 0.0
2164
# for r in range(net.nroles):
2165
# q1 += (np.sum(n2a(n, net, r=r)**2)-1)**2
2166
# return -np.sum((np.sum(net.vec2mat(n2a(n, net))**2, axis=0) - 1)**2)
2167
return -np.sum(np.sum(net.vec2mat(net.S2C(n))**2, axis=0)-1)**2
2168
2169
def Q1GradE(net, n): # Elementwise computation
2170
# Not yet tested.
2171
q1grad = np.zeros(net.nunits)
2172
for phi in range(net.ndim_f):
2173
for rho in range(net.ndim_r):
2174
unit_grad = 0.0
2175
for r in range(net.nroles):
2176
var1 = np.sum(n2a(n, net, r=r)**2) - 1
2177
var2 = 0.0
2178
for f in range(net.nfillers):
2179
var2 += n2a(n, net, f, r) * w(f, r, phi, rho, net)
2180
unit_grad += 4 * var1 * var2
2181
q1grad[u_ind(phi, rho, net)] = unit_grad
2182
return -q1grad
2183
2184
def Q1Grad(net, n):
2185
# Not yet tested.
2186
# term 1
2187
a = net.S2C(n)
2188
q1grad = 0.0
2189
for r_ind, rr in enumerate(net.role_names):
2190
curr_binding_ind = net.find_roles(rr)
2191
amat = np.tile(a[curr_binding_ind], (net.nunits, 1)).T
2192
term2 = np.sum(net.TPinv[curr_binding_ind, :] * amat, axis=0)
2193
term1 = np.sum(a[curr_binding_ind] ** 2) - 1
2194
q1grad += term1 * term2
2195
q1grad = 4 * q1grad
2196
return -q1grad
2197
2198
2199
def QGrad(n, net):
2200
# So far, the c parameter indicates the relative strength of Q1.
2201
# In this function, it indicates the relative strength of Q0.
2202
return net.c * Q0Grad(net, n) + (1-net.c) * Q1Grad(net, n)
2203
2204
def Q(net, n): # Q
2205
return net.c * Q0(net, n) + (1-net.c) * Q1(net, n)
2206
2207