"""
GSC Software:
Author: Pyeong Whan Cho
Department of Cognitive Science, Johns Hopkins University
Summer 2015
"""
from math import *
import numpy as np
from matplotlib import pyplot as plt
from scipy import optimize as optim
import pandas as pd
import pylab as P
from scipy import linalg as la
from numbers import Number
import sys
import copy
class GscNet():
def __init__(self, filler_names, role_names,
WGC=None, bGC=None, extC=None, prev_extC=None, z=0.5, beta=4,
q_init=0, q_rate=4, q_max=50, q_fun='plinear', c=0.5,
T_init=0, dt=0.001, reptype_r='local', reptype_f='local',
dp_r=0, dp_f=0, ndim_r=None, ndim_f=None,
T_decay_rate=0, T_min=0, quant_list=None, grid_points=None):
'''Construct an instance of GscNet.'''
self.binding_names = bind(filler_names, role_names)
self.filler_names = filler_names
self.role_names = role_names
self.nbindings = len(self.binding_names)
self.nfillers = len(filler_names)
self.nroles = len(role_names)
if WGC is None:
WGC = np.zeros(self.nbindings**2).reshape(self.nbindings, self.nbindings, order='F')
if bGC is None:
bGC = np.zeros(self.nbindings).reshape(self.nbindings)
if extC is None:
extC = np.zeros(self.nbindings).reshape(self.nbindings)
if prev_extC is None:
prev_extC = np.zeros(self.nbindings).reshape(self.nbindings)
self.WGC = WGC
self.bGC = bGC
self.extC = extC
self.prev_extC = prev_extC
if ndim_r is None:
ndim_r = self.nroles
if ndim_f is None:
ndim_f = self.nfillers
self.reptype_r = reptype_r
self.reptype_f = reptype_f
self.ndim_r = ndim_r
self.ndim_f = ndim_f
self.R = encode_symbols(self.nroles, reptype=reptype_r, dp=dp_r, ndim=ndim_r)
self.F = encode_symbols(self.nfillers, reptype=reptype_f, dp=dp_f, ndim=ndim_f)
self.TP, self.TPinv = compute_TPmat(self.R, self.F)
self.nunits = ndim_r * ndim_f
self.act = np.zeros(self.nunits)
self.actC = self.S2C()
self.z = z
self.beta = beta
self.WBC = -self.beta * np.eye(self.nbindings)
self.bBC = self.beta * self.z * np.ones(self.nbindings)
self.WC = self.WGC + self.WBC
self.bC = self.bGC + self.bBC
self.WG = compute_sspace_weights(self.WGC, self)
self.WB = compute_sspace_weights(self.WBC, self)
self.bG = compute_sspace_biases(self.bGC, self)
self.bB = compute_sspace_biases(self.bBC, self)
self.W = compute_sspace_weights(self.WC, self)
self.b = compute_sspace_biases(self.bC, self)
self.ext = compute_sspace_biases(self.extC, self)
self.prev_ext = compute_sspace_biases(self.prev_extC, self)
self.check_beta()
self.q_init = q_init
self.q = copy.deepcopy(self.q_init)
self.q_rate = q_rate
self.q_fun = q_fun
self.q_max = q_max
self.q_time = 0
self.c = c
self.T_init = T_init
self.T = copy.deepcopy(self.T_init)
self.T_decay_rate = T_decay_rate
self.T_min = T_min
self.dt = dt
self.clamped = False
self.act_trace = None
self.quant_list = quant_list
if (self.W.T == self.W).all() == False:
sys.exit("The weight matrix (2D array) is not symmetric. Please check it.")
ndigits = len(str(self.nunits))
self.unit_names = ['U' + str(ii+1).zfill(ndigits) for ii in list(range(self.nunits))]
self.speed = None
self.ema_speed = None
self.grid_points = grid_points
def randomize_state(self, minact=0, maxact=1):
'''Set the activation state to a random vector inside a unit hypercube'''
actC = np.random.uniform(minact, maxact, self.nbindings)
self.act = self.C2S(actC)
def randomize_weights(self, mu=0, sigma=1):
'''Randomize WC'''
WC = sigma*np.random.randn(self.nbindings**2).reshape(self.nbindings, self.nbindings, order='F') + mu
WC = (WC + WC.T)/2
self.WC = WC
self.W = compute_sspace_weights(self.WC, self)
def set_weight(self, binding_name1, binding_name2, weight, symmetric=True):
'''Set the weight of a connection between binding1 and binding2.
When symmetric is set to True (default), the connection weight from
binding2 to binding1 is set to the same value.'''
idx1 = self.find_bindings(binding_name1)
idx2 = self.find_bindings(binding_name2)
if symmetric:
self.WGC[idx1, idx2] = self.WGC[idx2, idx1] = weight
else:
self.WGC[idx2, idx1] = weight
self.WC = self.WGC + self.WBC
self.WG = compute_sspace_weights(self.WGC, self)
self.W = compute_sspace_weights(self.WC, self)
def set_bias(self, binding_name, bias):
'''Set the bias of a binding to [bias].'''
idx = self.find_bindings(binding_name)
self.bGC[idx] = bias
self.bC = self.bGC + self.bBC
self.bG = compute_sspace_biases(self.bGC, self)
self.b = compute_sspace_biases(self.bC, self)
def set_role_bias(self, role_name, bias):
'''Set the bias of bindings of all fillers with particular roles to [bias].'''
role_list = [bb.split('/')[1] for bb in self.binding_names]
if not isinstance(role_name, list):
role_name = [role_name]
for jj, role in enumerate(role_name):
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
self.bGC[idx] = bias
self.bC = self.bGC + self.bBC
self.bG = compute_sspace_biases(self.bGC, self)
self.b = compute_sspace_biases(self.bC, self)
def set_filler_bias(self, filler_name, bias):
'''Set the bias of bindings of all roles with particular fillers to [bias].'''
filler_list = [bb.split('/')[0] for bb in self.binding_names]
if not isinstance(filler_name, list):
filler_name = [filler_name]
for jj, filler in enumerate(filler_name):
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
self.bGC[idx] = bias
self.bC = self.bGC + self.bBC
self.bG = compute_sspace_biases(self.bGC, self)
self.b = compute_sspace_biases(self.bC, self)
def vec2mat(self, actC=None):
if actC is None:
actC = self.S2C()
return actC.reshape(self.nfillers, self.nroles, order='F')
def C2S(self, C=None):
'''Change basis: from C- to S-space.'''
if C is None:
C = self.actC
return self.TP.dot(C)
def S2C(self, S=None):
'''Change basis: from S- to C-space.'''
if S is None:
S = self.act
return self.TPinv.dot(S)
def find_bindings(self, binding_names):
'''Find the indices of the bindings from the list of binding names.'''
if not isinstance(binding_names, list):
binding_names = [binding_names]
return [self.binding_names.index(bb) for bb in binding_names]
def find_roles(self, role_name):
if not isinstance(role_name, list):
role_name = [role_name]
role_list = [bb.split('/')[1] for bb in self.binding_names]
role_idx = []
for jj, role in enumerate(role_name):
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
role_idx += idx
return role_idx
def find_fillers(self, filler_name):
if not isinstance(filler_name, list):
filler_name = [filler_name]
filler_list = [bb.split('/')[0] for bb in self.binding_names]
filler_idx = []
for jj, filler in enumerate(filler_name):
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
filler_idx += idx
return filler_idx
def read_state(self, act=None):
'''Print the current state (C-SPACE) in a readable format. Pandas should be installed.'''
if act is None:
act = self.act
actC = self.vec2mat(self.S2C(act))
print(pd.DataFrame(actC, index=self.filler_names, columns=self.role_names))
def read_grid_point(self, act=None, disp=True):
'''Print a grid point close to the current state. The grid point will be
chosen by the snapping method (winner-takes-it-all).'''
if act is None:
act = self.act
if self.quant_list is None:
actC = self.vec2mat(self.S2C(act))
winner_idx = np.argmax(actC, axis=0)
winners = [self.filler_names[ii] for ii in winner_idx]
winners = ["%s/%s" % bb for bb in zip(winners, self.role_names)]
else:
actC = self.S2C(act)
winners = []
for kk, group in enumerate(self.quant_list):
idx = self.find_bindings(group)
winner_idx = [idx[ii] for ii, jj in enumerate(actC[idx]) if jj == max(actC[idx])]
winner = [self.binding_names[jj] for ii, jj in enumerate(winner_idx)]
winners.extend(winner)
if disp:
print(winners)
return winners
def read_weight(self, which='WGC'):
'''Print the weight matrix in a readable format (in the pattern coordinate).'''
if which[-1] == 'C':
print(pd.DataFrame(getattr(self, which), index=self.binding_names, columns=self.binding_names))
else:
print(pd.DataFrame(getattr(self, which), index=self.unit_names, columns=self.unit_names))
def read_bias(self, which='bGC', print_vertical=True):
'''Print the bias vector (in the pattern coordinate).'''
if which[-1] == 'C':
if print_vertical:
print(pd.DataFrame(getattr(self, which).reshape(self.nbindings, 1), index=self.binding_names, columns=["bias"]))
else:
print(pd.DataFrame(getattr(self, which).reshape(1, self.nbindings), index=["bias"], columns=self.binding_names))
else:
if print_vertical:
print(pd.DataFrame(getattr(self, which).reshape(self.nbindings, 1), index=self.unit_names, columns=["bias"]))
else:
print(pd.DataFrame(getattr(self, which).reshape(1, self.nbindings), index=["bias"], columns=self.unit_names))
def hinton(self, which='WGC', label=True):
'''Draw a hinton diagram'''
if label:
if which[-1] == 'C':
labels=self.binding_names
else:
labels=self.unit_names
else:
labels=None
hinton(getattr(self, which), xlabels=labels, ylabels=labels)
def act_clamped(self, act=None):
'''S-space'''
if act is None:
act = self.act
return self.projmat.dot(act) + self.clampvec
def check_beta(self):
'Compute and print the recommended beta value given the weights and biases in the C-space.'''
eigvals, eigvecs = la.eigh(self.WGC)
eig_max = max(eigvals)
if self.nbindings == 1:
beta1 = -(self.bGC+self.extC)/self.z
beta2 = (self.bGC+self.extC+eig_max)/(1-self.z)
else:
beta1 = -min(self.bGC+self.extC)/self.z
beta2 = (max(self.bGC+self.extC)+eig_max)/(1-self.z)
beta_min = max(eig_max, beta1, beta2)
if self.beta <= beta_min:
sys.exit("Beta (bowl strength) should be greater than %.4f." % beta_min)
def info(self):
'''Print the network information.'''
print('Fillers = ', self.filler_names)
print('Roles = ', self.role_names)
print('Num_of_units = ', self.nunits)
print('Current T = ', self.T)
def compute_dist(self, ref_point, norm_ord=2, space='S'):
"""
Compute the distance of the current state from a grid point.
[grid point] is a set of bindings.
"""
idx = self.find_bindings(ref_point)
destC = np.zeros(self.nbindings)
destC[idx] = 1.0
if space == 'S':
state1 = self.act
state2 = self.C2S(destC)
elif space == 'C':
state1 = self.S2C(self.act)
state2 = destC
return np.linalg.norm(state1-state2, ord=norm_ord)
def set_input(self, binding_names, ext_vals, inhib_comp=False):
'''Set external input.'''
if not isinstance(ext_vals, list):
ext_vals = [ext_vals]
if not isinstance(binding_names, list):
binding_names = [binding_names]
if len(ext_vals) > 1:
if len(binding_names) != len(ext_vals):
sys.exit("binding_names and ext_vals have different lengths.")
self.clear_input()
if inhib_comp:
if self.quant_list is None:
role_names = [b.split('/')[1] for b in binding_names]
idx = self.find_roles(role_names)
self.extC[idx] = -np.asarray(ext_vals)
else:
for ii, bb in enumerate(binding_names):
group_id = [jj for jj, group in enumerate(self.quant_list) if bb in group]
for kk, gid in enumerate(group_id):
curr_bindings = self.quant_list[gid]
idx = self.find_bindings(curr_bindings)
self.extC[idx] = -np.asarray(ext_vals)
idx = self.find_bindings(binding_names)
self.extC[idx] = ext_vals
self.ext = compute_sspace_biases(self.extC, self)
def clear_input(self):
'''Remove external input.'''
self.extC = np.zeros(self.nbindings)
self.ext = compute_sspace_biases(self.extC, self)
def clamp(self, binding_names, clamp_vals=1.0, clamp_comp=False):
'''Clamp f/r bindings'''
if not isinstance(clamp_vals, list):
clamp_vals = [clamp_vals]
if not isinstance(binding_names, list):
binding_names = [binding_names]
if len(clamp_vals) > 1:
if len(clamp_vals) != len(binding_names):
sys.exit('The number of bindings clamped is not equal to the number of values provided.')
self.clamped = True
self.binding_names_clamped = binding_names
clampvecC = np.zeros(self.nbindings)
if clamp_comp:
if self.quant_list is None:
role_names = [b.split('/')[1] for b in binding_names]
idx1 = self.find_roles(role_names)
clampvecC[idx1] = 0.0
else:
for ii, bb in enumerate(binding_names):
group_id = [jj for jj, group in enumerate(self.quant_list) if bb in group]
for kk, gid in enumerate(group_id):
curr_bindings = self.quant_list[gid]
idx1 = self.find_bindings(curr_bindings)
clampvecC[idx1] = 0.0
idx = self.find_bindings(binding_names)
clampvecC[idx] = clamp_vals
self.clampvecC = clampvecC
if clamp_comp:
idx += idx1
idx.sort()
idx0 = [bb for bb in np.arange(self.nbindings) if bb not in idx]
A = self.TP[:, idx0]
if len(idx0) > 0:
self.projmat = compute_projmat(A)
else:
self.projmat = np.zeros((self.nunits, self.nunits))
self.clampvec = self.C2S(clampvecC)
self.act = self.act_clamped(self.act)
self.actC = self.S2C()
def unclamp(self):
if self.clamped is True:
del self.clampvec
del self.clampvecC
del self.projmat
del self.binding_names_clamped
self.clamped = False
def harmony(self, act=None, quant_list=None):
'''Compute the total harmony of the current state'''
if act is None:
act = self.act
if quant_list is None:
quant_list = self.quant_list
return harmony(net=self, act=act, quant_list=quant_list)
def gharmony(self, act=None):
'''Compute the grammar harmony of the current state'''
if act is None:
act = self.act
return gharmony(net=self, act=act)
def bharmony(self, act=None):
'''Compute the grammar harmony of the current state'''
if act is None:
act = self.act
return bharmony(net=self, act=act)
def oharmony(self, act=None):
'''Compute the grammar harmony of the current state'''
if act is None:
act = self.act
return oharmony(net=self, act=act)
def qharmony(self, act=None, quant_list=None):
'''Compute the grammar harmony of the current state'''
if act is None:
act = self.act
if quant_list is None:
quant_list = self.quant_list
return q2harmony(net=self, act=act, quant_list=quant_list)
def hgrad(self, act=None, quant_list=None):
'''Compute the harmony gradient evaluated at the current state'''
if act is None:
act = self.act
return hgrad(net=self, act=act, quant_list=quant_list)
def update(self, update_T=True, update_q=True, q_fun=None, space='S', norm_ord=2):
if q_fun is None:
q_fun = self.q_fun
self.prev_act = copy.deepcopy(self.act)
self.update_state()
self.update_speed(space=space, norm_ord=norm_ord)
if self.grid_points is not None:
self.update_dist(self.grid_points, space=space, norm_ord=norm_ord)
if update_T:
self.update_T()
if update_q:
self.update_q(fun=q_fun)
def update_state(self):
'''Update the current state (with noise)'''
self.act += self.dt * self.hgrad(quant_list=self.quant_list)
self.add_noise()
if self.clamped:
self.act = self.act_clamped()
self.actC = self.S2C()
def add_noise(self):
'''Add a noise to the current activation state.'''
noise = sqrt(2 * self.T * self.dt) * np.random.randn(self.nunits)
self.act += noise
def update_T(self, method='exponential'):
'''Update the current temperature'''
if method == 'exponential':
self.T = (1-self.T_decay_rate) * (self.T - self.T_min) + self.T_min
def update_q(self, fun='plinear'):
if fun == 'log':
self.q_time += 1
self.q = self.q_rate * log(self.q_time)
if fun == 'linear':
self.q += self.q_rate
if fun == 'plinear':
if self.q <= self.q_max:
self.q += self.q_rate
def reset(self):
self.q = self.q_init
self.q_time = 0
self.T = self.T_init
self.randomize_state()
def update_speed(self, ema_factor=0.1, norm_ord=2, space='S'):
if space == 'S':
diff = self.act - self.prev_act
elif space == 'C':
diff = self.S2C(self.act) - self.S2C(self.prev_act)
self.speed = np.linalg.norm(diff, ord=norm_ord)
if self.ema_speed is None:
self.ema_speed = self.speed
else:
self.ema_speed = ema_factor * self.ema_speed + (1-ema_factor) * self.speed
def update_dist(self, grid_points, norm_ord=2, space='S'):
if not any(isinstance(tt, list) for tt in grid_points):
grid_points = [grid_points]
dist = np.zeros(len(grid_points))
for ii, grid_point in enumerate(grid_points):
dist[ii] = self.compute_dist(ref_point=grid_point, norm_ord=norm_ord, space=space)
self.dist = dist
def update_traces(self):
self.act_trace[self.step, :] = self.act
self.h_trace[self.step] = self.harmony()
self.oh_trace[self.step] = self.oharmony()
self.qh_trace[self.step] = self.qharmony()
self.gh_trace[self.step] = self.gharmony()
self.bh_trace[self.step] = self.bharmony()
self.speed_trace[self.step] = self.speed
self.ema_speed_trace[self.step] = self.ema_speed
self.q_trace[self.step] = self.q
self.T_trace[self.step] = self.T
if self.grid_points is not None:
self.dist_trace[self.step, :] = self.dist
def run(self, maxstep, plot=False, grayscale=False, colorbar=True,
tol=None, ema_factor=0.1, update_T=True, update_q=True, q_fun=None,
grid_points=None, testvar='ema_speed', space='S', norm_ord=2,
ext_overlap=False, ext_overlap_steps=1000):
'''
Run simulations for [maxstep] steps.
'''
if q_fun is None:
q_fun = self.q_fun
if grid_points is None:
grid_points = self.grid_points
self.converged = False
self.act_trace = np.zeros((maxstep+1, self.nunits))
self.h_trace = np.zeros(maxstep+1)
self.oh_trace = np.zeros(maxstep+1)
self.qh_trace = np.zeros(maxstep+1)
self.gh_trace = np.zeros(maxstep+1)
self.bh_trace = np.zeros(maxstep+1)
self.speed_trace = np.zeros(maxstep+1)
self.ema_speed_trace = np.zeros(maxstep+1)
self.q_trace = np.zeros(maxstep+1)
self.T_trace = np.zeros(maxstep+1)
if grid_points is not None:
if not any(isinstance(pp, list) for pp in grid_points):
grid_points = [grid_points]
n_grid_points = len(grid_points)
self.dist_trace = np.zeros((maxstep+1, len(grid_points)))
self.step = 0
if grid_points is not None:
self.dist = np.empty(n_grid_points)
self.dist[:] = np.NAN
self.update_traces()
curr_extC = self.extC
for tt in np.arange(maxstep)+1:
self.step = tt
if ext_overlap:
self.extC = (1 - tt/ext_overlap_steps) * self.prev_extC + (tt/ext_overlap_steps) * curr_extC
self.update(update_T=update_T, update_q=update_q, q_fun=q_fun,
space=space, norm_ord=norm_ord)
self.update_traces()
if tol is not None:
self.check_convergence(tol=tol, testvar=testvar)
if self.converged:
break
self.rt = self.step
if plot:
actC_trace = self.S2C(self.act_trace[:(self.rt+1),:].T).T
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings",
yticklabels=self.binding_names, grayscale=grayscale, colorbar=colorbar)
def check_convergence(self, tol, testvar='ema_speed'):
'''Check if the convergence criterion (distance vs. ema_speed) has been satisfied.'''
if testvar == 'dist':
if (self.dist < tol).any():
self.converged = True
if testvar == 'ema_speed':
if self.ema_speed < tol:
self.converged = True
def simulate(self):
return 0
def plot_act_trace(self, timesteps=None):
if hasattr(self, 'act_trace'):
if timesteps is None:
timesteps = list(range(self.rt+1))
actC_trace = self.S2C(self.act_trace[timesteps,:].T).T
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings",
xticklabels=None, yticklabels=self.binding_names, grayscale=False)
else:
sys.exit("There is no [act_trace] attribute in the current object.")
def plot_trace(self, varname, timesteps=None):
if timesteps is None:
timesteps = list(range(self.rt+1))
curr_trace = getattr(self, varname+'_trace')
plt.plot(curr_trace[timesteps])
plt.xlabel('Timestep', fontsize=16)
plt.ylabel(varname, fontsize=16)
plt.show()
def plot_dist_trace(self, timesteps=None):
if timesteps is None:
timesteps = list(range(self.rt+1))
plt.plot(self.dist_trace[timesteps, :])
plt.xlabel('Timestep', fontsize=16)
plt.ylabel('Distance', fontsize=16)
plt.show()
def optim(self, initvals=[], method='nelder-mead', options={'xtol': 1e-8, 'disp': True}):
'''Find a local optimum given an initial guess (by default, the current state is used)'''
if initvals==[]:
initvals = self.act
res = optim.minimize(lambda x: -harmony(self, x), initvals, method='nelder-mead', options=options)
res.fun = -res.fun
return res
def harmony_landscape(self, binding_name, minact, maxact):
'''1D harmony landscape'''
if isinstance(binding_name, list):
if len(binding_name) > 1:
sys.exit("Choose only one binding.!")
idx = self.binding_names.index(binding_name)
agrid = np.linspace(minact, maxact, 1000)
hgrid = np.zeros(len(agrid))
actC = self.S2C()
for ii, aa in enumerate(agrid):
act0 = actC
act0[idx] = aa
act = self.C2S(act0)
hgrid[ii] = harmony(net=self, act=act)
return agrid, hgrid
class WeightTP():
def __init__(self, binding_names):
self.binding_names = binding_names
self.nbindings = len(binding_names)
self.WGC = np.zeros((self.nbindings, self.nbindings))
def set_weight(self, binding1, binding2, weight):
idx1 = self.binding_names.index(binding1)
idx2 = self.binding_names.index(binding2)
self.WGC[idx1, idx2] = self.WGC[idx2, idx1] = weight
def show(self):
print(pd.DataFrame(self.WGC, index=self.binding_names, columns=self.binding_names))
class BiasTP():
def __init__(self, binding_names):
self.binding_names = binding_names
self.nbindings = len(binding_names)
self.bGC = np.zeros(self.nbindings)
def set_bias(self, binding, bias):
idx = self.binding_names.index(binding)
self.bGC[idx] = bias
def set_role_bias(self, role, bias):
role_list = [b.split('/')[1] for b in self.binding_names]
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
self.bGC[idx] = bias
def set_filler_bias(self, filler, bias):
filler_list = [b.split('/')[0] for b in self.binding_names]
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
self.bGC[idx] = bias
def show(self):
print(pd.DataFrame(self.bGC, index=self.binding_names, columns=["bias"]))
class GscNet0():
def __init__(self, net_list, group_names):
self.ngroups = len(net_list)
self.group_names = group_names
self.groups = net_list
self.binding_names = []
for idx, net in enumerate(self.groups):
self.binding_names = self.binding_names + [b + ':' + group_names[idx] for b in net.binding_names]
self.nbindings = len(self.binding_names)
self.WGC = np.zeros((self.nbindings, self.nbindings))
self.WBC = np.zeros((self.nbindings, self.nbindings))
for ii, net in enumerate(self.groups):
curr_bindings = [b + ':' + group_names[ii] for b in net.binding_names]
idx = [self.binding_names.index(bb) for bb in curr_bindings]
self.WGC[np.ix_(idx,idx)] = net.WGC
self.WBC[np.ix_(idx,idx)] = net.WBC
self.WC = self.WGC + self.WBC
self.bGC = np.zeros(self.nbindings)
self.bBC = np.zeros(self.nbindings)
self.extC = np.zeros(self.nbindings)
nunits = 0
self.group_unit_idx = list()
self.group_binding_idx = list()
for ii, net in enumerate(self.groups):
curr_bindings = [b + ':' + group_names[ii] for b in net.binding_names]
idx = [self.binding_names.index(bb) for bb in curr_bindings]
self.bGC[np.ix_(idx)] = net.bGC
self.bBC[np.ix_(idx)] = net.bBC
self.extC[np.ix_(idx)] = net.extC
unit_idx = np.arange(net.nunits) + nunits
self.group_unit_idx.append(unit_idx)
self.group_binding_idx.append(idx)
nunits = nunits + net.nunits
self.bC = self.bGC + self.bBC
self.nunits = nunits
self.W = self.compute_sspace_weights()
self.b, self.ext = self.compute_sspace_b_and_ext()
self.act = np.zeros(self.nunits)
def compute_sspace_b_and_ext(self):
b_distributed = np.zeros(self.nunits)
ext_distributed = np.zeros(self.nunits)
for ii, net in enumerate(self.groups):
b_distributed[self.group_unit_idx[ii]] = net.b
ext_distributed[self.group_unit_idx[ii]] = net.ext
return b_distributed, ext_distributed
def compute_sspace_weights(self):
W_distributed = np.zeros((self.nunits, self.nunits))
for ii, net1 in enumerate(self.groups):
idx1 = self.group_unit_idx[ii]
bidx1 = self.group_binding_idx[ii]
for jj, net2 in enumerate(self.groups):
idx2 = self.group_unit_idx[jj]
bidx2 = self.group_binding_idx[jj]
W_local = self.WC[np.ix_(bidx1, bidx2)]
if ii == jj:
W_distributed[np.ix_(idx1, idx1)] = net1.W
else:
W_temp = np.zeros((net1.nunits, net2.nunits))
for kk in np.arange(0, net1.nbindings):
si = net1.TP[:, kk]
for ll in np.arange(0, net2.nbindings):
sj = net2.TP[:, ll]
W_temp = W_temp + W_local[kk,ll] * np.outer(si, sj) / (np.dot(si, si) * np.dot(sj, sj))
W_distributed[np.ix_(idx1, idx2)] = W_temp
return W_distributed
def randomize_state(self, minact=0, maxact=1):
'''Set the activation state to a random vector inside a unit hypercube'''
for ii, net in enumerate(self.groups):
net.randomize_state()
def set_weight(self, binding_name1, binding_name2, weight, symmetric=True):
idx1 = self.binding_names.index(binding_name1)
idx2 = self.binding_names.index(binding_name2)
if symmetric:
self.WGC[idx1, idx2] = self.WGC[idx2, idx1] = weight
else:
self.WGC[idx2, idx1] = weight
self.WC = self.WGC + self.WBC
self.W = self.compute_sspace_weights()
def set_bias(self, binding_name, bias):
idx = self.binding_names.index(binding_name)
self.bGC[idx] = bias
self.bC = self.bGC + self.bBC
self.b, self.ext = self.compute_sspace_b_and_ext()
def set_role_bias(self, role_name, bias):
role_list = [bb.split('/')[1] for bb in self.binding_names]
idx = [ii for ii, rr in enumerate(role_list) if role_name in rr]
self.bGC[idx] = bias
self.bC = self.bGC + self.bBC
self.b, self.ext = self.compute_sspace_b_and_ext()
def set_filler_bias(self, filler_name, bias):
filler_list = [bb.split('/')[0] for bb in self.binding_names]
idx = [ii for ii, ff in enumerate(filler_list) if filler_name in ff]
self.bGC[idx] = bias
self.bC = self.bGC + self.bBC
self.b, self.ext = self.compute_sspace_b_and_ext()
def set_seed(self, num):
np.random.seed(num)
def find_bindings(self, binding_names):
'''Find the indices of the bindings from the list of binding names.'''
if not isinstance(binding_names, list):
binding_names = [binding_names]
return [self.binding_names.index(bb) for bb in binding_names]
def read_state(self):
'''Print the current state (C-SPACE) in a readable format. Pandas should be installed.'''
for ii, net in enumerate(self.groups):
print("Group: %s" % self.group_names[ii])
net.read_state()
print("\n")
def read_grid_point(self):
for ii, net in enumerate(self.groups):
print("Group: %s" % self.group_names[ii])
print(net.read_grid_point(disp=False))
print("\n")
def read_weight(self):
print(pd.DataFrame(self.WC, index=self.binding_names, columns=self.binding_names))
print(pd.DataFrame(self.bC.reshape(1, self.nbindings), index=["bias"], columns=self.binding_names))
def hinton(self, label=True):
'''Draw a hinton diagram'''
if label:
labels=self.binding_names
else:
labels=None
hinton(self.WC, xlabels=labels, ylabels=labels)
def set_input(self, binding_names, ext_vals):
if len(binding_names) != len(ext_vals):
sys.exit("binding_names and ext_vals have different lengths.")
self.extC = np.zeros(self.nbindings)
ind = self.find_bindings(binding_names)
self.extC[ind] = ext_vals
self.ext = compute_sspace_biases(self.extC, self)
def clear_input(self):
self.extC = np.zeros(self.nbindings)
self.ext = compute_sspace_biases(self.extC, self)
def clamp(self, binding_names, clamp_vals):
'''Clamp f/r bindings'''
if len(clamp_vals) != len(binding_names):
sys.exit('The number of bindings clamped is not equal to the number of values provided.')
self.clamped = True
temp = [b.split(':') for b in binding_names]
group_num = [self.group_names.index(b[1]) for b in temp]
bb_names = [b[0] for b in temp]
for ii, gg in enumerate(set(group_num)):
idx1 = [jj for jj, val in enumerate(group_num) if val==gg]
curr_bindings = [bb_names[jj] for jj in idx1]
curr_vals = [clamp_vals[jj] for jj in idx1]
self.groups[gg].clamp(binding_names=curr_bindings, clamp_vals=curr_vals)
def unclamp(self):
self.clamped = False
for ii, net in enumerate(self.groups):
if net.clamped:
net.unclamp()
def update(self, update_T=True, update_q=True):
'''Update the current state (with noise) for each group'''
for ii, net1 in enumerate(self.groups):
extC = np.zeros(net1.nbindings)
binding_names1 = [b + ':' + self.group_names[ii] for b in net1.binding_names]
idx1 = [self.binding_names.index(b) for b in binding_names1]
for jj, net2 in enumerate(self.groups):
if not (ii==jj):
binding_names2 = [b + ':' + self.group_names[jj] for b in net2.binding_names]
idx2 = [self.binding_names.index(b) for b in binding_names2]
curr_WC = self.WC[np.ix_(idx1, idx2)]
extC = extC + curr_WC.dot(net2.TPinv.dot(net2.act))
net1.ext = net1.TP.dot(extC)
for ii, net in enumerate(self.groups):
net.update(update_T = update_T, update_q = update_q)
self.act[self.group_unit_idx[ii]] = net.act
def update_state(self):
'''Update the current state (with noise)'''
self.act += self.dt * self.hgrad()
self.add_noise()
if self.clamped:
self.act = self.act_clamped()
def add_noise(self):
'''Add a noise to the current activation state.'''
noise = sqrt(2 * self.T * self.dt) * np.random.randn(self.nunits)
self.act += noise
def update_T(self, method='exponential'):
'''Update the current temperature'''
if method == 'exponential':
self.T = (1-self.T_decay_rate) * (self.T - self.T_min) + self.T_min
def update_q(self, method='log'):
if method == 'log':
self.q_time += 1
self.q = self.q_factor * log(self.q_time)
def reset(self):
for ii, net in enumerate(self.groups):
net.reset()
def harmony(self, act=None, quant_list=None):
'''Compute the total harmony of the current state'''
if act is None:
act = self.act
return harmony(net=self, act=act, quant_list=quant_list)
def gharmony(self, act=None):
'''Compute the grammar harmony of the current state'''
if act is None:
act = self.act
return gharmony(net=self, act=act)
def hgrad(self, act=None, quant_list=None):
'''Compute the harmony gradient evaluated at the current state'''
if act is None:
act = self.act
return hgrad(net=self, act=act, quant_list=quant_list)
def run(self, maxstep, tol=None):
self.rt, self.final_state, self.act_trace, self.ema_speed_trace = run_GscNet0(self, maxstep=maxstep, tol=tol)
def plot_act_trace(self):
for ii, net in enumerate(self.groups):
curr_act_trace = self.act_trace[:, self.group_unit_idx[ii]]
actC_trace = net.TPinv.dot(curr_act_trace.T).T
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings", yticklabels=net.binding_names, grayscale=False, colorbar=True)
def gharmony_given_clamping(self, x, group_name_clamped, clampval, quant_list, q):
x0 = np.zeros(self.nbindings)
self.clamp_group(group_name_clamped, clampval)
x0[self.clampind] = self.clampvec[self.clampind]
ind = list(set(np.arange(self.nbindings)) - set(self.clampind))
x0[ind] = x
H = 0.5 * x0.dot(self.W.dot(x0)) + x0.dot(self.b)
for jj, subgroup in enumerate(quant_list):
temp_ind = [self.binding_names.index(bb) for bb in subgroup]
qx = x0[temp_ind]
H = H + q* self.qharmony(qx)
return H
def optimal_state_given_clamping(self, group_name_clamped, clampval, quant_list, q, method="Nelder-Mead"):
group_ind = self.group_names.index(group_name_clamped)
if method=="basinhopping":
res = optim.basinhopping(lambda x: -self.gharmony_given_clamping(x, group_name_clamped, clampval, quant_list, q),
np.random.uniform(size = self.nbindings - self.groups[group_ind].nbindings))
else:
res = optim.minimize(lambda x: -self.gharmony_given_clamping(x, group_name_clamped, clampval, quant_list, q),
np.random.uniform(size = self.nbindings - self.groups[group_ind].nbindings),
method = method)
print(res)
x0 = np.zeros(self.nbindings)
x0[self.clampind] = self.clampvec[self.clampind]
ind = list(set(np.arange(self.nbindings)) - set(self.clampind))
x0[ind] = res.x
self.optim_state = x0
print(x0)
def run_GscNet0(net, maxstep, plot=False, grayscale=False, colorbar=True,
tol=None, ema_factor=0.1, crop=True, update_q=True):
ngroups = len(net.groups)
act_trace = np.zeros((maxstep, net.nunits))
if tol is None:
ema_speed_trace = np.zeros(maxstep)
prev_act = copy.deepcopy(net.act)
ema_speed = 0
ema_factor = 0.1
for tt in np.arange(maxstep):
net.update(update_q=update_q)
act_trace[tt, :] = net.act
ema_speed_trace = -1
else:
if tol < 0:
sys.exit("The tolerance parameter should be a positive real number.")
else:
ema_speed_trace = np.zeros(maxstep)
prev_act = copy.deepcopy(net.act)
ema_speed = 0
ema_factor = ema_factor
for tt in np.arange(maxstep):
net.update(update_q=update_q)
diff = max(abs(net.act - prev_act))
if tt == 0:
ema_speed = diff
else:
ema_speed = ema_factor * ema_speed + (1-ema_factor) * diff
act_trace[tt, :] = net.act
ema_speed_trace[tt] = ema_speed
prev_act = copy.deepcopy(net.act)
if (ema_speed < tol) and (tol > 0):
break
ema_speed_trace = ema_speed_trace[tt]
if plot:
for ii, net in enumerate(net.groups):
curr_act_trace = act_trace[:, net.group_unit_idx[ii]]
actC_trace = net.S2C(curr_act_trace[:tt,:].T).T
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings", yticklabels=net.binding_names, grayscale=grayscale, colorbar=colorbar)
return [tt+1, net.act, act_trace[:tt,], ema_speed_trace]
def bind(fillernames, rolenames):
return [f + '/' + r for r in rolenames for f in fillernames]
def harmony(net, act, quant_list=None):
'''Compute total harmony at a given activation state.'''
return gharmony(net, act) + bharmony(net, act) + net.q * q2harmony(net, act, quant_list)
def oharmony(net, act):
'''Compute total harmony at a given activation state.'''
return gharmony(net, act) + bharmony(net, act)
def gharmony(net, act):
'''Compute harmony (given a harmonic grammar) at a given activation state.'''
return 0.5 * act.dot(net.WG).dot(act) + act.dot(net.bG) + act.dot(net.ext)
def bharmony(net, act):
'''Compute bowl harmony. act should be an array object.'''
return -0.5 * net.beta * (net.S2C(act) - net.z).dot(net.S2C(act) - net.z)
def q2harmony(net, act=None, quant_list=None):
'''Compute quantization harmony Q. By default, it considers filler competition in each role.'''
if act is None:
act = net.act
actC = net.S2C(act)
if quant_list is None:
if net.nfillers == 1:
const1 = 0
const2 = -np.sum(actC**2 * (1-actC)**2)
Q = const2
else:
const1 = -np.sum((np.sum(actC.reshape(net.nfillers,net.nroles,order='F')**2, axis=0)-1)**2)
const2 = -np.sum(actC**2 * (1-actC)**2)
Q = net.c * const1 + (1-net.c) * const2
else:
Q = 0
for jj, group in enumerate(quant_list):
temp_ind = [net.binding_names.index(bb) for bb in group]
actC0 = actC[temp_ind]
Q = Q + q2(net, actC0)
return Q
def q2(net, actC):
const1 = -np.sum(actC**2-1)**2
const2 = -np.sum(actC**2 * (1-actC)**2)
if len(actC) == 1:
Q = const2
else:
Q = net.c * const1 + (1-net.c) * const2
return Q
def hgrad(net, act, quant_list=None):
return ghgrad(net, act) + bhgrad(net, act) + net.q * q2hgrad(net, act, quant_list)
def ghgrad(net, act):
return net.WG.dot(act) + net.bG + net.ext
def bhgrad(net, act):
return -net.beta * (net.S2C(act) - net.z)
def q2hgrad(net, act, quant_list=None):
actC = net.S2C(act)
if quant_list is None:
if net.nfillers == 1:
const1 = 0
const2 = -2 * actC * (actC-1) * (2*actC-1)
else:
const1 = -4 * actC * ( np.tile( np.sum(actC.reshape(net.nfillers,net.nroles,order='F')**2, axis=0), (net.nfillers,1) ).flatten('F') - 1)
const2 = -2 * actC * (actC-1) * (2*actC-1)
q_grad = net.c * const1 + (1-net.c) * const2
else:
q_grad = np.zeros(net.nbindings)
for jj, subgroup in enumerate(quant_list):
curr_q_grad = np.zeros(net.nbindings)
temp_ind = [net.binding_names.index(bb) for bb in subgroup]
actC0 = actC[temp_ind]
curr_q_grad[temp_ind] = q2grad(net, actC0)
q_grad = q_grad + curr_q_grad
return net.C2S(q_grad)
def q2grad(net, actC):
'''q2grad for each group'''
const1 = -4 * actC * (np.sum(actC**2) - 1)
const2 = -2 * actC * (actC-1) * (2*actC-1)
return net.c * const1 + (1-net.c) * const2
def _blob(x,y,area,colour):
"""
Draws a square-shaped blob with the given area (< 1) at
the given coordinates.
http://wiki.scipy.org/Cookbook/Matplotlib/HintonDiagrams
"""
hs = np.sqrt(area) / 2
xcorners = np.array([x - hs, x + hs, x + hs, x - hs])
ycorners = np.array([y - hs, y - hs, y + hs, y + hs])
P.fill(xcorners, ycorners, colour, edgecolor=colour)
def hinton(W, maxWeight=None, xlabels=None, ylabels=None):
"""
Draws a Hinton diagram for visualizing a weight matrix.
Temporarily disables matplotlib interactive mode if it is on,
otherwise this takes forever.
http://wiki.scipy.org/Cookbook/Matplotlib/HintonDiagrams
"""
reenable = False
if P.isinteractive():
P.ioff()
P.clf()
height, width = W.shape
if not maxWeight:
maxWeight = 2**np.ceil(np.log(np.max(np.abs(W)))/np.log(2))
P.fill(np.array([0,width,width,0]),np.array([0,0,height,height]),'gray')
for x in range(width):
for y in range(height):
_x = x+1
_y = y+1
w = W[y,x]
if w > 0:
_blob(_x - 0.5, height - _y + 0.5, min(1,w/maxWeight),'white')
elif w < 0:
_blob(_x - 0.5, height - _y + 0.5, min(1,-w/maxWeight),'black')
if reenable:
P.ion()
label = False
if xlabels is not None:
P.xticks(np.arange(len(xlabels))+0.5, xlabels, rotation='vertical')
label = True
if ylabels is not None:
P.yticks(np.arange(len(ylabels))+0.5, ylabels[::-1])
label = True
if label:
P.xlim([0,width])
P.ylim([0,height])
P.gca().set_aspect('equal', adjustable='box')
P.gca().tick_params(direction='out')
else:
P.axis('off')
P.axis('equal')
P.show()
def heatmap(data, xlabel=None, ylabel=None, xticklabels=None, yticklabels=None,
grayscale=False, colorbar=True):
if grayscale:
cmap = plt.cm.get_cmap("gray_r")
else:
cmap=plt.cm.get_cmap("Reds")
plt.imshow(data, cmap=cmap, interpolation="nearest", aspect='auto')
if xlabel is not None:
plt.xlabel(xlabel, fontsize=16)
if ylabel is not None:
plt.ylabel(ylabel, fontsize=16)
if xticklabels is not None:
plt.xticks(np.arange(len(xticklabels)), xticklabels)
if yticklabels is not None:
plt.yticks(np.arange(len(yticklabels)), yticklabels)
if colorbar:
plt.colorbar()
plt.show()
def encode_symbols(nsymbols, reptype='local', dp=0, ndim=None):
if reptype == 'local':
sym_mat = np.eye(nsymbols)
else:
if ndim==None:
ndim = nsymbols
if isinstance(dp, Number):
sym_mat = dot_products(nsymbols, ndim, dp)
else:
sym_mat = dot_products2(nsymbols, ndim, dp)
return sym_mat
def dot_products(nsymbols, ndim, s):
dp_mat = s * np.ones((nsymbols, nsymbols)) + (1-s) * np.eye(nsymbols, nsymbols)
sym_mat = dot_products2(nsymbols, ndim, dp_mat)
return sym_mat
def dot_products2(nsymbols, ndim, dp_mat):
if not (dp_mat.T == dp_mat).all():
sys.exit('dot_products2: dp_mat must be symmetric')
if (np.diag(dp_mat) != 1).any():
sys.exit('dot_products2: dp_mat must have all ones on the main diagonal')
sym_mat = np.random.uniform(size=ndim*nsymbols).reshape(ndim, nsymbols, order='F')
min_step = .1
tol = 1e-6
max_iter = 100000
converged = False
for iter_num in range(1, max_iter+1):
inc = sym_mat.dot(sym_mat.T.dot(sym_mat) - dp_mat)
step = min(min_step, .01/abs(inc).max())
sym_mat = sym_mat - step * inc
max_diff = abs(sym_mat.T.dot(sym_mat)-dp_mat).max()
if max_diff <= tol:
converged = True
break
if not converged:
print("Didn't converge after %d iterations" % max_iter)
return sym_mat
def compute_TPmat(R, F):
TP = np.kron(R,F)
if TP.shape[0] == TP.shape[1]:
TPinv = la.inv(TP)
else:
TPinv = la.pinv(TP)
return TP, TPinv
def compute_sspace_biases(b_local, net):
b_distributed = np.zeros(net.nunits).reshape(net.nunits)
for ii in np.arange(0, net.nbindings):
si = net.TP[:, ii]
b_distributed = b_distributed + b_local[ii]*si/ np.dot(si, si)
return b_distributed
def compute_sspace_weights(W_local, net):
W_distributed = np.zeros((net.nunits, net.nunits))
for ii in np.arange(0, net.nbindings):
si = net.TP[:, ii]
for jj in np.arange(0, ii+1):
sj = net.TP[:, jj]
if ii != jj:
W_distributed = W_distributed + W_local[ii,jj] * (np.outer(si, sj) + np.outer(sj, si)) / (np.dot(si, si) * np.dot(sj, sj))
else:
W_distributed = W_distributed + .5 * W_local[ii,jj] * (np.outer(si, sj) + np.outer(sj, si)) / (np.dot(si, si) * np.dot(sj, sj))
return W_distributed
def compute_projmat(A):
return A.dot(la.inv(A.T.dot(A))).dot(A.T)
class sim():
def __init__(self, net, params):
self.net = net
self.params = params
def set_params(self, p_name, p_val):
self.params[p_name] = p_val
def set_seed(self, num):
np.random.seed(num)
def set_input(self, input_list, input_vals):
if not isinstance(input_list, list):
input_list = [input_list]
nwords = len(input_list)
self.inputC = np.zeros((self.net.nbindings, nwords))
for ii, binding in enumerate(input_list):
val = input_vals[ii]
inputC = np.zeros(self.net.nbindings)
curr_role = binding.split('/')[1]
role_list = [bb.split('/')[1] for bb in self.net.binding_names]
role_idx = [jj for jj, rr in enumerate(role_list) if curr_role == rr]
if not isinstance(binding, list):
binding = [binding]
binding_idx = self.net.find_bindings(binding)
if self.params['input_inhib']:
if self.params['input_method'] == 'ext':
inputC[role_idx] = -val
inputC[binding_idx] = val
self.inputC[:, ii] = inputC
if self.params['cumulative_input']:
self.inputC = self.inputC.cumsum(axis=1)
def simulate(self, params=None):
if params is None:
params = self.params
nrep = self.params['nrep']
maxstep = self.params['maxstep'].max()
nwords = len(self.input_list)
self.act_trace = np.zeros((nrep, nwords, maxstep+1, self.net.nunits))
self.T_trace = np.zeros((nrep, nwords, maxstep+1))
self.q_trace = np.zeros((nrep, nwords, maxstep+1))
self.speed_trace = np.zeros((nrep, nwords, maxstep+1))
self.ema_speed_trace = np.zeros((nrep, nwords, maxstep+1))
self.converged = np.ones((nrep, nwords), dtype=bool)
self.rt = np.zeros((nrep, nwords))
for rep_num in np.arange(self.params['nrep']):
self.net.unclamp()
self.net.clear_input()
self.net.reset()
self.set_input(self.input_list, self.input_vals)
for word_num, curr_input in enumerate(self.input_list):
curr_maxstep = self.params['maxstep'][word_num]
curr_inputC = self.inputC[:,word_num]
if self.params['input_method'] == 'ext':
self.net.extC = curr_inputC
self.net.ext = compute_sspace_biases(self.net.extC, self.net)
elif self.params['input_method'] == 'clamp':
sys.exit('Not yet implementd')
self.net.run(curr_maxstep, tol=self.params['tol'],
update_T=self.params['update_T'],
update_q=self.params['update_q'],
q_fun=self.params['q_fun'],
grid_points=self.params['grid_points'],
ext_overlap=self.params['ext_overlap'],
ext_overlap_steps=self.params['ext_overlap_steps'])
self.net.prev_extC = self.net.extC
self.act_trace[rep_num, word_num, :(curr_maxstep+1), :] = self.net.act_trace
self.rt[rep_num, word_num] = self.net.rt
self.T_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.T_trace
self.q_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.q_trace
self.speed_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.speed_trace
self.ema_speed_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.ema_speed_trace
self.converged[rep_num, word_num] = self.net.converged
def plot_act_trace(self, rep_num):
rep_ind = rep_num - 1
nwords = self.act_trace.shape[1]
curr_act_trace = self.act_trace[rep_ind, 0, :(self.rt[rep_ind, 0]+1), :]
if nwords > 1:
for wind in np.arange(1, nwords):
temp_act_trace = self.act_trace[rep_ind, wind, 1:(self.rt[rep_ind, wind]+1), :]
curr_act_trace = np.concatenate((curr_act_trace, temp_act_trace), axis=0)
actC_trace = self.net.S2C(curr_act_trace.T).T
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings",
xticklabels=None, yticklabels=self.net.binding_names, grayscale=False)
def plot_trace(self, varname, rep_num):
rep_ind = rep_num - 1
nwords = self.act_trace.shape[1]
curr_trace = getattr(self, varname+'_trace')[rep_ind, 0, :(self.rt[rep_ind, 0]+1)]
if nwords > 1:
for wind in np.arange(1, nwords):
temp_trace = getattr(self, varname+'_trace')[rep_ind, wind, 1:(self.rt[rep_ind, 0]+1)]
curr_trace = np.concatenate((curr_trace, temp_trace), axis=0)
curr_rt_trace = self.rt[rep_ind, :]
curr_rt_trace = curr_rt_trace.cumsum()
plt.plot(curr_trace)
for ii, rt in enumerate(curr_rt_trace):
plt.plot([rt, rt], [min(curr_trace), max(curr_trace)], 'g-')
plt.xlabel('Timestep', fontsize=16)
plt.ylabel(varname, fontsize=16)
def read_state(self, rep_num, word_num):
act = self.act_trace[rep_num-1, word_num-1, self.rt[rep_num-1, word_num-1], :]
self.net.read_state(act)
def read_grid_point(self, rep_num, word_num):
act = self.act_trace[rep_num-1, word_num-1, self.rt[rep_num-1, word_num-1], :]
self.net.read_grid_point(act)