Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Download
436 views
1
# -*- coding: utf-8 -*-
2
"""
3
GSC Software:
4
Author: Pyeong Whan Cho
5
Department of Cognitive Science, Johns Hopkins University
6
Summer 2015
7
"""
8
from math import *
9
import numpy as np
10
from matplotlib import pyplot as plt
11
from scipy import optimize as optim
12
import pandas as pd
13
import pylab as P
14
from scipy import linalg as la
15
from numbers import Number
16
import sys
17
import copy
18
19
20
class GscNet():
21
def __init__(self, filler_names, role_names,
22
WGC=None, bGC=None, extC=None, prev_extC=None, z=0.5, beta=4,
23
q_init=0, q_rate=4, q_max=50, q_fun='plinear', c=0.5, # check q_max
24
T_init=0, dt=0.001, reptype_r='local', reptype_f='local',
25
dp_r=0, dp_f=0, ndim_r=None, ndim_f=None,
26
T_decay_rate=0, T_min=0, quant_list=None, grid_points=None):
27
'''Construct an instance of GscNet.'''
28
# filler_names: a list of filler names
29
# role_names: a list of role names
30
# hgRules: a list of harmonic grammar rules
31
32
# Local representation
33
self.binding_names = bind(filler_names, role_names)
34
self.filler_names = filler_names
35
self.role_names = role_names
36
self.nbindings = len(self.binding_names) # number of bindings
37
self.nfillers = len(filler_names) # number of fillers
38
self.nroles = len(role_names) # number of roles
39
40
# Network parameters (C-space) --- local representation
41
if WGC is None:
42
WGC = np.zeros(self.nbindings**2).reshape(self.nbindings, self.nbindings, order='F')
43
if bGC is None:
44
bGC = np.zeros(self.nbindings).reshape(self.nbindings)
45
if extC is None:
46
extC = np.zeros(self.nbindings).reshape(self.nbindings)
47
if prev_extC is None:
48
prev_extC = np.zeros(self.nbindings).reshape(self.nbindings)
49
50
# self.WC = WC # weight matrix
51
# self.bC = bC # bias vector
52
self.WGC = WGC # weight matrix
53
self.bGC = bGC # bias vector
54
self.extC = extC # external input
55
self.prev_extC = prev_extC
56
57
# Network parameters (S-space) --- distributed representation
58
# If 'local' is chosen, these parameters would be same as the above parameters.
59
if ndim_r is None:
60
ndim_r = self.nroles
61
if ndim_f is None:
62
ndim_f = self.nfillers
63
64
self.reptype_r = reptype_r
65
self.reptype_f = reptype_f
66
self.ndim_r = ndim_r
67
self.ndim_f = ndim_f
68
69
# R: a matrix whose column vectors correspond to roles.
70
# F: a matrix whose column vectors correspond to fillers.
71
# TP: a matrix whose column vectors correspond to bindings.
72
# it is used for a state conversion from C- to S-space.
73
# TPinv: the inverse of TP. Used for a state conversion from S- to C-Space.
74
# for now, only works for symmetric TP.
75
self.R = encode_symbols(self.nroles, reptype=reptype_r, dp=dp_r, ndim=ndim_r)
76
self.F = encode_symbols(self.nfillers, reptype=reptype_f, dp=dp_f, ndim=ndim_f)
77
self.TP, self.TPinv = compute_TPmat(self.R, self.F) # Pay attention to the argument order.
78
79
self.nunits = ndim_r * ndim_f # CHECK: should be same as the number of rows of TP
80
self.act = np.zeros(self.nunits) # activation state in S-space.
81
self.actC = self.S2C()
82
83
# self.W = compute_sspace_weights(self.WC, self)
84
# self.b = compute_sspace_biases(self.bC, self)
85
# self.ext = compute_sspace_biases(self.extC, self) # CHECK:
86
87
# Add the bowl with bowl parameters (C-space is assumed)
88
self.z = z # bowl center: vector or scalar
89
self.beta = beta # bowl strength
90
self.WBC = -self.beta * np.eye(self.nbindings)
91
self.bBC = self.beta * self.z * np.ones(self.nbindings)
92
#self.check_beta()
93
94
self.WC = self.WGC + self.WBC
95
self.bC = self.bGC + self.bBC
96
97
self.WG = compute_sspace_weights(self.WGC, self)
98
self.WB = compute_sspace_weights(self.WBC, self)
99
self.bG = compute_sspace_biases(self.bGC, self)
100
self.bB = compute_sspace_biases(self.bBC, self)
101
102
self.W = compute_sspace_weights(self.WC, self)
103
self.b = compute_sspace_biases(self.bC, self)
104
self.ext = compute_sspace_biases(self.extC, self)
105
self.prev_ext = compute_sspace_biases(self.prev_extC, self)
106
107
self.check_beta()
108
109
# Quantization dynamics parameter
110
self.q_init = q_init # quantization strength
111
self.q = copy.deepcopy(self.q_init)
112
self.q_rate = q_rate
113
self.q_fun = q_fun
114
self.q_max = q_max # used when qfun = 'linear'
115
self.q_time = 0
116
self.c = c # quantization parameter
117
118
# Update parameters
119
self.T_init = T_init
120
self.T = copy.deepcopy(self.T_init) # temperature
121
self.T_decay_rate = T_decay_rate # temperature decay rate
122
self.T_min = T_min
123
self.dt = dt # time step constant
124
125
# Clamping state
126
self.clamped = False
127
self.act_trace = None
128
self.quant_list = quant_list
129
130
# Test if W is symmetric
131
if (self.W.T == self.W).all() == False:
132
sys.exit("The weight matrix (2D array) is not symmetric. Please check it.")
133
134
# unit labels
135
ndigits = len(str(self.nunits))
136
self.unit_names = ['U' + str(ii+1).zfill(ndigits) for ii in list(range(self.nunits))]
137
138
self.speed = None
139
self.ema_speed = None
140
self.grid_points = grid_points
141
142
def randomize_state(self, minact=0, maxact=1):
143
'''Set the activation state to a random vector inside a unit hypercube'''
144
actC = np.random.uniform(minact, maxact, self.nbindings)
145
self.act = self.C2S(actC)
146
147
def randomize_weights(self, mu=0, sigma=1):
148
'''Randomize WC'''
149
WC = sigma*np.random.randn(self.nbindings**2).reshape(self.nbindings, self.nbindings, order='F') + mu
150
WC = (WC + WC.T)/2
151
self.WC = WC
152
self.W = compute_sspace_weights(self.WC, self)
153
154
##########################################################################
155
###
156
### set weights and biases
157
###
158
##########################################################################
159
160
def set_weight(self, binding_name1, binding_name2, weight, symmetric=True):
161
'''Set the weight of a connection between binding1 and binding2.
162
When symmetric is set to True (default), the connection weight from
163
binding2 to binding1 is set to the same value.'''
164
idx1 = self.find_bindings(binding_name1)
165
idx2 = self.find_bindings(binding_name2)
166
if symmetric:
167
self.WGC[idx1, idx2] = self.WGC[idx2, idx1] = weight
168
else:
169
self.WGC[idx2, idx1] = weight
170
self.WC = self.WGC + self.WBC
171
self.WG = compute_sspace_weights(self.WGC, self)
172
self.W = compute_sspace_weights(self.WC, self)
173
174
def set_bias(self, binding_name, bias):
175
'''Set the bias of a binding to [bias].'''
176
idx = self.find_bindings(binding_name)
177
self.bGC[idx] = bias
178
self.bC = self.bGC + self.bBC
179
self.bG = compute_sspace_biases(self.bGC, self)
180
self.b = compute_sspace_biases(self.bC, self)
181
182
def set_role_bias(self, role_name, bias):
183
'''Set the bias of bindings of all fillers with particular roles to [bias].'''
184
role_list = [bb.split('/')[1] for bb in self.binding_names]
185
if not isinstance(role_name, list):
186
role_name = [role_name]
187
for jj, role in enumerate(role_name):
188
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
189
self.bGC[idx] = bias
190
191
self.bC = self.bGC + self.bBC
192
self.bG = compute_sspace_biases(self.bGC, self)
193
self.b = compute_sspace_biases(self.bC, self)
194
195
def set_filler_bias(self, filler_name, bias):
196
'''Set the bias of bindings of all roles with particular fillers to [bias].'''
197
filler_list = [bb.split('/')[0] for bb in self.binding_names]
198
if not isinstance(filler_name, list):
199
filler_name = [filler_name]
200
for jj, filler in enumerate(filler_name):
201
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
202
self.bGC[idx] = bias
203
self.bC = self.bGC + self.bBC
204
self.bG = compute_sspace_biases(self.bGC, self)
205
self.b = compute_sspace_biases(self.bC, self)
206
207
##########################################################################
208
###
209
### Etc
210
###
211
##########################################################################
212
213
def vec2mat(self, actC=None):
214
if actC is None:
215
actC = self.S2C()
216
return actC.reshape(self.nfillers, self.nroles, order='F')
217
218
def C2S(self, C=None):
219
'''Change basis: from C- to S-space.'''
220
221
if C is None:
222
C = self.actC
223
return self.TP.dot(C)
224
225
def S2C(self, S=None):
226
'''Change basis: from S- to C-space.'''
227
228
if S is None:
229
S = self.act
230
return self.TPinv.dot(S)
231
232
def find_bindings(self, binding_names):
233
'''Find the indices of the bindings from the list of binding names.'''
234
235
if not isinstance(binding_names, list):
236
binding_names = [binding_names]
237
return [self.binding_names.index(bb) for bb in binding_names]
238
239
def find_roles(self, role_name):
240
241
if not isinstance(role_name, list):
242
role_name = [role_name]
243
244
role_list = [bb.split('/')[1] for bb in self.binding_names]
245
role_idx = []
246
for jj, role in enumerate(role_name):
247
#idx = [ii for ii, rr in enumerate(role_list) if role_name == rr]
248
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
249
role_idx += idx
250
#idx_all = [[ii for ii, rr in enumerate(role_list) if role == rr] for jj, role in enumerate(role_name)]
251
return role_idx
252
253
def find_fillers(self, filler_name):
254
255
if not isinstance(filler_name, list):
256
filler_name = [filler_name]
257
258
filler_list = [bb.split('/')[0] for bb in self.binding_names]
259
filler_idx = []
260
for jj, filler in enumerate(filler_name):
261
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
262
filler_idx += idx
263
return filler_idx
264
265
def read_state(self, act=None):
266
'''Print the current state (C-SPACE) in a readable format. Pandas should be installed.'''
267
268
if act is None:
269
act = self.act
270
#actC = self.S2C().reshape(self.nfillers, self.nroles, order='F')
271
actC = self.vec2mat(self.S2C(act))
272
print(pd.DataFrame(actC, index=self.filler_names, columns=self.role_names))
273
274
def read_grid_point(self, act=None, disp=True):
275
'''Print a grid point close to the current state. The grid point will be
276
chosen by the snapping method (winner-takes-it-all).'''
277
278
if act is None:
279
act = self.act
280
281
if self.quant_list is None:
282
# By default, it is assumed that a role can be bound with only one filler.
283
actC = self.vec2mat(self.S2C(act))
284
winner_idx = np.argmax(actC, axis=0)
285
winners = [self.filler_names[ii] for ii in winner_idx]
286
winners = ["%s/%s" % bb for bb in zip(winners, self.role_names)]
287
else:
288
# quant_list: a list of lists of binding names.
289
actC = self.S2C(act) # pattern coordinate
290
winners = []
291
for kk, group in enumerate(self.quant_list):
292
idx = self.find_bindings(group)
293
winner_idx = [idx[ii] for ii, jj in enumerate(actC[idx]) if jj == max(actC[idx])]
294
winner = [self.binding_names[jj] for ii, jj in enumerate(winner_idx)]
295
winners.extend(winner)
296
if disp:
297
print(winners)
298
return winners
299
300
def read_weight(self, which='WGC'):
301
'''Print the weight matrix in a readable format (in the pattern coordinate).'''
302
if which[-1] == 'C':
303
print(pd.DataFrame(getattr(self, which), index=self.binding_names, columns=self.binding_names))
304
else:
305
print(pd.DataFrame(getattr(self, which), index=self.unit_names, columns=self.unit_names))
306
307
def read_bias(self, which='bGC', print_vertical=True):
308
'''Print the bias vector (in the pattern coordinate).'''
309
if which[-1] == 'C':
310
if print_vertical:
311
print(pd.DataFrame(getattr(self, which).reshape(self.nbindings, 1), index=self.binding_names, columns=["bias"]))
312
else:
313
print(pd.DataFrame(getattr(self, which).reshape(1, self.nbindings), index=["bias"], columns=self.binding_names))
314
else:
315
if print_vertical:
316
print(pd.DataFrame(getattr(self, which).reshape(self.nbindings, 1), index=self.unit_names, columns=["bias"]))
317
else:
318
print(pd.DataFrame(getattr(self, which).reshape(1, self.nbindings), index=["bias"], columns=self.unit_names))
319
320
def hinton(self, which='WGC', label=True):
321
'''Draw a hinton diagram'''
322
# CHECK: which of W or W-beta*I should be visualized?
323
if label:
324
if which[-1] == 'C':
325
labels=self.binding_names
326
else:
327
labels=self.unit_names
328
else:
329
labels=None
330
hinton(getattr(self, which), xlabels=labels, ylabels=labels)
331
332
def act_clamped(self, act=None):
333
'''S-space'''
334
if act is None:
335
act = self.act
336
return self.projmat.dot(act) + self.clampvec
337
338
def check_beta(self):
339
'Compute and print the recommended beta value given the weights and biases in the C-space.'''
340
eigvals, eigvecs = la.eigh(self.WGC) # WGC should be a symmetric matrix. So eigh() was used instead of eig()
341
eig_max = max(eigvals) # Condition 1: beta > eig_max to be stable
342
if self.nbindings == 1:
343
beta1 = -(self.bGC+self.extC)/self.z
344
beta2 = (self.bGC+self.extC+eig_max)/(1-self.z)
345
else:
346
beta1 = -min(self.bGC+self.extC)/self.z # Condition 2: beta > beta1
347
beta2 = (max(self.bGC+self.extC)+eig_max)/(1-self.z) # Condition 3: beta > beta2
348
beta_min = max(eig_max, beta1, beta2)
349
if self.beta <= beta_min:
350
sys.exit("Beta (bowl strength) should be greater than %.4f." % beta_min)
351
352
def info(self):
353
'''Print the network information.'''
354
print('Fillers = ', self.filler_names)
355
print('Roles = ', self.role_names)
356
print('Num_of_units = ', self.nunits)
357
print('Current T = ', self.T)
358
# Add more information ...
359
360
def compute_dist(self, ref_point, norm_ord=2, space='S'):
361
"""
362
Compute the distance of the current state from a grid point.
363
[grid point] is a set of bindings.
364
"""
365
366
idx = self.find_bindings(ref_point)
367
destC = np.zeros(self.nbindings)
368
destC[idx] = 1.0
369
370
if space == 'S':
371
state1 = self.act
372
state2 = self.C2S(destC)
373
elif space == 'C':
374
state1 = self.S2C(self.act)
375
state2 = destC
376
377
return np.linalg.norm(state1-state2, ord=norm_ord)
378
379
380
##########################################################################
381
###
382
### Set or clear external input / Clamp or unclamp
383
###
384
##########################################################################
385
386
def set_input(self, binding_names, ext_vals, inhib_comp=False):
387
'''Set external input.'''
388
if not isinstance(ext_vals, list):
389
ext_vals = [ext_vals]
390
if not isinstance(binding_names, list):
391
binding_names = [binding_names]
392
if len(ext_vals) > 1:
393
if len(binding_names) != len(ext_vals):
394
sys.exit("binding_names and ext_vals have different lengths.")
395
396
self.clear_input() # Consider removing this line.
397
if inhib_comp:
398
# Check qunat_list
399
if self.quant_list is None:
400
# Assume filler competition in each role.
401
# Find the current role.
402
role_names = [b.split('/')[1] for b in binding_names]
403
idx = self.find_roles(role_names)
404
self.extC[idx] = -np.asarray(ext_vals) # -ext_vals (list object)
405
else:
406
for ii, bb in enumerate(binding_names):
407
group_id = [jj for jj, group in enumerate(self.quant_list) if bb in group]
408
for kk, gid in enumerate(group_id):
409
curr_bindings = self.quant_list[gid]
410
idx = self.find_bindings(curr_bindings)
411
self.extC[idx] = -np.asarray(ext_vals) # -ext_vals (list object)
412
413
idx = self.find_bindings(binding_names)
414
self.extC[idx] = ext_vals
415
self.ext = compute_sspace_biases(self.extC, self)
416
417
def clear_input(self):
418
'''Remove external input.'''
419
self.extC = np.zeros(self.nbindings)
420
self.ext = compute_sspace_biases(self.extC, self)
421
422
def clamp(self, binding_names, clamp_vals=1.0, clamp_comp=False):
423
'''Clamp f/r bindings'''
424
# A matrix of basic vectors each of which corresponds to a filler/role
425
# binding whose activation state can change.
426
if not isinstance(clamp_vals, list):
427
clamp_vals = [clamp_vals]
428
if not isinstance(binding_names, list):
429
binding_names = [binding_names]
430
if len(clamp_vals) > 1:
431
if len(clamp_vals) != len(binding_names):
432
sys.exit('The number of bindings clamped is not equal to the number of values provided.')
433
434
self.clamped = True
435
self.binding_names_clamped = binding_names
436
clampvecC = np.zeros(self.nbindings)
437
438
# ADD
439
if clamp_comp:
440
# Check qunat_list
441
if self.quant_list is None:
442
# Assume filler competition in each role.
443
# Find the current role.
444
role_names = [b.split('/')[1] for b in binding_names]
445
idx1 = self.find_roles(role_names)
446
clampvecC[idx1] = 0.0
447
else:
448
for ii, bb in enumerate(binding_names):
449
group_id = [jj for jj, group in enumerate(self.quant_list) if bb in group]
450
for kk, gid in enumerate(group_id):
451
curr_bindings = self.quant_list[gid]
452
idx1 = self.find_bindings(curr_bindings)
453
clampvecC[idx1] = 0.0
454
455
idx = self.find_bindings(binding_names)
456
clampvecC[idx] = clamp_vals
457
self.clampvecC = clampvecC
458
459
if clamp_comp:
460
idx += idx1 # CHECK
461
idx.sort()
462
463
# Choose unclamped bindings. --- free to vary.
464
idx0 = [bb for bb in np.arange(self.nbindings) if bb not in idx]
465
A = self.TP[:, idx0] # complement set of idx (basis vectors of the subspace)
466
if len(idx0) > 0:
467
self.projmat = compute_projmat(A)
468
else:
469
self.projmat = np.zeros((self.nunits, self.nunits))
470
self.clampvec = self.C2S(clampvecC)
471
self.act = self.act_clamped(self.act)
472
self.actC = self.S2C()
473
474
def unclamp(self):
475
if self.clamped is True:
476
del self.clampvec
477
del self.clampvecC
478
del self.projmat
479
del self.binding_names_clamped
480
self.clamped = False
481
482
##########################################################################
483
###
484
### Harmony and harmony gradient
485
###
486
##########################################################################
487
488
def harmony(self, act=None, quant_list=None):
489
'''Compute the total harmony of the current state'''
490
if act is None:
491
act = self.act
492
493
if quant_list is None:
494
quant_list = self.quant_list
495
496
return harmony(net=self, act=act, quant_list=quant_list)
497
498
def gharmony(self, act=None):
499
'''Compute the grammar harmony of the current state'''
500
if act is None:
501
act = self.act
502
return gharmony(net=self, act=act)
503
504
def bharmony(self, act=None):
505
'''Compute the grammar harmony of the current state'''
506
if act is None:
507
act = self.act
508
return bharmony(net=self, act=act)
509
510
def oharmony(self, act=None):
511
'''Compute the grammar harmony of the current state'''
512
if act is None:
513
act = self.act
514
return oharmony(net=self, act=act)
515
516
def qharmony(self, act=None, quant_list=None):
517
'''Compute the grammar harmony of the current state'''
518
if act is None:
519
act = self.act
520
521
if quant_list is None:
522
quant_list = self.quant_list
523
524
return q2harmony(net=self, act=act, quant_list=quant_list)
525
526
def hgrad(self, act=None, quant_list=None):
527
'''Compute the harmony gradient evaluated at the current state'''
528
if act is None:
529
act = self.act
530
return hgrad(net=self, act=act, quant_list=quant_list)
531
532
# def hessian(self, act=None, quant_list=None):
533
# '''Compute the Hessian (the Jacobain of the gradient) evaluated at the current state'''
534
# if act is None:
535
# act = self.act
536
# return hhess(net=self, act=act, quant_list=quant_list)
537
538
##########################################################################
539
###
540
### Update state and parameters
541
###
542
##########################################################################
543
544
#def update(self, update_T=True, update_q=True, q_fun='log', quant_list=None):
545
def update(self, update_T=True, update_q=True, q_fun=None, space='S', norm_ord=2):
546
547
if q_fun is None:
548
q_fun = self.q_fun
549
550
self.prev_act = copy.deepcopy(self.act)
551
self.update_state()
552
self.update_speed(space=space, norm_ord=norm_ord)
553
if self.grid_points is not None:
554
self.update_dist(self.grid_points, space=space, norm_ord=norm_ord)
555
556
if update_T:
557
self.update_T()
558
if update_q:
559
self.update_q(fun=q_fun)
560
561
def update_state(self):
562
'''Update the current state (with noise)'''
563
564
self.act += self.dt * self.hgrad(quant_list=self.quant_list) # update state (deterministic change)
565
self.add_noise() # add noise (stochastic part)
566
if self.clamped:
567
self.act = self.act_clamped()
568
self.actC = self.S2C()
569
570
def add_noise(self):
571
'''Add a noise to the current activation state.'''
572
573
noise = sqrt(2 * self.T * self.dt) * np.random.randn(self.nunits)
574
self.act += noise
575
576
def update_T(self, method='exponential'):
577
'''Update the current temperature'''
578
579
# In the LDNet, the decay rate per timestep is adjusted based on the time step constant.
580
# For simplicity, the decay rate is assumed to be per-step.
581
if method == 'exponential':
582
self.T = (1-self.T_decay_rate) * (self.T - self.T_min) + self.T_min
583
584
def update_q(self, fun='plinear'):
585
586
if fun == 'log':
587
self.q_time += 1
588
self.q = self.q_rate * log(self.q_time)
589
# self.q = self.q_rate * log(exp(self.q/self.q_rate)+1) # without q_time
590
if fun == 'linear':
591
self.q += self.q_rate
592
if fun == 'plinear':
593
# piecewise linear
594
if self.q <= self.q_max:
595
self.q += self.q_rate
596
597
def reset(self):
598
599
self.q = self.q_init
600
self.q_time = 0
601
self.T = self.T_init
602
self.randomize_state()
603
# self.unclamp()
604
# self.clear_input()
605
606
def update_speed(self, ema_factor=0.1, norm_ord=2, space='S'):
607
if space == 'S':
608
diff = self.act - self.prev_act
609
elif space == 'C':
610
diff = self.S2C(self.act) - self.S2C(self.prev_act)
611
612
self.speed = np.linalg.norm(diff, ord=norm_ord) #/ net.dt
613
if self.ema_speed is None:
614
self.ema_speed = self.speed
615
else:
616
self.ema_speed = ema_factor * self.ema_speed + (1-ema_factor) * self.speed
617
#return self.ema_speed
618
619
def update_dist(self, grid_points, norm_ord=2, space='S'):
620
621
if not any(isinstance(tt, list) for tt in grid_points):
622
grid_points = [grid_points]
623
624
dist = np.zeros(len(grid_points))
625
for ii, grid_point in enumerate(grid_points):
626
dist[ii] = self.compute_dist(ref_point=grid_point, norm_ord=norm_ord, space=space)
627
628
self.dist = dist
629
630
def update_traces(self):
631
632
self.act_trace[self.step, :] = self.act
633
self.h_trace[self.step] = self.harmony()
634
self.oh_trace[self.step] = self.oharmony()
635
self.qh_trace[self.step] = self.qharmony()
636
self.gh_trace[self.step] = self.gharmony()
637
self.bh_trace[self.step] = self.bharmony()
638
self.speed_trace[self.step] = self.speed
639
self.ema_speed_trace[self.step] = self.ema_speed
640
self.q_trace[self.step] = self.q
641
self.T_trace[self.step] = self.T
642
643
if self.grid_points is not None:
644
self.dist_trace[self.step, :] = self.dist
645
646
##########################################################################
647
###
648
### Run
649
###
650
##########################################################################
651
652
def run(self, maxstep, plot=False, grayscale=False, colorbar=True,
653
tol=None, ema_factor=0.1, update_T=True, update_q=True, q_fun=None,
654
grid_points=None, testvar='ema_speed', space='S', norm_ord=2,
655
ext_overlap=False, ext_overlap_steps=1000):
656
'''
657
Run simulations for [maxstep] steps.
658
'''
659
if q_fun is None:
660
q_fun = self.q_fun
661
662
if grid_points is None:
663
grid_points = self.grid_points
664
665
# initialize
666
self.converged = False
667
self.act_trace = np.zeros((maxstep+1, self.nunits))
668
self.h_trace = np.zeros(maxstep+1)
669
self.oh_trace = np.zeros(maxstep+1)
670
self.qh_trace = np.zeros(maxstep+1)
671
self.gh_trace = np.zeros(maxstep+1)
672
self.bh_trace = np.zeros(maxstep+1)
673
self.speed_trace = np.zeros(maxstep+1)
674
self.ema_speed_trace = np.zeros(maxstep+1)
675
self.q_trace = np.zeros(maxstep+1)
676
self.T_trace = np.zeros(maxstep+1)
677
678
if grid_points is not None:
679
if not any(isinstance(pp, list) for pp in grid_points):
680
grid_points = [grid_points]
681
n_grid_points = len(grid_points)
682
self.dist_trace = np.zeros((maxstep+1, len(grid_points)))
683
684
# Log initial state
685
self.step = 0
686
if grid_points is not None:
687
self.dist = np.empty(n_grid_points)
688
self.dist[:] = np.NAN
689
self.update_traces()
690
691
curr_extC = self.extC
692
for tt in np.arange(maxstep)+1:
693
self.step = tt
694
if ext_overlap:
695
# Linear overlap
696
# update extC
697
self.extC = (1 - tt/ext_overlap_steps) * self.prev_extC + (tt/ext_overlap_steps) * curr_extC
698
699
self.update(update_T=update_T, update_q=update_q, q_fun=q_fun,
700
space=space, norm_ord=norm_ord)
701
self.update_traces()
702
703
if tol is not None:
704
self.check_convergence(tol=tol, testvar=testvar)
705
if self.converged:
706
break
707
708
self.rt = self.step
709
710
if plot:
711
actC_trace = self.S2C(self.act_trace[:(self.rt+1),:].T).T
712
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings",
713
yticklabels=self.binding_names, grayscale=grayscale, colorbar=colorbar)
714
715
def check_convergence(self, tol, testvar='ema_speed'):
716
'''Check if the convergence criterion (distance vs. ema_speed) has been satisfied.'''
717
718
if testvar == 'dist':
719
if (self.dist < tol).any():
720
self.converged = True
721
722
if testvar == 'ema_speed':
723
if self.ema_speed < tol:
724
self.converged = True
725
726
def simulate(self):
727
728
return 0
729
730
def plot_act_trace(self, timesteps=None):
731
732
if hasattr(self, 'act_trace'):
733
if timesteps is None:
734
timesteps = list(range(self.rt+1))
735
736
actC_trace = self.S2C(self.act_trace[timesteps,:].T).T
737
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings",
738
xticklabels=None, yticklabels=self.binding_names, grayscale=False)
739
else:
740
sys.exit("There is no [act_trace] attribute in the current object.")
741
742
def plot_trace(self, varname, timesteps=None):
743
if timesteps is None:
744
timesteps = list(range(self.rt+1))
745
curr_trace = getattr(self, varname+'_trace')
746
plt.plot(curr_trace[timesteps])
747
plt.xlabel('Timestep', fontsize=16)
748
plt.ylabel(varname, fontsize=16)
749
plt.show()
750
751
def plot_dist_trace(self, timesteps=None):
752
if timesteps is None:
753
timesteps = list(range(self.rt+1))
754
plt.plot(self.dist_trace[timesteps, :])
755
plt.xlabel('Timestep', fontsize=16)
756
plt.ylabel('Distance', fontsize=16)
757
plt.show()
758
759
##########################################################################
760
###
761
### Update state and parameters
762
###
763
##########################################################################
764
765
def optim(self, initvals=[], method='nelder-mead', options={'xtol': 1e-8, 'disp': True}):
766
'''Find a local optimum given an initial guess (by default, the current state is used)'''
767
# CHECK: does this work correctly when some bindings are clamped?
768
if initvals==[]:
769
initvals = self.act
770
res = optim.minimize(lambda x: -harmony(self, x), initvals, method='nelder-mead', options=options)
771
res.fun = -res.fun # Convert the (locally) minimal negative harmony (energy) to the (locally) maximal harmony
772
return res
773
774
def harmony_landscape(self, binding_name, minact, maxact):
775
'''1D harmony landscape'''
776
# To-do: Currently it works only with local representation.
777
if isinstance(binding_name, list):
778
if len(binding_name) > 1:
779
sys.exit("Choose only one binding.!")
780
781
idx = self.binding_names.index(binding_name)
782
agrid = np.linspace(minact, maxact, 1000)
783
hgrid = np.zeros(len(agrid))
784
actC = self.S2C()
785
for ii, aa in enumerate(agrid):
786
act0 = actC
787
act0[idx] = aa
788
act = self.C2S(act0)
789
hgrid[ii] = harmony(net=self, act=act)
790
return agrid, hgrid
791
792
793
794
class WeightTP():
795
796
def __init__(self, binding_names):
797
self.binding_names = binding_names
798
self.nbindings = len(binding_names)
799
self.WGC = np.zeros((self.nbindings, self.nbindings))
800
801
def set_weight(self, binding1, binding2, weight):
802
idx1 = self.binding_names.index(binding1)
803
idx2 = self.binding_names.index(binding2)
804
self.WGC[idx1, idx2] = self.WGC[idx2, idx1] = weight
805
806
def show(self):
807
print(pd.DataFrame(self.WGC, index=self.binding_names, columns=self.binding_names))
808
809
class BiasTP():
810
def __init__(self, binding_names):
811
self.binding_names = binding_names
812
self.nbindings = len(binding_names)
813
self.bGC = np.zeros(self.nbindings)
814
815
def set_bias(self, binding, bias):
816
idx = self.binding_names.index(binding)
817
self.bGC[idx] = bias
818
819
def set_role_bias(self, role, bias):
820
role_list = [b.split('/')[1] for b in self.binding_names]
821
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
822
self.bGC[idx] = bias
823
824
def set_filler_bias(self, filler, bias):
825
filler_list = [b.split('/')[0] for b in self.binding_names]
826
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
827
self.bGC[idx] = bias
828
829
def show(self):
830
print(pd.DataFrame(self.bGC, index=self.binding_names, columns=["bias"]))
831
832
833
# wh-question
834
# To-do: quant_list should be considered in read_grid_point()
835
class GscNet0():
836
# local representation first.
837
def __init__(self, net_list, group_names):
838
self.ngroups = len(net_list)
839
self.group_names = group_names
840
self.groups = net_list
841
842
self.binding_names = []
843
for idx, net in enumerate(self.groups):
844
self.binding_names = self.binding_names + [b + ':' + group_names[idx] for b in net.binding_names]
845
846
self.nbindings = len(self.binding_names)
847
848
self.WGC = np.zeros((self.nbindings, self.nbindings))
849
self.WBC = np.zeros((self.nbindings, self.nbindings))
850
for ii, net in enumerate(self.groups):
851
curr_bindings = [b + ':' + group_names[ii] for b in net.binding_names]
852
idx = [self.binding_names.index(bb) for bb in curr_bindings]
853
self.WGC[np.ix_(idx,idx)] = net.WGC
854
self.WBC[np.ix_(idx,idx)] = net.WBC
855
self.WC = self.WGC + self.WBC
856
857
self.bGC = np.zeros(self.nbindings)
858
self.bBC = np.zeros(self.nbindings)
859
self.extC = np.zeros(self.nbindings)
860
nunits = 0
861
self.group_unit_idx = list()
862
self.group_binding_idx = list()
863
for ii, net in enumerate(self.groups):
864
curr_bindings = [b + ':' + group_names[ii] for b in net.binding_names]
865
idx = [self.binding_names.index(bb) for bb in curr_bindings]
866
self.bGC[np.ix_(idx)] = net.bGC
867
self.bBC[np.ix_(idx)] = net.bBC
868
self.extC[np.ix_(idx)] = net.extC
869
unit_idx = np.arange(net.nunits) + nunits
870
self.group_unit_idx.append(unit_idx) # WRONG ....
871
self.group_binding_idx.append(idx)
872
nunits = nunits + net.nunits
873
self.bC = self.bGC + self.bBC
874
875
self.nunits = nunits
876
877
self.W = self.compute_sspace_weights()
878
self.b, self.ext = self.compute_sspace_b_and_ext()
879
self.act = np.zeros(self.nunits)
880
881
def compute_sspace_b_and_ext(self):
882
b_distributed = np.zeros(self.nunits)
883
ext_distributed = np.zeros(self.nunits)
884
for ii, net in enumerate(self.groups):
885
b_distributed[self.group_unit_idx[ii]] = net.b
886
ext_distributed[self.group_unit_idx[ii]] = net.ext
887
return b_distributed, ext_distributed
888
889
# The program does not use this:
890
def compute_sspace_weights(self):
891
W_distributed = np.zeros((self.nunits, self.nunits))
892
for ii, net1 in enumerate(self.groups):
893
idx1 = self.group_unit_idx[ii]
894
bidx1 = self.group_binding_idx[ii]
895
for jj, net2 in enumerate(self.groups):
896
idx2 = self.group_unit_idx[jj]
897
bidx2 = self.group_binding_idx[jj]
898
W_local = self.WC[np.ix_(bidx1, bidx2)]
899
if ii == jj:
900
W_distributed[np.ix_(idx1, idx1)] = net1.W
901
else:
902
W_temp = np.zeros((net1.nunits, net2.nunits)) # from net2 to net1
903
for kk in np.arange(0, net1.nbindings):
904
si = net1.TP[:, kk]
905
for ll in np.arange(0, net2.nbindings):
906
sj = net2.TP[:, ll]
907
#W_temp = W_temp + W_local[kk,ll] * (np.outer(si, sj) + np.outer(sj, si)) / (np.dot(si, si) * np.dot(sj, sj))
908
W_temp = W_temp + W_local[kk,ll] * np.outer(si, sj) / (np.dot(si, si) * np.dot(sj, sj))
909
W_distributed[np.ix_(idx1, idx2)] = W_temp
910
return W_distributed
911
912
def randomize_state(self, minact=0, maxact=1):
913
'''Set the activation state to a random vector inside a unit hypercube'''
914
for ii, net in enumerate(self.groups):
915
net.randomize_state()
916
917
# def randomize_weights(self, mu=0, sigma=1):
918
# '''Randomize WC'''
919
# WC = sigma*np.random.randn(self.nbindings**2).reshape(self.nbindings, self.nbindings, order='F') + mu
920
# WC = (WC + WC.T)/2
921
# self.WC = WC
922
# self.W = compute_sspace_weights(self.WC, self)
923
924
##########################################################################
925
###
926
### set weights and biases (Use this only for setting the weights between networks)
927
###
928
##########################################################################
929
930
def set_weight(self, binding_name1, binding_name2, weight, symmetric=True):
931
idx1 = self.binding_names.index(binding_name1)
932
idx2 = self.binding_names.index(binding_name2)
933
if symmetric:
934
self.WGC[idx1, idx2] = self.WGC[idx2, idx1] = weight
935
else:
936
self.WGC[idx2, idx1] = weight
937
self.WC = self.WGC + self.WBC
938
self.W = self.compute_sspace_weights()
939
940
def set_bias(self, binding_name, bias):
941
idx = self.binding_names.index(binding_name)
942
self.bGC[idx] = bias
943
self.bC = self.bGC + self.bBC
944
self.b, self.ext = self.compute_sspace_b_and_ext()
945
946
def set_role_bias(self, role_name, bias):
947
role_list = [bb.split('/')[1] for bb in self.binding_names]
948
idx = [ii for ii, rr in enumerate(role_list) if role_name in rr]
949
self.bGC[idx] = bias
950
self.bC = self.bGC + self.bBC
951
self.b, self.ext = self.compute_sspace_b_and_ext()
952
953
def set_filler_bias(self, filler_name, bias):
954
filler_list = [bb.split('/')[0] for bb in self.binding_names]
955
idx = [ii for ii, ff in enumerate(filler_list) if filler_name in ff]
956
self.bGC[idx] = bias
957
self.bC = self.bGC + self.bBC
958
self.b, self.ext = self.compute_sspace_b_and_ext()
959
960
##########################################################################
961
###
962
### Etc --- CHECK if it works
963
###
964
##########################################################################
965
966
def set_seed(self, num):
967
np.random.seed(num)
968
969
def find_bindings(self, binding_names):
970
'''Find the indices of the bindings from the list of binding names.'''
971
972
if not isinstance(binding_names, list):
973
binding_names = [binding_names]
974
return [self.binding_names.index(bb) for bb in binding_names]
975
976
def read_state(self):
977
'''Print the current state (C-SPACE) in a readable format. Pandas should be installed.'''
978
for ii, net in enumerate(self.groups):
979
print("Group: %s" % self.group_names[ii])
980
net.read_state()
981
print("\n")
982
983
def read_grid_point(self):
984
# CHECK: current method assumes filler competition in each role.
985
# quant_list should be considered
986
for ii, net in enumerate(self.groups):
987
print("Group: %s" % self.group_names[ii])
988
print(net.read_grid_point(disp=False))
989
print("\n")
990
991
def read_weight(self):
992
print(pd.DataFrame(self.WC, index=self.binding_names, columns=self.binding_names))
993
print(pd.DataFrame(self.bC.reshape(1, self.nbindings), index=["bias"], columns=self.binding_names))
994
995
def hinton(self, label=True):
996
'''Draw a hinton diagram'''
997
# CHECK: which of W or W-beta*I should be visualized?
998
if label:
999
labels=self.binding_names
1000
else:
1001
labels=None
1002
hinton(self.WC, xlabels=labels, ylabels=labels)
1003
1004
# def act_clamped(self, act=None):
1005
# # Not used. Consider removing this part.
1006
# if act is None:
1007
# act = self.act
1008
# return self.projmat.dot(act) + self.clampvec
1009
#
1010
# def check_beta(self):
1011
# # Unused.
1012
# # To-do: eig() vs. eigh()
1013
# eigvals, eigvecs = la.eigh(self.W) # W should be a symmetric matrix.
1014
# eig_max = max(eigvals) # Condition 1: beta > eig_max to be stable
1015
# if self.nbindings == 1:
1016
# beta1 = -(self.b+self.ext)/self.z
1017
# beta2 = (self.b+self.ext+eig_max)/(1-self.z)
1018
# else:
1019
# beta1 = -min(self.b+self.ext)/self.z # Condition 2: beta > beta1
1020
# beta2 = (max(self.b+self.ext)+eig_max)/(1-self.z) # Condition 3: beta > beta2
1021
# beta_min = max(eig_max, beta1, beta2)
1022
# if self.beta <= beta_min:
1023
# sys.exit("Beta (bowl strength) should be greater than %.4f." % beta_min)
1024
#
1025
# def info(self):
1026
# '''Print the network information.'''
1027
# print('Fillers = ', self.fillers)
1028
# print('Roles = ', self.roles)
1029
# print('Num_of_units = ', self.nunits)
1030
# print('Current T = ', self.T)
1031
# # Add more information ...
1032
1033
##########################################################################
1034
###
1035
### Set or clear external input / Clamp or unclamp
1036
###
1037
##########################################################################
1038
1039
def set_input(self, binding_names, ext_vals):
1040
if len(binding_names) != len(ext_vals):
1041
sys.exit("binding_names and ext_vals have different lengths.")
1042
self.extC = np.zeros(self.nbindings)
1043
ind = self.find_bindings(binding_names)
1044
self.extC[ind] = ext_vals
1045
self.ext = compute_sspace_biases(self.extC, self)
1046
1047
def clear_input(self):
1048
self.extC = np.zeros(self.nbindings)
1049
self.ext = compute_sspace_biases(self.extC, self)
1050
1051
def clamp(self, binding_names, clamp_vals):
1052
'''Clamp f/r bindings'''
1053
# A matrix of basic vectors each of which corresponds to a filler/role
1054
# binding whose activation state can change.
1055
if len(clamp_vals) != len(binding_names):
1056
sys.exit('The number of bindings clamped is not equal to the number of values provided.')
1057
1058
self.clamped = True
1059
# self.binding_names_clamped = binding_names
1060
#
1061
# idx = self.find_bindings(binding_names)
1062
# clampvecC = np.zeros(self.nbindings)
1063
# clampvecC[idx] = clamp_vals
1064
# self.clampvecC = clampvecC
1065
#
1066
# # Choose unclamped bindings. --- free to vary.
1067
# idx0 = [bb for bb in np.arange(self.nbindings) if bb not in idx]
1068
# A = self.TP[:, idx0] # complement set of idx (basis vectors of the subspace)
1069
# self.projmat = compute_projmat(A)
1070
# self.clampvec = self.TP.dot(clampvecC)
1071
# self.act = self.act_clamped(self.act)
1072
1073
# CHECK: clamping each group
1074
temp = [b.split(':') for b in binding_names]
1075
group_num = [self.group_names.index(b[1]) for b in temp]
1076
bb_names = [b[0] for b in temp]
1077
1078
for ii, gg in enumerate(set(group_num)):
1079
idx1 = [jj for jj, val in enumerate(group_num) if val==gg]
1080
curr_bindings = [bb_names[jj] for jj in idx1]
1081
curr_vals = [clamp_vals[jj] for jj in idx1]
1082
self.groups[gg].clamp(binding_names=curr_bindings, clamp_vals=curr_vals)
1083
1084
def unclamp(self):
1085
self.clamped = False
1086
for ii, net in enumerate(self.groups):
1087
if net.clamped:
1088
net.unclamp()
1089
1090
# def clamp_group(self, group_name, clamp_vals):
1091
# self.group_clamped = group_name
1092
# group_idx = self.group_names.index(group_name)
1093
# binding_names = [b + ':' + group_name for b in self.groups[group_idx].binding_names]
1094
# self.clamp(binding_names, clamp_vals)
1095
1096
##########################################################################
1097
###
1098
### Update state and parameters
1099
###
1100
##########################################################################
1101
1102
def update(self, update_T=True, update_q=True):
1103
'''Update the current state (with noise) for each group'''
1104
for ii, net1 in enumerate(self.groups): # to
1105
extC = np.zeros(net1.nbindings)
1106
binding_names1 = [b + ':' + self.group_names[ii] for b in net1.binding_names]
1107
idx1 = [self.binding_names.index(b) for b in binding_names1]
1108
for jj, net2 in enumerate(self.groups): # from
1109
# between group weights
1110
# curr_W = self.W[idx1, idx2]
1111
if not (ii==jj):
1112
binding_names2 = [b + ':' + self.group_names[jj] for b in net2.binding_names]
1113
idx2 = [self.binding_names.index(b) for b in binding_names2]
1114
curr_WC = self.WC[np.ix_(idx1, idx2)]
1115
extC = extC + curr_WC.dot(net2.TPinv.dot(net2.act))
1116
net1.ext = net1.TP.dot(extC)
1117
for ii, net in enumerate(self.groups):
1118
net.update(update_T = update_T, update_q = update_q)
1119
self.act[self.group_unit_idx[ii]] = net.act
1120
1121
def update_state(self):
1122
'''Update the current state (with noise)'''
1123
self.act += self.dt * self.hgrad() # update state (deterministic change)
1124
self.add_noise() # add noise (stochastic part)
1125
if self.clamped:
1126
self.act = self.act_clamped()
1127
1128
def add_noise(self):
1129
'''Add a noise to the current activation state.'''
1130
noise = sqrt(2 * self.T * self.dt) * np.random.randn(self.nunits)
1131
self.act += noise
1132
1133
def update_T(self, method='exponential'):
1134
'''Update the current temperature'''
1135
# In the LDNet, the decay rate per timestep is adjusted based on the time step constant.
1136
# For simplicity, the decay rate is assumed to be per-step.
1137
if method == 'exponential':
1138
self.T = (1-self.T_decay_rate) * (self.T - self.T_min) + self.T_min
1139
1140
def update_q(self, method='log'):
1141
if method == 'log':
1142
self.q_time += 1
1143
self.q = self.q_factor * log(self.q_time)
1144
1145
def reset(self):
1146
for ii, net in enumerate(self.groups):
1147
net.reset()
1148
1149
##########################################################################
1150
###
1151
### Harmony and harmony gradient
1152
###
1153
##########################################################################
1154
1155
def harmony(self, act=None, quant_list=None):
1156
'''Compute the total harmony of the current state'''
1157
if act is None:
1158
act = self.act
1159
return harmony(net=self, act=act, quant_list=quant_list)
1160
1161
def gharmony(self, act=None):
1162
'''Compute the grammar harmony of the current state'''
1163
if act is None:
1164
act = self.act
1165
return gharmony(net=self, act=act)
1166
1167
def hgrad(self, act=None, quant_list=None):
1168
'''Compute the harmony gradient evaluated at the current state'''
1169
if act is None:
1170
act = self.act
1171
return hgrad(net=self, act=act, quant_list=quant_list)
1172
1173
1174
##########################################################################
1175
###
1176
### Run
1177
###
1178
##########################################################################
1179
1180
def run(self, maxstep, tol=None):
1181
self.rt, self.final_state, self.act_trace, self.ema_speed_trace = run_GscNet0(self, maxstep=maxstep, tol=tol)
1182
1183
def plot_act_trace(self):
1184
for ii, net in enumerate(self.groups):
1185
curr_act_trace = self.act_trace[:, self.group_unit_idx[ii]]
1186
actC_trace = net.TPinv.dot(curr_act_trace.T).T
1187
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings", yticklabels=net.binding_names, grayscale=False, colorbar=True)
1188
1189
# def run(self, maxstep, tol=None):
1190
# if tol is not None:
1191
# self.rt, self.final_state, self.act_trace, self.t_trace, self.q_trace, self.ema_speed_trace = run(self, maxstep=maxstep, tol=tol)
1192
# else:
1193
# self.rt, self.final_state, self.act_trace, self.t_trace, self.q_trace, self.ema_speed_trace = run(self, maxstep=maxstep)
1194
#
1195
# def plot_act_trace(self):
1196
# if hasattr(self, 'act_trace'):
1197
# heatmap(self.act_trace.T, xlabel="Timestep", ylabel="Bindings",
1198
# xticklabels=None, yticklabels=self.binding_names, grayscale=False)
1199
# else:
1200
# sys.exit("There is no [act_trace] attribute in the current object.")
1201
1202
##########################################################################
1203
###
1204
### Optim (CHECK)
1205
###
1206
##########################################################################
1207
1208
def gharmony_given_clamping(self, x, group_name_clamped, clampval, quant_list, q):
1209
x0 = np.zeros(self.nbindings)
1210
self.clamp_group(group_name_clamped, clampval)
1211
x0[self.clampind] = self.clampvec[self.clampind]
1212
ind = list(set(np.arange(self.nbindings)) - set(self.clampind))
1213
x0[ind] = x
1214
H = 0.5 * x0.dot(self.W.dot(x0)) + x0.dot(self.b)
1215
for jj, subgroup in enumerate(quant_list):
1216
temp_ind = [self.binding_names.index(bb) for bb in subgroup]
1217
qx = x0[temp_ind]
1218
H = H + q* self.qharmony(qx)
1219
return H
1220
1221
def optimal_state_given_clamping(self, group_name_clamped, clampval, quant_list, q, method="Nelder-Mead"):
1222
group_ind = self.group_names.index(group_name_clamped)
1223
if method=="basinhopping":
1224
res = optim.basinhopping(lambda x: -self.gharmony_given_clamping(x, group_name_clamped, clampval, quant_list, q),
1225
np.random.uniform(size = self.nbindings - self.groups[group_ind].nbindings))
1226
else:
1227
res = optim.minimize(lambda x: -self.gharmony_given_clamping(x, group_name_clamped, clampval, quant_list, q),
1228
np.random.uniform(size = self.nbindings - self.groups[group_ind].nbindings),
1229
method = method)
1230
print(res)
1231
x0 = np.zeros(self.nbindings)
1232
x0[self.clampind] = self.clampvec[self.clampind]
1233
ind = list(set(np.arange(self.nbindings)) - set(self.clampind))
1234
x0[ind] = res.x
1235
self.optim_state = x0
1236
print(x0)
1237
1238
1239
# CHECK: for now, consider all units' act change:
1240
def run_GscNet0(net, maxstep, plot=False, grayscale=False, colorbar=True,
1241
tol=None, ema_factor=0.1, crop=True, update_q=True):
1242
# To-do: needs to include conversion from distributed to local representations
1243
# ema factor
1244
ngroups = len(net.groups)
1245
act_trace = np.zeros((maxstep, net.nunits))
1246
1247
if tol is None:
1248
ema_speed_trace = np.zeros(maxstep)
1249
prev_act = copy.deepcopy(net.act)
1250
ema_speed = 0
1251
ema_factor = 0.1
1252
for tt in np.arange(maxstep): #range(1, maxstep+1):
1253
net.update(update_q=update_q)
1254
act_trace[tt, :] = net.act
1255
ema_speed_trace = -1
1256
1257
else:
1258
if tol < 0:
1259
sys.exit("The tolerance parameter should be a positive real number.")
1260
else:
1261
ema_speed_trace = np.zeros(maxstep)
1262
prev_act = copy.deepcopy(net.act)
1263
ema_speed = 0
1264
ema_factor = ema_factor
1265
for tt in np.arange(maxstep): #range(1, maxstep+1):
1266
net.update(update_q=update_q)
1267
#diff = sqrt(sum((net.act - prev_act)**2))/net.nbindings # CHECK: nunits vs. nbindings
1268
diff = max(abs(net.act - prev_act)) # divide it by dt?
1269
if tt == 0:
1270
ema_speed = diff
1271
else:
1272
ema_speed = ema_factor * ema_speed + (1-ema_factor) * diff
1273
act_trace[tt, :] = net.act
1274
ema_speed_trace[tt] = ema_speed
1275
prev_act = copy.deepcopy(net.act)
1276
1277
if (ema_speed < tol) and (tol > 0):
1278
break
1279
1280
ema_speed_trace = ema_speed_trace[tt]
1281
1282
if plot:
1283
for ii, net in enumerate(net.groups):
1284
curr_act_trace = act_trace[:, net.group_unit_idx[ii]]
1285
actC_trace = net.S2C(curr_act_trace[:tt,:].T).T
1286
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings", yticklabels=net.binding_names, grayscale=grayscale, colorbar=colorbar)
1287
1288
# return [tt+1, net.act, act_trace[:tt,], t_trace[:tt], q_trace[:tt], ema_speed_trace[:tt]]
1289
return [tt+1, net.act, act_trace[:tt,], ema_speed_trace]
1290
1291
1292
1293
1294
1295
1296
def bind(fillernames, rolenames):
1297
return [f + '/' + r for r in rolenames for f in fillernames]
1298
1299
# Harmony, Harmony gradient, and hessian (jacobian of the gradient)
1300
def harmony(net, act, quant_list=None):
1301
'''Compute total harmony at a given activation state.'''
1302
return gharmony(net, act) + bharmony(net, act) + net.q * q2harmony(net, act, quant_list)
1303
1304
def oharmony(net, act):
1305
'''Compute total harmony at a given activation state.'''
1306
return gharmony(net, act) + bharmony(net, act)
1307
1308
def gharmony(net, act):
1309
'''Compute harmony (given a harmonic grammar) at a given activation state.'''
1310
return 0.5 * act.dot(net.WG).dot(act) + act.dot(net.bG) + act.dot(net.ext)
1311
1312
def bharmony(net, act):
1313
'''Compute bowl harmony. act should be an array object.'''
1314
return -0.5 * net.beta * (net.S2C(act) - net.z).dot(net.S2C(act) - net.z)
1315
1316
# CHECK: separate ...
1317
def q2harmony(net, act=None, quant_list=None):
1318
'''Compute quantization harmony Q. By default, it considers filler competition in each role.'''
1319
if act is None:
1320
act = net.act
1321
actC = net.S2C(act)
1322
1323
if quant_list is None:
1324
# By default, assumes filler-competition in each role.
1325
if net.nfillers == 1:
1326
const1 = 0
1327
const2 = -np.sum(actC**2 * (1-actC)**2)
1328
Q = const2
1329
else:
1330
const1 = -np.sum((np.sum(actC.reshape(net.nfillers,net.nroles,order='F')**2, axis=0)-1)**2)
1331
const2 = -np.sum(actC**2 * (1-actC)**2)
1332
Q = net.c * const1 + (1-net.c) * const2
1333
else:
1334
Q = 0
1335
for jj, group in enumerate(quant_list):
1336
temp_ind = [net.binding_names.index(bb) for bb in group]
1337
actC0 = actC[temp_ind]
1338
Q = Q + q2(net, actC0)
1339
return Q
1340
1341
def q2(net, actC):
1342
# Within-group competition
1343
#actC = net.TPinv.dot(act)
1344
const1 = -np.sum(actC**2-1)**2
1345
const2 = -np.sum(actC**2 * (1-actC)**2)
1346
if len(actC) == 1:
1347
Q = const2
1348
else:
1349
Q = net.c * const1 + (1-net.c) * const2
1350
return Q
1351
1352
def hgrad(net, act, quant_list=None):
1353
return ghgrad(net, act) + bhgrad(net, act) + net.q * q2hgrad(net, act, quant_list)
1354
1355
def ghgrad(net, act):
1356
return net.WG.dot(act) + net.bG + net.ext
1357
1358
def bhgrad(net, act):
1359
return -net.beta * (net.S2C(act) - net.z)
1360
#return net.WB.dot(net.act) + net.bB # should be same -net.beta * (net.S2C(act) - net.z)
1361
1362
def q2hgrad(net, act, quant_list=None):
1363
actC = net.S2C(act)
1364
if quant_list is None:
1365
if net.nfillers == 1:
1366
const1 = 0
1367
const2 = -2 * actC * (actC-1) * (2*actC-1)
1368
else:
1369
const1 = -4 * actC * ( np.tile( np.sum(actC.reshape(net.nfillers,net.nroles,order='F')**2, axis=0), (net.nfillers,1) ).flatten('F') - 1)
1370
const2 = -2 * actC * (actC-1) * (2*actC-1)
1371
q_grad = net.c * const1 + (1-net.c) * const2
1372
else:
1373
q_grad = np.zeros(net.nbindings)
1374
for jj, subgroup in enumerate(quant_list):
1375
curr_q_grad = np.zeros(net.nbindings)
1376
temp_ind = [net.binding_names.index(bb) for bb in subgroup]
1377
actC0 = actC[temp_ind]
1378
curr_q_grad[temp_ind] = q2grad(net, actC0)
1379
q_grad = q_grad + curr_q_grad
1380
return net.C2S(q_grad)
1381
1382
def q2grad(net, actC):
1383
'''q2grad for each group'''
1384
const1 = -4 * actC * (np.sum(actC**2) - 1)
1385
const2 = -2 * actC * (actC-1) * (2*actC-1)
1386
return net.c * const1 + (1-net.c) * const2
1387
1388
## CHECK
1389
#def hhess(net, act, quant_list):
1390
# # add quant_list
1391
# # Not yet tested.
1392
# A = np.eye(net.nroles)
1393
# B = np.ones((net.nfillers, net.nfillers)) - np.eye(net.nfillers)
1394
# K = np.kron(A, B)
1395
# res = net.W - net.beta * np.eye(net.nunits) + net.q * ( K.dot(-8 * net.c * np.outer(act, act)) + np.diag(-net.c*(4*( np.tile( np.sum(act.reshape(net.nfillers, net.nroles, order='F')**2, axis=0), (net.nfillers,1) ).flatten('F') - 1 + 2 * act**2)) - (1-net.c)*(2*(6*act**2-6*act+1))) )
1396
# return res
1397
1398
# Hinton diagram from ( http://wiki.scipy.org/Cookbook/Matplotlib/HintonDiagrams )
1399
# The functions were slightly modified for consistency
1400
def _blob(x,y,area,colour):
1401
"""
1402
Draws a square-shaped blob with the given area (< 1) at
1403
the given coordinates.
1404
http://wiki.scipy.org/Cookbook/Matplotlib/HintonDiagrams
1405
"""
1406
# Replaced N with np for consistency
1407
hs = np.sqrt(area) / 2
1408
xcorners = np.array([x - hs, x + hs, x + hs, x - hs])
1409
ycorners = np.array([y - hs, y - hs, y + hs, y + hs])
1410
P.fill(xcorners, ycorners, colour, edgecolor=colour)
1411
1412
def hinton(W, maxWeight=None, xlabels=None, ylabels=None):
1413
"""
1414
Draws a Hinton diagram for visualizing a weight matrix.
1415
Temporarily disables matplotlib interactive mode if it is on,
1416
otherwise this takes forever.
1417
http://wiki.scipy.org/Cookbook/Matplotlib/HintonDiagrams
1418
"""
1419
# Replaced N with np for consistency
1420
# xrange was replaced with range because it is not available in Python3
1421
# To do: Consider adding the f/r binding information as X and YTick labels
1422
reenable = False
1423
if P.isinteractive():
1424
P.ioff()
1425
P.clf()
1426
height, width = W.shape
1427
if not maxWeight:
1428
maxWeight = 2**np.ceil(np.log(np.max(np.abs(W)))/np.log(2))
1429
1430
P.fill(np.array([0,width,width,0]),np.array([0,0,height,height]),'gray')
1431
#P.axis('off')
1432
#P.axis('equal')
1433
for x in range(width):
1434
for y in range(height):
1435
_x = x+1
1436
_y = y+1
1437
w = W[y,x]
1438
if w > 0:
1439
_blob(_x - 0.5, height - _y + 0.5, min(1,w/maxWeight),'white')
1440
elif w < 0:
1441
_blob(_x - 0.5, height - _y + 0.5, min(1,-w/maxWeight),'black')
1442
if reenable:
1443
P.ion()
1444
1445
# PWC
1446
label = False
1447
if xlabels is not None:
1448
P.xticks(np.arange(len(xlabels))+0.5, xlabels, rotation='vertical')
1449
label = True
1450
1451
if ylabels is not None:
1452
P.yticks(np.arange(len(ylabels))+0.5, ylabels[::-1])
1453
label = True
1454
1455
if label:
1456
P.xlim([0,width])
1457
P.ylim([0,height])
1458
P.gca().set_aspect('equal', adjustable='box')
1459
P.gca().tick_params(direction='out')
1460
#P.axis('equal')
1461
else:
1462
P.axis('off')
1463
P.axis('equal')
1464
1465
P.show()
1466
1467
#def run(net, maxstep, plot=False, grayscale=False, colorbar=True,
1468
# tol=None, ema_factor=0.1, update_T=True, update_q=True, q_fun='plinear'):
1469
# # To-do: needs to include conversion from distributed to local representations
1470
# # ema factor
1471
# act_trace = np.zeros((maxstep, net.nunits))
1472
# t_trace = np.zeros(maxstep)
1473
# q_trace = np.zeros(maxstep)
1474
#
1475
# if tol is None:
1476
# for tt in np.arange(maxstep): #range(1, maxstep+1):
1477
# net.update(update_T=update_T, update_q=update_q, q_fun=q_fun)
1478
# t_trace[tt] = net.T
1479
# q_trace[tt] = net.q
1480
# act_trace[tt, :] = net.act
1481
# ema_speed_trace = -1
1482
#
1483
# else:
1484
# if tol <= 0:
1485
# sys.exit("The tolerance parameter should be set to a positive real number.")
1486
# else:
1487
# ema_speed_trace = np.zeros(maxstep)
1488
# prev_act = copy.deepcopy(net.act)
1489
# ema_speed = 0
1490
# ema_factor = ema_factor
1491
# for tt in np.arange(maxstep): #range(1, maxstep+1):
1492
# net.update(update_T=update_T, update_q=update_q, q_fun=q_fun)
1493
# t_trace[tt] = net.T
1494
# q_trace[tt] = net.q
1495
# act_trace[tt, :] = net.act
1496
# ema_speed_trace[tt] = net.ema_speed
1497
# prev_act = copy.deepcopy(net.act)
1498
# if (ema_speed < tol) and (tol > 0):
1499
# break
1500
#
1501
# ema_speed_trace = ema_speed_trace[:tt]
1502
#
1503
# if plot:
1504
# actC_trace = net.S2C(act_trace[:tt,:].T).T
1505
# heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings", yticklabels=net.binding_names, grayscale=grayscale, colorbar=colorbar)
1506
#
1507
# return [tt+1, net.act, act_trace[:tt,], t_trace[:tt], q_trace[:tt], ema_speed_trace]
1508
1509
def heatmap(data, xlabel=None, ylabel=None, xticklabels=None, yticklabels=None,
1510
grayscale=False, colorbar=True):
1511
# Plot the activation trace as heatmap
1512
if grayscale:
1513
cmap = plt.cm.get_cmap("gray_r")
1514
else:
1515
cmap=plt.cm.get_cmap("Reds")
1516
plt.imshow(data, cmap=cmap, interpolation="nearest", aspect='auto')
1517
if xlabel is not None:
1518
plt.xlabel(xlabel, fontsize=16)
1519
if ylabel is not None:
1520
plt.ylabel(ylabel, fontsize=16)
1521
if xticklabels is not None:
1522
plt.xticks(np.arange(len(xticklabels)), xticklabels)
1523
if yticklabels is not None:
1524
plt.yticks(np.arange(len(yticklabels)), yticklabels)
1525
if colorbar:
1526
plt.colorbar()
1527
plt.show()
1528
1529
# create symbol representations
1530
def encode_symbols(nsymbols, reptype='local', dp=0, ndim=None):
1531
if reptype == 'local':
1532
sym_mat = np.eye(nsymbols)
1533
else:
1534
if ndim==None:
1535
ndim = nsymbols
1536
1537
if isinstance(dp, Number):
1538
# dp is a scalar that represents the expected pairwise similarity (dot product).
1539
sym_mat = dot_products(nsymbols, ndim, dp)
1540
else:
1541
sym_mat = dot_products2(nsymbols, ndim, dp)
1542
return sym_mat
1543
1544
def dot_products(nsymbols, ndim, s):
1545
dp_mat = s * np.ones((nsymbols, nsymbols)) + (1-s) * np.eye(nsymbols, nsymbols)
1546
sym_mat = dot_products2(nsymbols, ndim, dp_mat)
1547
return sym_mat
1548
1549
def dot_products2(nsymbols, ndim, dp_mat):
1550
# Given square matrix dpMatrix of dimension N-by-N, find N
1551
# dim-dimensional unit vectors whose pairwise dot products match
1552
# dpMatrix. Results are returned in the columns of M. itns is the
1553
# number of iterations of search required, and may be ignored.
1554
#
1555
# Algorithm: Find a matrix M such that M'*M = dpMatrix. This is done
1556
# via gradient descent on a cost function that is the square of the
1557
# frobenius norm of (M'*M-dpMatrix).
1558
#
1559
# NOTE: It has trouble finding more than about 16 vectors, possibly for
1560
# dumb numerical reasons (like stepsize and tolerance), which might be
1561
# fixable if necessary.
1562
1563
if not (dp_mat.T == dp_mat).all():
1564
# print error
1565
sys.exit('dot_products2: dp_mat must be symmetric')
1566
1567
if (np.diag(dp_mat) != 1).any():
1568
sys.exit('dot_products2: dp_mat must have all ones on the main diagonal')
1569
1570
sym_mat = np.random.uniform(size=ndim*nsymbols).reshape(ndim, nsymbols, order='F')
1571
min_step = .1
1572
tol = 1e-6
1573
max_iter = 100000
1574
converged = False
1575
for iter_num in range(1, max_iter+1):
1576
inc = sym_mat.dot(sym_mat.T.dot(sym_mat) - dp_mat)
1577
step = min(min_step, .01/abs(inc).max())
1578
sym_mat = sym_mat - step * inc
1579
max_diff = abs(sym_mat.T.dot(sym_mat)-dp_mat).max()
1580
if max_diff <= tol:
1581
converged = True
1582
break
1583
1584
if not converged:
1585
print("Didn't converge after %d iterations" % max_iter)
1586
1587
return sym_mat
1588
1589
def compute_TPmat(R, F):
1590
# TP matrix that converts local to distributed representations (C-space to S-space).
1591
# See http://en.wikipedia.org/wiki/Vectorization_(mathematics) for justification of kronecker product.
1592
TP = np.kron(R,F)
1593
if TP.shape[0] == TP.shape[1]:
1594
TPinv = la.inv(TP)
1595
else:
1596
TPinv = la.pinv(TP) # TP may be a non-square matrix. So use pseudo-inverse.
1597
return TP, TPinv
1598
1599
def compute_sspace_biases(b_local, net):
1600
b_distributed = np.zeros(net.nunits).reshape(net.nunits)
1601
for ii in np.arange(0, net.nbindings):
1602
si = net.TP[:, ii]
1603
b_distributed = b_distributed + b_local[ii]*si/ np.dot(si, si)
1604
return b_distributed
1605
1606
def compute_sspace_weights(W_local, net):
1607
W_distributed = np.zeros((net.nunits, net.nunits))
1608
for ii in np.arange(0, net.nbindings):
1609
si = net.TP[:, ii] # choose one binding
1610
for jj in np.arange(0, ii+1):
1611
sj = net.TP[:, jj] # choose another binding
1612
if ii != jj:
1613
W_distributed = W_distributed + W_local[ii,jj] * (np.outer(si, sj) + np.outer(sj, si)) / (np.dot(si, si) * np.dot(sj, sj))
1614
else:
1615
W_distributed = W_distributed + .5 * W_local[ii,jj] * (np.outer(si, sj) + np.outer(sj, si)) / (np.dot(si, si) * np.dot(sj, sj))
1616
return W_distributed
1617
1618
def compute_projmat(A):
1619
# A is an n x m matrix of basis (column) vectors of the subspace.
1620
# Only works if the rank of A is equal to the nunmber of columns of A.
1621
return A.dot(la.inv(A.T.dot(A))).dot(A.T)
1622
1623
#class sim_params():
1624
#
1625
# def __init__(self):
1626
#
1627
# self.testvar = 'ema_speed'
1628
# self.convergence_check = True
1629
# self.norm_ord = np.inf
1630
# self.dist_space = 'S'
1631
# self.tol = 1e-2
1632
# self.ema_factor = 0.1
1633
# self.grid_points = None
1634
# self.input_method = 'ext'
1635
# self.update_T = True
1636
# self.update_q = True
1637
# self.maxstep = np.array(10000)
1638
# self.nrep = 1
1639
# self.seed_num = time.time()
1640
#
1641
# def set_params(self, name, val):
1642
#
1643
# setattr(self, name, val)
1644
1645
class sim():
1646
# Use this class to run simulations and store the results.
1647
def __init__(self, net, params):
1648
# net: an object of the GscNet class.
1649
# params: a dictionary for parameter setting.
1650
self.net = net
1651
self.params = params
1652
1653
def set_params(self, p_name, p_val):
1654
self.params[p_name] = p_val
1655
1656
def set_seed(self, num):
1657
np.random.seed(num)
1658
1659
def set_input(self, input_list, input_vals):
1660
1661
if not isinstance(input_list, list):
1662
input_list = [input_list]
1663
1664
nwords = len(input_list)
1665
self.inputC = np.zeros((self.net.nbindings, nwords))
1666
1667
for ii, binding in enumerate(input_list):
1668
1669
val = input_vals[ii]
1670
inputC = np.zeros(self.net.nbindings)
1671
curr_role = binding.split('/')[1]
1672
role_list = [bb.split('/')[1] for bb in self.net.binding_names]
1673
role_idx = [jj for jj, rr in enumerate(role_list) if curr_role == rr]
1674
if not isinstance(binding, list):
1675
binding = [binding]
1676
binding_idx = self.net.find_bindings(binding)
1677
1678
if self.params['input_inhib']:
1679
if self.params['input_method'] == 'ext':
1680
inputC[role_idx] = -val
1681
1682
inputC[binding_idx] = val
1683
self.inputC[:, ii] = inputC
1684
1685
if self.params['cumulative_input']:
1686
1687
self.inputC = self.inputC.cumsum(axis=1)
1688
1689
def simulate(self, params=None):
1690
1691
if params is None:
1692
params = self.params
1693
1694
nrep = self.params['nrep']
1695
maxstep = self.params['maxstep'].max()
1696
nwords = len(self.input_list) # including the end of sentence markers, if any.
1697
self.act_trace = np.zeros((nrep, nwords, maxstep+1, self.net.nunits))
1698
self.T_trace = np.zeros((nrep, nwords, maxstep+1))
1699
self.q_trace = np.zeros((nrep, nwords, maxstep+1))
1700
self.speed_trace = np.zeros((nrep, nwords, maxstep+1))
1701
self.ema_speed_trace = np.zeros((nrep, nwords, maxstep+1))
1702
self.converged = np.ones((nrep, nwords), dtype=bool)
1703
self.rt = np.zeros((nrep, nwords))
1704
1705
# create a storage to save grid points at the end of processing a word.
1706
# create a storage to save the states at the end of processing a word.
1707
1708
for rep_num in np.arange(self.params['nrep']):
1709
1710
self.net.unclamp()
1711
self.net.clear_input()
1712
self.net.reset()
1713
self.set_input(self.input_list, self.input_vals)
1714
1715
for word_num, curr_input in enumerate(self.input_list):
1716
# curr word
1717
# if self.params['cumulative_input']:
1718
# curr_input = self.input_list[0:word_num]
1719
# curr_vals = self.input_vals[0:word_num]
1720
# else:
1721
# curr_input = self.input_list[word_num]
1722
# act_trace curr_vals = self.input_vals[word_num]
1723
#
1724
# curr_val = self.input_vals[word_num]
1725
# if not isinstance(curr_input, list):
1726
# curr_input = [curr_input]
1727
# if not isinstance(curr_val, list):
1728
# curr_val = [curr_val]
1729
1730
curr_maxstep = self.params['maxstep'][word_num]
1731
1732
curr_inputC = self.inputC[:,word_num]
1733
if self.params['input_method'] == 'ext':
1734
self.net.extC = curr_inputC
1735
self.net.ext = compute_sspace_biases(self.net.extC, self.net)
1736
1737
elif self.params['input_method'] == 'clamp':
1738
sys.exit('Not yet implementd')
1739
1740
self.net.run(curr_maxstep, tol=self.params['tol'],
1741
update_T=self.params['update_T'],
1742
update_q=self.params['update_q'],
1743
q_fun=self.params['q_fun'],
1744
grid_points=self.params['grid_points'],
1745
ext_overlap=self.params['ext_overlap'],
1746
ext_overlap_steps=self.params['ext_overlap_steps']) # ADDED
1747
1748
self.net.prev_extC = self.net.extC
1749
1750
self.act_trace[rep_num, word_num, :(curr_maxstep+1), :] = self.net.act_trace
1751
self.rt[rep_num, word_num] = self.net.rt
1752
self.T_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.T_trace
1753
self.q_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.q_trace
1754
self.speed_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.speed_trace
1755
self.ema_speed_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.ema_speed_trace
1756
self.converged[rep_num, word_num] = self.net.converged
1757
1758
def plot_act_trace(self, rep_num):
1759
1760
rep_ind = rep_num - 1
1761
1762
nwords = self.act_trace.shape[1]
1763
curr_act_trace = self.act_trace[rep_ind, 0, :(self.rt[rep_ind, 0]+1), :]
1764
if nwords > 1:
1765
for wind in np.arange(1, nwords):
1766
temp_act_trace = self.act_trace[rep_ind, wind, 1:(self.rt[rep_ind, wind]+1), :]
1767
curr_act_trace = np.concatenate((curr_act_trace, temp_act_trace), axis=0)
1768
1769
actC_trace = self.net.S2C(curr_act_trace.T).T
1770
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings",
1771
xticklabels=None, yticklabels=self.net.binding_names, grayscale=False)
1772
1773
def plot_trace(self, varname, rep_num):
1774
1775
rep_ind = rep_num - 1
1776
nwords = self.act_trace.shape[1]
1777
curr_trace = getattr(self, varname+'_trace')[rep_ind, 0, :(self.rt[rep_ind, 0]+1)]
1778
if nwords > 1:
1779
for wind in np.arange(1, nwords):
1780
temp_trace = getattr(self, varname+'_trace')[rep_ind, wind, 1:(self.rt[rep_ind, 0]+1)]
1781
curr_trace = np.concatenate((curr_trace, temp_trace), axis=0)
1782
1783
curr_rt_trace = self.rt[rep_ind, :]
1784
curr_rt_trace = curr_rt_trace.cumsum()
1785
1786
plt.plot(curr_trace)
1787
for ii, rt in enumerate(curr_rt_trace):
1788
plt.plot([rt, rt], [min(curr_trace), max(curr_trace)], 'g-')
1789
1790
plt.xlabel('Timestep', fontsize=16)
1791
plt.ylabel(varname, fontsize=16)
1792
1793
def read_state(self, rep_num, word_num):
1794
1795
act = self.act_trace[rep_num-1, word_num-1, self.rt[rep_num-1, word_num-1], :]
1796
self.net.read_state(act)
1797
#actC = self.net.S2C(self.act_trace[rep_num-1, word_num-1, self.rt[rep_num-1, word_num-1], :])
1798
#print(pd.DataFrame(self.net.vec2mat(actC), index=self.net.filler_names, columns=self.net.role_names))
1799
1800
def read_grid_point(self, rep_num, word_num):
1801
act = self.act_trace[rep_num-1, word_num-1, self.rt[rep_num-1, word_num-1], :]
1802
self.net.read_grid_point(act)
1803
1804