"""
GSC Software:
Author: Pyeong Whan Cho
Department of Cognitive Science, Johns Hopkins University
Summer 2015
(based on the MATLAB program, LDNet2.0)
"""
from math import *
import numpy as np
from matplotlib import pyplot as plt
from scipy import optimize as optim
import pandas as pd
import pylab as P
from scipy import linalg as la
from numbers import Number
import sys
import copy
import itertools
import pprint
class Vividict(dict):
def __missing__(self, key):
value = self[key] = type(self)()
return value
class RefSet():
def __init__(self):
self.dict = Vividict()
def add(self, key, val, isSym=True, isTree=False, isSpanRole=False):
if isTree:
if isSpanRole:
sys.exit('Not yet implemented. Check tree.py.')
else:
self.dict[key]['sym'] = ['/'.join(binding) for binding in val.recursiveFRtoString]
else:
if isSym:
self.dict[key]['sym'] = val
else:
self.dict[key]['num'] = val
def disp(self):
pprint.pprint(self.dict)
class GscNet():
def __init__(self, filler_names, role_names, grammar=None,
WGC=None, bGC=None, extC=None, prev_extC=None, z=0.5, beta=4,
q_init=0, q_rate=4, q_max=50, q_fun='plinear', c=0.5,
T_init=0, dt=0.001, reptype_r='local', reptype_f='local',
dp_r=0, dp_f=0, ndim_r=None, ndim_f=None,
T_decay_rate=0, T_min=0, quant_list=None, grid_points=None,
F=None, R=None, getGPset=False):
'''Construct an instance of GscNet.'''
self.binding_names = bind(filler_names, role_names)
self.filler_names = filler_names
self.role_names = role_names
self.nbindings = len(self.binding_names)
self.nfillers = len(filler_names)
self.nroles = len(role_names)
if WGC is None:
WGC = np.zeros(self.nbindings**2).reshape(self.nbindings, self.nbindings, order='F')
if bGC is None:
bGC = np.zeros(self.nbindings)
if extC is None:
extC = np.zeros(self.nbindings)
if prev_extC is None:
prev_extC = np.zeros(self.nbindings)
self.WGC = WGC
self.bGC = bGC
self.extC = extC
self.prev_extC = prev_extC
if ndim_r is None:
ndim_r = self.nroles
if ndim_f is None:
ndim_f = self.nfillers
self.reptype_r = reptype_r
self.reptype_f = reptype_f
self.ndim_r = ndim_r
self.ndim_f = ndim_f
self.R = encode_symbols(self.nroles, reptype=reptype_r, dp=dp_r, ndim=ndim_r)
self.F = encode_symbols(self.nfillers, reptype=reptype_f, dp=dp_f, ndim=ndim_f)
if R is not None:
self.R = R
if F is not None:
self.F = F
self.TP, self.TPinv = compute_TPmat(self.R, self.F)
self.nunits = ndim_r * ndim_f
self.act = np.zeros(self.nunits)
self.actC = self.S2C()
self.z = z
self.beta = beta
self.WBC = -self.beta * np.eye(self.nbindings)
self.bBC = self.beta * self.z * np.ones(self.nbindings)
self.WC = self.WGC + self.WBC
self.bC = self.bGC + self.bBC
self.WG = compute_sspace_weights(self.WGC, self)
self.WB = compute_sspace_weights(self.WBC, self)
self.bG = compute_sspace_biases(self.bGC, self)
self.bB = compute_sspace_biases(self.bBC, self)
self.W = compute_sspace_weights(self.WC, self)
self.b = compute_sspace_biases(self.bC, self)
self.zeta = self.TP.dot(self.z * np.ones(self.nbindings))
self.ext = compute_sspace_biases(self.extC, self)
self.prev_ext = compute_sspace_biases(self.prev_extC, self)
self.check_beta()
self.q_init = q_init
self.q = copy.deepcopy(self.q_init)
self.q_rate = q_rate
self.q_fun = q_fun
self.q_max = q_max
self.q_time = 0
self.c = c
self.T_init = T_init
self.T = copy.deepcopy(self.T_init)
self.T_decay_rate = T_decay_rate
self.T_min = T_min
self.dt = dt
self.clamped = False
self.act_trace = None
self.quant_list = quant_list
if np.allclose(self.W, self.W.T) == False:
sys.exit("The weight matrix (2D array) is not symmetric. Please check it.")
ndigits = len(str(self.nunits))
self.unit_names = ['U' + str(ii+1).zfill(ndigits) for ii in list(range(self.nunits))]
self.speed = None
self.ema_speed = None
self.grid_points = grid_points
self.getGPset = getGPset
if getGPset:
self.all_grid_points()
def randomize_state(self, minact=0, maxact=1):
'''Set the activation state to a random vector inside a unit hypercube'''
actC = np.random.uniform(minact, maxact, self.nbindings)
self.act = self.C2S(actC)
def set_init_state(self, mu=0.2, sd=0.2):
actC = np.random.normal(loc=mu, scale=sd, size=self.nbindings)
self.act = self.C2S(actC)
def randomize_weights(self, mu=0, sigma=1):
'''Randomize WC'''
WC = sigma*np.random.randn(self.nbindings**2).reshape(self.nbindings, self.nbindings, order='F') + mu
WC = (WC + WC.T)/2
self.WC = WC
self.W = compute_sspace_weights(self.WC, self)
def all_grid_points(self):
if self.quant_list is None:
quant_list = [None] * self.nroles
for rind, role in enumerate(self.role_names):
quant_list[rind] = [self.binding_names[ii] for ii in self.find_roles(role)]
else:
quant_list = self.quant_list
gpset = list(itertools.product(*quant_list))
gpset = [list(gp) for ii, gp in enumerate(gpset)]
gpset_gh = np.zeros(len(gpset))
for ii, gp in enumerate(gpset):
actC = np.zeros(self.nbindings)
actC[self.find_bindings(gp)] = 1.0
gpset_gh[ii] = self.Hg(act=self.C2S(actC))
idx = np.argsort(gpset_gh)[::-1]
self.gpset = [gpset[ii] for ii in idx]
self.gpset_gh = gpset_gh[idx]
def set_weight(self, binding_name1, binding_name2, weight, symmetric=True):
'''Set the weight of a connection between binding1 and binding2.
When symmetric is set to True (default), the connection weight from
binding2 to binding1 is set to the same value.'''
idx1 = self.find_bindings(binding_name1)
idx2 = self.find_bindings(binding_name2)
if symmetric:
self.WGC[idx1, idx2] = self.WGC[idx2, idx1] = weight
else:
self.WGC[idx2, idx1] = weight
self.WC = self.WGC + self.WBC
self.WG = compute_sspace_weights(self.WGC, self)
self.W = compute_sspace_weights(self.WC, self)
def set_bias(self, binding_name, bias):
'''Set the bias of a binding to [bias].'''
idx = self.find_bindings(binding_name)
self.bGC[idx] = bias
self.bC = self.bGC + self.bBC
self.bG = compute_sspace_biases(self.bGC, self)
self.b = compute_sspace_biases(self.bC, self)
def set_role_bias(self, role_name, bias):
'''Set the bias of bindings of all fillers with particular roles to [bias].'''
role_list = [bb.split('/')[1] for bb in self.binding_names]
if not isinstance(role_name, list):
role_name = [role_name]
for jj, role in enumerate(role_name):
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
self.bGC[idx] = bias
self.bC = self.bGC + self.bBC
self.bG = compute_sspace_biases(self.bGC, self)
self.b = compute_sspace_biases(self.bC, self)
def set_filler_bias(self, filler_name, bias):
'''Set the bias of bindings of all roles with particular fillers to [bias].'''
filler_list = [bb.split('/')[0] for bb in self.binding_names]
if not isinstance(filler_name, list):
filler_name = [filler_name]
for jj, filler in enumerate(filler_name):
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
self.bGC[idx] = bias
self.bC = self.bGC + self.bBC
self.bG = compute_sspace_biases(self.bGC, self)
self.b = compute_sspace_biases(self.bC, self)
def set_state(self, binding_names, vals=1.0):
idx = self.find_bindings(binding_names)
self.actC = np.zeros(self.nbindings)
self.actC[idx] = vals
self.act = self.C2S()
def vec2mat(self, act=None):
if act is None:
act = self.S2C()
return act.reshape(self.nfillers, self.nroles, order='F')
def C2S(self, C=None):
'''Change basis: from C- to S-space.'''
if C is None:
C = self.actC
return self.TP.dot(C)
def S2C(self, S=None):
'''Change basis: from S- to C-space.'''
if S is None:
S = self.act
return self.TPinv.dot(S)
def find_bindings(self, binding_names):
'''Find the indices of the bindings from the list of binding names.'''
if not isinstance(binding_names, list):
binding_names = [binding_names]
return [self.binding_names.index(bb) for bb in binding_names]
def find_roles(self, role_name):
if not isinstance(role_name, list):
role_name = [role_name]
role_list = [bb.split('/')[1] for bb in self.binding_names]
role_idx = []
for jj, role in enumerate(role_name):
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
role_idx += idx
return role_idx
def find_fillers(self, filler_name):
if not isinstance(filler_name, list):
filler_name = [filler_name]
filler_list = [bb.split('/')[0] for bb in self.binding_names]
filler_idx = []
for jj, filler in enumerate(filler_name):
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
filler_idx += idx
return filler_idx
def read_state(self, act=None):
'''Print the current state (C-SPACE) in a readable format. Pandas should be installed.'''
if act is None:
act = self.act
actC = self.vec2mat(self.S2C(act))
print(pd.DataFrame(actC, index=self.filler_names, columns=self.role_names))
def read_grid_point(self, act=None, disp=True):
'''Print a grid point close to the current state. The grid point will be
chosen by the snapping method (winner-takes-it-all).'''
if act is None:
act = self.act
if self.quant_list is None:
actC = self.vec2mat(self.S2C(act))
winner_idx = np.argmax(actC, axis=0)
winners = [self.filler_names[ii] for ii in winner_idx]
winners = ["%s/%s" % bb for bb in zip(winners, self.role_names)]
else:
actC = self.S2C(act)
winners = []
for kk, group in enumerate(self.quant_list):
idx = self.find_bindings(group)
winner_idx = [idx[ii] for ii, jj in enumerate(actC[idx]) if jj == max(actC[idx])]
winner = [self.binding_names[jj] for ii, jj in enumerate(winner_idx)]
winners.extend(winner)
if disp:
print(winners)
return winners
def read_weight(self, which='WGC'):
'''Print the weight matrix in a readable format (in the pattern coordinate).'''
if which[-1] == 'C':
print(pd.DataFrame(getattr(self, which), index=self.binding_names, columns=self.binding_names))
else:
print(pd.DataFrame(getattr(self, which), index=self.unit_names, columns=self.unit_names))
def read_bias(self, which='bGC', print_vertical=True):
'''Print the bias vector (in the pattern coordinate).'''
if which[-1] == 'C':
if print_vertical:
print(pd.DataFrame(getattr(self, which).reshape(self.nbindings, 1), index=self.binding_names, columns=["bias"]))
else:
print(pd.DataFrame(getattr(self, which).reshape(1, self.nbindings), index=["bias"], columns=self.binding_names))
else:
if print_vertical:
print(pd.DataFrame(getattr(self, which).reshape(self.nbindings, 1), index=self.unit_names, columns=["bias"]))
else:
print(pd.DataFrame(getattr(self, which).reshape(1, self.nbindings), index=["bias"], columns=self.unit_names))
def hinton(self, which='WGC', label=True):
'''Draw a hinton diagram'''
if label:
if which[-1] == 'C':
labels=self.binding_names
else:
labels=self.unit_names
else:
labels=None
hinton(getattr(self, which), xlabels=labels, ylabels=labels)
def act_clamped(self, act=None):
'''S-space'''
if act is None:
act = self.act
return self.projmat.dot(act) + self.clampvec
def check_beta(self, disp=False):
'Compute and print the recommended beta value given the weights and biases in the C-space.'''
eigvals, eigvecs = la.eigh(self.WGC)
eig_max = max(eigvals)
if self.nbindings == 1:
beta1 = -(self.bGC+self.extC)/self.z
beta2 = (self.bGC+self.extC+eig_max)/(1-self.z)
else:
beta1 = -min(self.bGC+self.extC)/self.z
beta2 = (max(self.bGC+self.extC)+eig_max)/(1-self.z)
beta_min = max(eig_max, beta1, beta2)
if self.beta <= beta_min:
sys.exit("Beta (bowl strength) should be greater than %.4f." % beta_min)
if disp:
print('Recommended beta min = %.3f' % beta_min)
def info(self):
'''Print the network information.'''
print('Fillers = ', self.filler_names)
print('Roles = ', self.role_names)
print('Num_of_units = ', self.nunits)
print('Current T = ', self.T)
def compute_dist(self, ref_point, norm_ord=2, space='S'):
"""
Compute the distance of the current state from a grid point.
[grid point] is a set of bindings.
"""
idx = self.find_bindings(ref_point)
destC = np.zeros(self.nbindings)
destC[idx] = 1.0
if space == 'S':
state1 = self.act
state2 = self.C2S(destC)
elif space == 'C':
state1 = self.S2C(self.act)
state2 = destC
return np.linalg.norm(state1-state2, ord=norm_ord)
def set_input(self, binding_names, ext_vals, inhib_comp=False):
'''Set external input.'''
if not isinstance(ext_vals, list):
ext_vals = [ext_vals]
if not isinstance(binding_names, list):
binding_names = [binding_names]
if len(ext_vals) > 1:
if len(binding_names) != len(ext_vals):
sys.exit("binding_names and ext_vals have different lengths.")
self.clear_input()
if inhib_comp:
if self.quant_list is None:
role_names = [b.split('/')[1] for b in binding_names]
idx = self.find_roles(role_names)
self.extC[idx] = -np.asarray(ext_vals)
else:
for ii, bb in enumerate(binding_names):
group_id = [jj for jj, group in enumerate(self.quant_list) if bb in group]
for kk, gid in enumerate(group_id):
curr_bindings = self.quant_list[gid]
idx = self.find_bindings(curr_bindings)
self.extC[idx] = -np.asarray(ext_vals)
idx = self.find_bindings(binding_names)
self.extC[idx] = ext_vals
self.ext = compute_sspace_biases(self.extC, self)
def clear_input(self):
'''Remove external input.'''
self.extC = np.zeros(self.nbindings)
self.ext = compute_sspace_biases(self.extC, self)
def clamp(self, binding_names, clamp_vals=1.0, clamp_comp=False):
'''Clamp f/r bindings'''
if not isinstance(clamp_vals, list):
clamp_vals = [clamp_vals]
if not isinstance(binding_names, list):
binding_names = [binding_names]
if len(clamp_vals) > 1:
if len(clamp_vals) != len(binding_names):
sys.exit('The number of bindings clamped is not equal to the number of values provided.')
self.clamped = True
self.binding_names_clamped = binding_names
clampvecC = np.zeros(self.nbindings)
if clamp_comp:
if self.quant_list is None:
role_names = [b.split('/')[1] for b in binding_names]
idx1 = self.find_roles(role_names)
clampvecC[idx1] = 0.0
else:
for ii, bb in enumerate(binding_names):
group_id = [jj for jj, group in enumerate(self.quant_list) if bb in group]
for kk, gid in enumerate(group_id):
curr_bindings = self.quant_list[gid]
idx1 = self.find_bindings(curr_bindings)
clampvecC[idx1] = 0.0
idx = self.find_bindings(binding_names)
clampvecC[idx] = clamp_vals
self.clampvecC = clampvecC
if clamp_comp:
idx += idx1
idx.sort()
idx0 = [bb for bb in np.arange(self.nbindings) if bb not in idx]
A = self.TP[:, idx0]
if len(idx0) > 0:
self.projmat = compute_projmat(A)
else:
self.projmat = np.zeros((self.nunits, self.nunits))
self.clampvec = self.C2S(clampvecC)
self.act = self.act_clamped(self.act)
self.actC = self.S2C()
def unclamp(self):
if self.clamped is True:
del self.clampvec
del self.clampvecC
del self.projmat
del self.binding_names_clamped
self.clamped = False
def H(self, act=None, quant_list=None):
'''Compute the total harmony of the current state'''
if act is None:
act = self.act
if quant_list is None:
quant_list = self.quant_list
return Hg(net=self, act=act) + self.q * Q(net=self, n=act)
def Hg(self, act=None):
'''Compute the grammar harmony of the current state'''
if act is None:
act = self.act
return Hg(net=self, act=act)
def Hq(self, act=None, quant_list=None):
if act is None:
act = self.act
if quant_list is None:
quant_list = self.quant_list
return Q(net=self, n=act)
def H0(self, act=None):
'''Compute the grammar harmony of the current state'''
if act is None:
act = self.act
return H0(net=self, act=act)
def H1(self, act=None):
'''Compute the grammar harmony of the current state'''
if act is None:
act = self.act
return H1(net=self, act=act)
def HGrad(self, act=None, quant_list=None):
'''Compute the harmony gradient evaluated at the current state'''
if act is None:
act = self.act
if quant_list is None:
quant_list = self.quant_list
return HgGrad(net=self, act=act) + self.q * QGrad(net=self, n=act)
def HgGrad(self, act=None):
'''Compute the grammar harmony of the current state'''
if act is None:
act = self.act
return HgGrad(net=self, act=act)
def H0Grad(self, act=None, quant_list=None):
'''Compute the harmony gradient evaluated at the current state'''
if act is None:
act = self.act
if quant_list is None:
quant_list = self.quant_list
return H0Grad(net=self, act=act, quant_list=quant_list)
def H1Grad(self, act=None):
'''Compute the grammar harmony of the current state'''
if act is None:
act = self.act
return H1Grad(net=self, act=act)
def HqGrad(self, act=None, quant_list=None):
if act is None:
act = self.act
if quant_list is None:
quant_list = self.quant_list
return QGrad(net=self, n=act)
def update(self, update_T=True, update_q=True, q_fun=None, space='S', norm_ord=2):
if q_fun is None:
q_fun = self.q_fun
self.prev_act = copy.deepcopy(self.act)
self.update_state()
self.update_speed(space=space, norm_ord=norm_ord)
if self.grid_points is not None:
self.update_dist(self.grid_points, space=space, norm_ord=norm_ord)
if update_T:
self.update_T()
if update_q:
self.update_q(fun=q_fun)
def update_state(self):
'''Update the current state (with noise)'''
self.act += self.dt * self.HGrad(quant_list=self.quant_list)
self.add_noise()
if self.clamped:
self.act = self.act_clamped()
self.actC = self.S2C()
def add_noise(self):
'''Add a noise to the current activation state.'''
noise = sqrt(2 * self.T * self.dt) * np.random.randn(self.nunits)
self.act += noise
def update_T(self, method='exponential'):
'''Update the current temperature'''
if method == 'exponential':
self.T = (1-self.T_decay_rate) * (self.T - self.T_min) + self.T_min
def update_q(self, fun='plinear'):
if fun == 'log':
self.q_time += 1
self.q = self.q_rate * log(self.q_time)
if fun == 'linear':
self.q += self.q_rate
if fun == 'plinear':
if self.q <= self.q_max:
self.q += self.q_rate
def reset(self):
self.q = self.q_init
self.q_time = 0
self.T = self.T_init
self.randomize_state()
self.prev_extC = np.zeros(self.nbindings)
self.prev_ext = compute_sspace_biases(self.prev_extC, self)
self.unclamp()
self.clear_input()
def update_speed(self, ema_factor=0.1, norm_ord=2, space='S'):
if space == 'S':
diff = self.act - self.prev_act
elif space == 'C':
diff = self.S2C(self.act) - self.S2C(self.prev_act)
self.speed = np.linalg.norm(diff, ord=norm_ord)
if self.ema_speed is None:
self.ema_speed = self.speed
else:
self.ema_speed = ema_factor * self.ema_speed + (1-ema_factor) * self.speed
def update_dist(self, grid_points, norm_ord=2, space='S'):
if not any(isinstance(tt, list) for tt in grid_points):
grid_points = [grid_points]
dist = np.zeros(len(grid_points))
for ii, grid_point in enumerate(grid_points):
dist[ii] = self.compute_dist(ref_point=grid_point, norm_ord=norm_ord, space=space)
self.dist = dist
def update_traces(self):
self.act_trace[self.step, :] = self.act
self.H_trace[self.step] = self.H()
self.Hg_trace[self.step] = self.Hg()
self.Hq_trace[self.step] = self.Hq()
self.H0_trace[self.step] = self.H0()
self.H1_trace[self.step] = self.H1()
self.speed_trace[self.step] = self.speed
self.ema_speed_trace[self.step] = self.ema_speed
self.q_trace[self.step] = self.q
self.T_trace[self.step] = self.T
self.gp_trace[self.step] = self.read_grid_point(disp=False)
if self.grid_points is not None:
self.dist_trace[self.step, :] = self.dist
def run(self, maxstep, plot=False, grayscale=False, colorbar=True,
tol=None, ema_factor=0.1, update_T=True, update_q=True, q_fun=None,
grid_points=None, testvar='ema_speed', space='S', norm_ord=2,
ext_overlap=False, ext_overlap_steps=0):
'''
Run simulations for [maxstep] steps.
'''
if q_fun is None:
q_fun = self.q_fun
if grid_points is None:
grid_points = self.grid_points
self.converged = False
self.act_trace = np.zeros((maxstep+1, self.nunits))
self.H_trace = np.zeros(maxstep+1)
self.Hg_trace = np.zeros(maxstep+1)
self.Hq_trace = np.zeros(maxstep+1)
self.H0_trace = np.zeros(maxstep+1)
self.H1_trace = np.zeros(maxstep+1)
self.speed_trace = np.zeros(maxstep+1)
self.ema_speed_trace = np.zeros(maxstep+1)
self.q_trace = np.zeros(maxstep+1)
self.T_trace = np.zeros(maxstep+1)
self.gp_trace = [[] for _ in range(maxstep+1)]
if grid_points is not None:
if not any(isinstance(pp, list) for pp in grid_points):
grid_points = [grid_points]
n_grid_points = len(grid_points)
self.dist_trace = np.zeros((maxstep+1, len(grid_points)))
self.step = 0
if grid_points is not None:
self.dist = np.empty(n_grid_points)
self.dist[:] = np.NAN
self.update_traces()
prev_extC = copy.deepcopy(self.prev_extC)
curr_extC = copy.deepcopy(self.extC)
for tt in np.arange(maxstep)+1:
self.step = tt
if ext_overlap and (tt <= ext_overlap_steps):
self.extC = (1 - tt/ext_overlap_steps) * prev_extC + (tt/ext_overlap_steps) * curr_extC
self.ext = compute_sspace_biases(self.extC, self)
self.update(update_T=update_T, update_q=update_q, q_fun=q_fun,
space=space, norm_ord=norm_ord)
self.update_traces()
if tol is not None:
self.check_convergence(tol=tol, testvar=testvar)
if self.converged:
break
self.rt = self.step
if plot:
actC_trace = self.S2C(self.act_trace[:(self.rt+1),:].T).T
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings",
yticklabels=self.binding_names, grayscale=grayscale, colorbar=colorbar)
def check_convergence(self, tol, testvar='ema_speed'):
'''Check if the convergence criterion (distance vs. ema_speed) has been satisfied.'''
if testvar == 'dist':
if (self.dist < tol).any():
self.converged = True
if testvar == 'ema_speed':
if self.ema_speed < tol:
self.converged = True
def simulate(self):
return 0
def plot_act_trace(self, timesteps=None):
if hasattr(self, 'act_trace'):
if timesteps is None:
timesteps = list(range(self.rt+1))
actC_trace = self.S2C(self.act_trace[timesteps,:].T).T
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings",
xticklabels=None, yticklabels=self.binding_names, grayscale=False)
else:
sys.exit("There is no [act_trace] attribute in the current object.")
def plot_trace(self, varname, timesteps=None):
if timesteps is None:
timesteps = list(range(self.rt+1))
curr_trace = getattr(self, varname+'_trace')
plt.plot(curr_trace[timesteps])
plt.xlabel('Timestep', fontsize=16)
plt.ylabel(varname, fontsize=16)
plt.show()
def optim(self, initvals=[], method='nelder-mead', options={'xtol': 1e-8, 'disp': True}):
'''Find a local optimum given an initial guess (by default, the current state is used)'''
if initvals==[]:
initvals = self.act
res = optim.minimize(lambda x: -harmony(self, x), initvals, method='nelder-mead', options=options)
res.fun = -res.fun
return res
def harmony_landscape(self, binding_name, minact, maxact):
'''1D harmony landscape'''
if isinstance(binding_name, list):
if len(binding_name) > 1:
sys.exit("Choose only one binding.!")
idx = self.binding_names.index(binding_name)
agrid = np.linspace(minact, maxact, 1000)
hgrid = np.zeros(len(agrid))
actC = self.S2C()
for ii, aa in enumerate(agrid):
act0 = actC
act0[idx] = aa
act = self.C2S(act0)
hgrid[ii] = harmony(net=self, act=act)
return agrid, hgrid
class WeightTP():
def __init__(self, binding_names):
self.binding_names = binding_names
self.nbindings = len(binding_names)
self.WGC = np.zeros((self.nbindings, self.nbindings))
def set_weight(self, binding1, binding2, weight):
idx1 = self.binding_names.index(binding1)
idx2 = self.binding_names.index(binding2)
self.WGC[idx1, idx2] = self.WGC[idx2, idx1] = weight
def show(self):
print(pd.DataFrame(self.WGC, index=self.binding_names, columns=self.binding_names))
class BiasTP():
def __init__(self, binding_names):
self.binding_names = binding_names
self.nbindings = len(binding_names)
self.bGC = np.zeros(self.nbindings)
def set_bias(self, binding, bias):
idx = self.binding_names.index(binding)
self.bGC[idx] = bias
def set_role_bias(self, role, bias):
role_list = [b.split('/')[1] for b in self.binding_names]
idx = [ii for ii, rr in enumerate(role_list) if role == rr]
self.bGC[idx] = bias
def set_filler_bias(self, filler, bias):
filler_list = [b.split('/')[0] for b in self.binding_names]
idx = [ii for ii, ff in enumerate(filler_list) if filler == ff]
self.bGC[idx] = bias
def show(self):
print(pd.DataFrame(self.bGC, index=self.binding_names, columns=["bias"]))
class GscNet0():
def __init__(self, net_list, group_names):
self.ngroups = len(net_list)
self.group_names = group_names
self.groups = net_list
self.binding_names = []
for idx, net in enumerate(self.groups):
self.binding_names = self.binding_names + [b + ':' + group_names[idx] for b in net.binding_names]
self.nbindings = len(self.binding_names)
self.WGC = np.zeros((self.nbindings, self.nbindings))
self.WBC = np.zeros((self.nbindings, self.nbindings))
for ii, net in enumerate(self.groups):
curr_bindings = [b + ':' + group_names[ii] for b in net.binding_names]
idx = [self.binding_names.index(bb) for bb in curr_bindings]
self.WGC[np.ix_(idx,idx)] = net.WGC
self.WBC[np.ix_(idx,idx)] = net.WBC
self.WC = self.WGC + self.WBC
self.bGC = np.zeros(self.nbindings)
self.bBC = np.zeros(self.nbindings)
self.extC = np.zeros(self.nbindings)
nunits = 0
self.group_unit_idx = list()
self.group_binding_idx = list()
for ii, net in enumerate(self.groups):
curr_bindings = [b + ':' + group_names[ii] for b in net.binding_names]
idx = [self.binding_names.index(bb) for bb in curr_bindings]
self.bGC[np.ix_(idx)] = net.bGC
self.bBC[np.ix_(idx)] = net.bBC
self.extC[np.ix_(idx)] = net.extC
unit_idx = np.arange(net.nunits) + nunits
self.group_unit_idx.append(unit_idx)
self.group_binding_idx.append(idx)
nunits = nunits + net.nunits
self.bC = self.bGC + self.bBC
self.nunits = nunits
self.W = self.compute_sspace_weights()
self.b, self.ext = self.compute_sspace_b_and_ext()
self.act = np.zeros(self.nunits)
def compute_sspace_b_and_ext(self):
b_distributed = np.zeros(self.nunits)
ext_distributed = np.zeros(self.nunits)
for ii, net in enumerate(self.groups):
b_distributed[self.group_unit_idx[ii]] = net.b
ext_distributed[self.group_unit_idx[ii]] = net.ext
return b_distributed, ext_distributed
def compute_sspace_weights(self):
W_distributed = np.zeros((self.nunits, self.nunits))
for ii, net1 in enumerate(self.groups):
idx1 = self.group_unit_idx[ii]
bidx1 = self.group_binding_idx[ii]
for jj, net2 in enumerate(self.groups):
idx2 = self.group_unit_idx[jj]
bidx2 = self.group_binding_idx[jj]
W_local = self.WC[np.ix_(bidx1, bidx2)]
if ii == jj:
W_distributed[np.ix_(idx1, idx1)] = net1.W
else:
W_temp = np.zeros((net1.nunits, net2.nunits))
for kk in np.arange(0, net1.nbindings):
si = net1.TP[:, kk]
for ll in np.arange(0, net2.nbindings):
sj = net2.TP[:, ll]
W_temp = W_temp + W_local[kk,ll] * np.outer(si, sj) / (np.dot(si, si) * np.dot(sj, sj))
W_distributed[np.ix_(idx1, idx2)] = W_temp
return W_distributed
def randomize_state(self, minact=0, maxact=1):
'''Set the activation state to a random vector inside a unit hypercube'''
for ii, net in enumerate(self.groups):
net.randomize_state()
def set_weight(self, binding_name1, binding_name2, weight, symmetric=True):
idx1 = self.binding_names.index(binding_name1)
idx2 = self.binding_names.index(binding_name2)
if symmetric:
self.WGC[idx1, idx2] = self.WGC[idx2, idx1] = weight
else:
self.WGC[idx2, idx1] = weight
self.WC = self.WGC + self.WBC
self.W = self.compute_sspace_weights()
def set_bias(self, binding_name, bias):
idx = self.binding_names.index(binding_name)
self.bGC[idx] = bias
self.bC = self.bGC + self.bBC
self.b, self.ext = self.compute_sspace_b_and_ext()
def set_role_bias(self, role_name, bias):
role_list = [bb.split('/')[1] for bb in self.binding_names]
idx = [ii for ii, rr in enumerate(role_list) if role_name in rr]
self.bGC[idx] = bias
self.bC = self.bGC + self.bBC
self.b, self.ext = self.compute_sspace_b_and_ext()
def set_filler_bias(self, filler_name, bias):
filler_list = [bb.split('/')[0] for bb in self.binding_names]
idx = [ii for ii, ff in enumerate(filler_list) if filler_name in ff]
self.bGC[idx] = bias
self.bC = self.bGC + self.bBC
self.b, self.ext = self.compute_sspace_b_and_ext()
def set_seed(self, num):
np.random.seed(num)
def find_bindings(self, binding_names):
'''Find the indices of the bindings from the list of binding names.'''
if not isinstance(binding_names, list):
binding_names = [binding_names]
return [self.binding_names.index(bb) for bb in binding_names]
def read_state(self):
'''Print the current state (C-SPACE) in a readable format. Pandas should be installed.'''
for ii, net in enumerate(self.groups):
print("Group: %s" % self.group_names[ii])
net.read_state()
print("\n")
def read_grid_point(self):
for ii, net in enumerate(self.groups):
print("Group: %s" % self.group_names[ii])
print(net.read_grid_point(disp=False))
print("\n")
def read_weight(self):
print(pd.DataFrame(self.WC, index=self.binding_names, columns=self.binding_names))
print(pd.DataFrame(self.bC.reshape(1, self.nbindings), index=["bias"], columns=self.binding_names))
def hinton(self, label=True):
'''Draw a hinton diagram'''
if label:
labels=self.binding_names
else:
labels=None
hinton(self.WC, xlabels=labels, ylabels=labels)
def set_input(self, binding_names, ext_vals):
if len(binding_names) != len(ext_vals):
sys.exit("binding_names and ext_vals have different lengths.")
self.extC = np.zeros(self.nbindings)
ind = self.find_bindings(binding_names)
self.extC[ind] = ext_vals
self.ext = compute_sspace_biases(self.extC, self)
def clear_input(self):
self.extC = np.zeros(self.nbindings)
self.ext = compute_sspace_biases(self.extC, self)
def clamp(self, binding_names, clamp_vals):
'''Clamp f/r bindings'''
if len(clamp_vals) != len(binding_names):
sys.exit('The number of bindings clamped is not equal to the number of values provided.')
self.clamped = True
temp = [b.split(':') for b in binding_names]
group_num = [self.group_names.index(b[1]) for b in temp]
bb_names = [b[0] for b in temp]
for ii, gg in enumerate(set(group_num)):
idx1 = [jj for jj, val in enumerate(group_num) if val==gg]
curr_bindings = [bb_names[jj] for jj in idx1]
curr_vals = [clamp_vals[jj] for jj in idx1]
self.groups[gg].clamp(binding_names=curr_bindings, clamp_vals=curr_vals)
def unclamp(self):
self.clamped = False
for ii, net in enumerate(self.groups):
if net.clamped:
net.unclamp()
def update(self, update_T=True, update_q=True):
'''Update the current state (with noise) for each group'''
for ii, net1 in enumerate(self.groups):
extC = np.zeros(net1.nbindings)
binding_names1 = [b + ':' + self.group_names[ii] for b in net1.binding_names]
idx1 = [self.binding_names.index(b) for b in binding_names1]
for jj, net2 in enumerate(self.groups):
if not (ii==jj):
binding_names2 = [b + ':' + self.group_names[jj] for b in net2.binding_names]
idx2 = [self.binding_names.index(b) for b in binding_names2]
curr_WGC = self.WGC[np.ix_(idx1, idx2)]
extC = extC + curr_WGC.dot(net2.S2C(net2.act))
net1.ext = net1.C2S(extC)
for ii, net in enumerate(self.groups):
net.update(update_T = update_T, update_q = update_q)
self.act[self.group_unit_idx[ii]] = net.act
def reset(self):
for ii, net in enumerate(self.groups):
net.reset()
def harmony(self, act=None, quant_list=None):
'''Compute the total harmony of the current state'''
if act is None:
act = self.act
return harmony(net=self, act=act, quant_list=quant_list)
def gharmony(self, act=None):
'''Compute the grammar harmony of the current state'''
if act is None:
act = self.act
return gharmony(net=self, act=act)
def hgrad(self, act=None, quant_list=None):
'''Compute the harmony gradient evaluated at the current state'''
if act is None:
act = self.act
return hgrad(net=self, act=act, quant_list=quant_list)
def run(self, maxstep, tol=None):
self.rt, self.final_state, self.act_trace, self.ema_speed_trace = run_GscNet0(self, maxstep=maxstep, tol=tol)
def plot_act_trace(self):
for ii, net in enumerate(self.groups):
curr_act_trace = self.act_trace[:, self.group_unit_idx[ii]]
actC_trace = net.TPinv.dot(curr_act_trace.T).T
print("Group: %s" % self.group_names[ii])
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings", yticklabels=net.binding_names, grayscale=False, colorbar=True)
def gharmony_given_clamping(self, x, group_name_clamped, clampval, quant_list, q):
x0 = np.zeros(self.nbindings)
self.clamp_group(group_name_clamped, clampval)
x0[self.clampind] = self.clampvec[self.clampind]
ind = list(set(np.arange(self.nbindings)) - set(self.clampind))
x0[ind] = x
H = 0.5 * x0.dot(self.W.dot(x0)) + x0.dot(self.b)
for jj, subgroup in enumerate(quant_list):
temp_ind = [self.binding_names.index(bb) for bb in subgroup]
qx = x0[temp_ind]
H = H + q* self.qharmony(qx)
return H
def optimal_state_given_clamping(self, group_name_clamped, clampval, quant_list, q, method="Nelder-Mead"):
group_ind = self.group_names.index(group_name_clamped)
if method=="basinhopping":
res = optim.basinhopping(lambda x: -self.gharmony_given_clamping(x, group_name_clamped, clampval, quant_list, q),
np.random.uniform(size = self.nbindings - self.groups[group_ind].nbindings))
else:
res = optim.minimize(lambda x: -self.gharmony_given_clamping(x, group_name_clamped, clampval, quant_list, q),
np.random.uniform(size = self.nbindings - self.groups[group_ind].nbindings),
method = method)
print(res)
x0 = np.zeros(self.nbindings)
x0[self.clampind] = self.clampvec[self.clampind]
ind = list(set(np.arange(self.nbindings)) - set(self.clampind))
x0[ind] = res.x
self.optim_state = x0
print(x0)
def run_GscNet0(net, maxstep, plot=False, grayscale=False, colorbar=True,
tol=None, ema_factor=0.1, crop=True, update_q=True):
ngroups = len(net.groups)
act_trace = np.zeros((maxstep, net.nunits))
if tol is None:
ema_speed_trace = np.zeros(maxstep)
prev_act = copy.deepcopy(net.act)
ema_speed = 0
ema_factor = 0.1
for tt in np.arange(maxstep):
net.update(update_q=update_q)
act_trace[tt, :] = net.act
ema_speed_trace = -1
else:
if tol < 0:
sys.exit("The tolerance parameter should be a positive real number.")
else:
ema_speed_trace = np.zeros(maxstep)
prev_act = copy.deepcopy(net.act)
ema_speed = 0
ema_factor = ema_factor
for tt in np.arange(maxstep):
net.update(update_q=update_q)
diff = max(abs(net.act - prev_act))
if tt == 0:
ema_speed = diff
else:
ema_speed = ema_factor * ema_speed + (1-ema_factor) * diff
act_trace[tt, :] = net.act
ema_speed_trace[tt] = ema_speed
prev_act = copy.deepcopy(net.act)
if (ema_speed < tol) and (tol > 0):
break
ema_speed_trace = ema_speed_trace[tt]
if plot:
for ii, net in enumerate(net.groups):
curr_act_trace = act_trace[:, net.group_unit_idx[ii]]
actC_trace = net.S2C(curr_act_trace[:tt,:].T).T
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings", yticklabels=net.binding_names, grayscale=grayscale, colorbar=colorbar)
return [tt+1, net.act, act_trace[:tt,], ema_speed_trace]
def bind(fillernames, rolenames, sep='/'):
return [f + sep + r for r in rolenames for f in fillernames]
def H(net, act, quant_list=None):
return Hg(net, act) + net.q * Q(n=act, net=net)
def HGrad(net, act, quant_list=None):
return HgGrad(net, act) + net.q * QGrad(n=act, net=net)
def Hg(net, act):
return H0(net, act) + H1(net, act)
def HgGrad(net, act):
return H0Grad(net, act) + H1Grad(net, act)
def H0(net, act):
return 0.5 * act.dot(net.WG).dot(act) + (net.bG + net.ext).dot(act)
def H0Grad(net, act):
return net.WG.dot(act) + net.bG + net.ext
def H1(net, act):
return -0.5 * net.beta * (act - net.zeta).dot(act - net.zeta)
def H1Grad(net, act):
return -net.beta * (act - net.zeta)
def q2harmony(net, act=None, quant_list=None):
'''Compute quantization harmony Q. By default, it considers filler competition in each role.'''
if act is None:
act = net.act
actC = net.S2C(act)
if quant_list is None:
if net.nfillers == 1:
const1 = 0
const2 = -np.sum(actC**2 * (1-actC)**2)
Q = const2
else:
const1 = -np.sum((np.sum(actC.reshape(net.nfillers,net.nroles,order='F')**2, axis=0)-1)**2)
const2 = -np.sum(actC**2 * (1-actC)**2)
Q = net.c * const1 + (1-net.c) * const2
else:
Q = 0
for jj, group in enumerate(quant_list):
temp_ind = [net.binding_names.index(bb) for bb in group]
actC0 = actC[temp_ind]
Q = Q + q2(net, actC0)
return Q
def q2(net, actC):
const1 = -np.sum(actC**2-1)**2
const2 = -np.sum(actC**2 * (1-actC)**2)
if len(actC) == 1:
return const2
else:
return net.c * const1 + (1-net.c) * const2
def q2hgrad(net, act, quant_list=None):
actC = net.S2C(act)
if quant_list is None:
if net.nfillers == 1:
const1 = 0
const2 = -2 * actC * (actC-1) * (2*actC-1)
else:
const1 = -4 * actC * ( np.tile( np.sum(actC.reshape(net.nfillers,net.nroles,order='F')**2, axis=0), (net.nfillers,1) ).flatten('F') - 1)
const2 = -2 * actC * (actC-1) * (2*actC-1)
q_grad = net.c * const1 + (1-net.c) * const2
else:
q_grad = np.zeros(net.nbindings)
for jj, subgroup in enumerate(quant_list):
curr_q_grad = np.zeros(net.nbindings)
temp_ind = [net.binding_names.index(bb) for bb in subgroup]
actC0 = actC[temp_ind]
curr_q_grad[temp_ind] = q2grad(net, actC0)
q_grad = q_grad + curr_q_grad
return net.C2S(q_grad)
def q2grad(net, actC):
'''q2grad for each group'''
const1 = -4 * actC * (np.sum(actC**2) - 1)
const2 = -2 * actC * (actC-1) * (2*actC-1)
return net.c * const1 + (1-net.c) * const2
def _blob(x,y,area,colour):
"""
Draws a square-shaped blob with the given area (< 1) at
the given coordinates.
http://wiki.scipy.org/Cookbook/Matplotlib/HintonDiagrams
"""
hs = np.sqrt(area) / 2
xcorners = np.array([x - hs, x + hs, x + hs, x - hs])
ycorners = np.array([y - hs, y - hs, y + hs, y + hs])
P.fill(xcorners, ycorners, colour, edgecolor=colour)
def hinton(W, maxWeight=None, xlabels=None, ylabels=None):
"""
Draws a Hinton diagram for visualizing a weight matrix.
Temporarily disables matplotlib interactive mode if it is on,
otherwise this takes forever.
http://wiki.scipy.org/Cookbook/Matplotlib/HintonDiagrams
"""
reenable = False
if P.isinteractive():
P.ioff()
P.clf()
height, width = W.shape
if not maxWeight:
maxWeight = 2**np.ceil(np.log(np.max(np.abs(W)))/np.log(2))
P.fill(np.array([0,width,width,0]),np.array([0,0,height,height]),'gray')
for x in range(width):
for y in range(height):
_x = x+1
_y = y+1
w = W[y,x]
if w > 0:
_blob(_x - 0.5, height - _y + 0.5, min(1,w/maxWeight),'white')
elif w < 0:
_blob(_x - 0.5, height - _y + 0.5, min(1,-w/maxWeight),'black')
if reenable:
P.ion()
label = False
if xlabels is not None:
P.xticks(np.arange(len(xlabels))+0.5, xlabels, rotation='vertical')
label = True
if ylabels is not None:
P.yticks(np.arange(len(ylabels))+0.5, ylabels[::-1])
label = True
if label:
P.xlim([0,width])
P.ylim([0,height])
P.gca().set_aspect('equal', adjustable='box')
P.gca().tick_params(direction='out')
else:
P.axis('off')
P.axis('equal')
P.show()
def heatmap(data, xlabel=None, ylabel=None, xticklabels=None, yticklabels=None,
grayscale=False, colorbar=True):
if grayscale:
cmap = plt.cm.get_cmap("gray_r")
else:
cmap=plt.cm.get_cmap("Reds")
plt.imshow(data, cmap=cmap, interpolation="nearest", aspect='auto')
if xlabel is not None:
plt.xlabel(xlabel, fontsize=16)
if ylabel is not None:
plt.ylabel(ylabel, fontsize=16)
if xticklabels is not None:
plt.xticks(np.arange(len(xticklabels)), xticklabels)
if yticklabels is not None:
plt.yticks(np.arange(len(yticklabels)), yticklabels)
if colorbar:
plt.colorbar()
plt.show()
def encode_symbols(nsymbols, reptype='local', dp=0, ndim=None):
if reptype == 'local':
sym_mat = np.eye(nsymbols)
else:
if ndim==None:
ndim = nsymbols
if isinstance(dp, Number):
sym_mat = dot_products(nsymbols, ndim, dp)
else:
sym_mat = dot_products2(nsymbols, ndim, dp)
return sym_mat
def dot_products(nsymbols, ndim, s):
dp_mat = s * np.ones((nsymbols, nsymbols)) + (1-s) * np.eye(nsymbols, nsymbols)
sym_mat = dot_products2(nsymbols, ndim, dp_mat)
return sym_mat
def dot_products2(nsymbols, ndim, dp_mat, max_iter = 100000):
if not (dp_mat.T == dp_mat).all():
sys.exit('dot_products2: dp_mat must be symmetric')
if (np.diag(dp_mat) != 1).any():
sys.exit('dot_products2: dp_mat must have all ones on the main diagonal')
sym_mat = np.random.uniform(size=ndim*nsymbols).reshape(ndim, nsymbols, order='F')
min_step = .1
tol = 1e-6
converged = False
for iter_num in range(1, max_iter+1):
inc = sym_mat.dot(sym_mat.T.dot(sym_mat) - dp_mat)
step = min(min_step, .01/abs(inc).max())
sym_mat = sym_mat - step * inc
max_diff = abs(sym_mat.T.dot(sym_mat)-dp_mat).max()
if max_diff <= tol:
converged = True
break
if not converged:
print("Didn't converge after %d iterations" % max_iter)
return sym_mat
def compute_TPmat(R, F):
TP = np.kron(R,F)
if TP.shape[0] == TP.shape[1]:
TPinv = la.inv(TP)
else:
TPinv = la.pinv(TP)
return TP, TPinv
def compute_sspace_biases(b_local, net):
return net.TPinv.T.dot(b_local)
def compute_sspace_weights(W_local, net):
return net.TPinv.T.dot(W_local).dot(net.TPinv)
def compute_projmat(A):
return A.dot(la.inv(A.T.dot(A))).dot(A.T)
class sim():
def __init__(self, net, params):
self.net = net
self.params = params
if self.net.getGPset:
self.net.all_grid_points()
self.gpset = copy.deepcopy(self.net.gpset)
self.gpset_gh = copy.deepcopy(self.net.gpset_gh)
def set_params(self, p_name, p_val):
self.params[p_name] = p_val
def set_seed(self, num):
np.random.seed(num)
def set_input(self, input_list, input_vals):
if not isinstance(input_list, list):
input_list = [input_list]
nwords = len(input_list)
self.inputC = np.zeros((self.net.nbindings, nwords))
for ii, binding in enumerate(input_list):
val = input_vals[ii]
inputC = np.zeros(self.net.nbindings)
curr_role = binding.split('/')[1]
role_list = [bb.split('/')[1] for bb in self.net.binding_names]
role_idx = [jj for jj, rr in enumerate(role_list) if curr_role == rr]
if not isinstance(binding, list):
binding = [binding]
binding_idx = self.net.find_bindings(binding)
if self.params['input_inhib']:
if self.params['input_method'] == 'ext':
inputC[role_idx] = -val
inputC[binding_idx] = val
self.inputC[:, ii] = inputC
if self.params['cumulative_input']:
self.inputC = self.inputC.cumsum(axis=1)
def simulate(self, params=None):
if params is None:
params = self.params
nrep = self.params['nrep']
maxstep = self.params['maxstep'].max()
nwords = len(self.input_list)
self.act_trace = np.zeros((nrep, nwords, maxstep+1, self.net.nunits))
self.T_trace = np.zeros((nrep, nwords, maxstep+1))
self.q_trace = np.zeros((nrep, nwords, maxstep+1))
self.speed_trace = np.zeros((nrep, nwords, maxstep+1))
self.ema_speed_trace = np.zeros((nrep, nwords, maxstep+1))
self.converged = np.ones((nrep, nwords), dtype=bool)
self.rt = np.zeros((nrep, nwords))
self.gp = [[[ [] for _ in range(maxstep+1)] for _ in range(nwords)] for _ in range(nrep)]
if self.params['grid_points'] is not None:
grid_points = self.params['grid_points']
if not any(isinstance(pp, list) for pp in grid_points):
grid_points = [grid_points]
self.dist_trace = np.zeros((nrep, nwords, maxstep+1, len(grid_points)))
for rep_num in np.arange(self.params['nrep']):
self.net.reset()
self.set_input(self.input_list, self.input_vals)
for word_num, curr_input in enumerate(self.input_list):
curr_maxstep = self.params['maxstep'][word_num]
if self.params['input_method'] == 'ext':
self.net.extC = copy.deepcopy(self.inputC[:,word_num])
self.net.ext = compute_sspace_biases(self.net.extC, self.net)
if word_num > 0:
self.net.prev_extC = copy.deepcopy(self.inputC[:,word_num-1])
self.net.prev_ext = compute_sspace_biases(self.net.prev_extC, self.net)
elif self.params['input_method'] == 'clamp':
if self.params['cumulative_input'] == True:
self.net.clamp(self.input_list[:word_num+1])
else:
self.net.clamp(self.input_list[word_num])
self.net.run(curr_maxstep,
norm_ord=self.params['norm_ord'],
tol=self.params['tol'],
update_T=self.params['update_T'],
update_q=self.params['update_q'],
q_fun=self.params['q_fun'],
grid_points=self.params['grid_points'],
ext_overlap=self.params['ext_overlap'],
ext_overlap_steps=self.params['ext_overlap_steps'])
self.act_trace[rep_num, word_num, :(curr_maxstep+1), :] = self.net.act_trace
self.rt[rep_num, word_num] = self.net.rt
self.T_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.T_trace
self.q_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.q_trace
self.speed_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.speed_trace
self.ema_speed_trace[rep_num, word_num, :(curr_maxstep+1)] = self.net.ema_speed_trace
self.converged[rep_num, word_num] = self.net.converged
if self.params['grid_points'] is not None:
self.dist_trace[rep_num, word_num, :(curr_maxstep+1), :] = self.net.dist_trace
self.gp[rep_num][word_num][:(curr_maxstep+1)] = self.net.gp_trace
def plot_act_trace(self, rep_num):
rep_ind = rep_num - 1
nwords = self.act_trace.shape[1]
curr_act_trace = self.act_trace[rep_ind, 0, :(self.rt[rep_ind, 0]+1), :]
if nwords > 1:
for wind in np.arange(1, nwords):
temp_act_trace = self.act_trace[rep_ind, wind, 1:(self.rt[rep_ind, wind]+1), :]
curr_act_trace = np.concatenate((curr_act_trace, temp_act_trace), axis=0)
actC_trace = self.net.S2C(curr_act_trace.T).T
heatmap(actC_trace.T, xlabel="Timestep", ylabel="Bindings",
xticklabels=None, yticklabels=self.net.binding_names, grayscale=False)
def plot_trace(self, varname, rep_num):
rep_ind = rep_num - 1
nwords = self.act_trace.shape[1]
curr_trace = getattr(self, varname+'_trace')[rep_ind, 0, :(self.rt[rep_ind, 0]+1)]
if nwords > 1:
for wind in np.arange(1, nwords):
temp_trace = getattr(self, varname+'_trace')[rep_ind, wind, 1:(self.rt[rep_ind, 0]+1)]
curr_trace = np.concatenate((curr_trace, temp_trace), axis=0)
curr_rt_trace = self.rt[rep_ind, :]
curr_rt_trace = curr_rt_trace.cumsum()
plt.plot(curr_trace)
for ii, rt in enumerate(curr_rt_trace):
plt.plot([rt, rt], [min(curr_trace[1:,:].flatten()), max(curr_trace[1:,:].flatten())], 'g-')
plt.xlabel('Timestep', fontsize=16)
plt.ylabel(varname, fontsize=16)
def read_state(self, rep_num, word_num):
act = self.act_trace[rep_num-1, word_num-1, self.rt[rep_num-1, word_num-1], :]
self.net.read_state(act)
def read_grid_point(self, rep_num, word_num):
act = self.act_trace[rep_num-1, word_num-1, self.rt[rep_num-1, word_num-1], :]
self.net.read_grid_point(act)
def compute_gpdist(self):
nrep = self.params['nrep']
nwords = len(self.input_list)
maxstep = self.params['maxstep']
gp_count = np.zeros((nrep, nwords, maxstep+1, len(self.gpset)))
for rep_ind in range(nrep):
for word_ind in range(nwords):
curr_rt = self.rt[rep_ind][word_ind].astype(int)
for step in range(curr_rt):
idx = self.gpset.index(self.gp[rep_ind][word_ind][step])
gp_count[rep_ind, word_ind, step, idx] += 1
gp_prob = gp_count.sum(axis=0)/nrep
return gp_prob
def plot_gp_prob_trace(self, word_num, gp_prob_trace, ngp=None, yticklab=True, grayscale=False, colorbar=True):
if ngp is None:
ngp = gp_prob_trace.shape[-1]
minrt = np.min(self.rt[:][word_num-1]).astype(int)
curr_trace = gp_prob_trace[word_num-1,:minrt,:(ngp+1)]
if yticklab:
yticklabels = [','.join(gp)+('(%.2f)' % self.gpset_gh[ii]) for ii, gp in enumerate(self.gpset[:(ngp+1)])]
heatmap(curr_trace.T, xlabel="Timestep", ylabel="", \
xticklabels=None, yticklabels=yticklabels,
grayscale=grayscale, colorbar=colorbar)
else:
heatmap(curr_trace.T, xlabel="Timestep", ylabel="", \
xticklabels=None, yticklabels=None,
grayscale=grayscale, colorbar=colorbar)
def b_ind(f, r, net):
return f + r * net.nfillers
def u_ind(phi, rho, net):
return phi + rho * net.ndim_f
def w(f, r, phi, rho, net):
return net.TPinv[b_ind(f, r, net), u_ind(phi, rho, net)]
def get_a(n, net, f, r):
act = 0
for phi in range(net.ndim_f):
for rho in range(net.ndim_r):
act += w(f, r, phi, rho, net) * n[u_ind(phi, rho, net)]
return act
def n2a(n, net, f=None, r=None):
if (f is None) and (r is None):
avec = np.zeros(net.nbindings)
for f in range(net.nfillers):
for r in range(net.nroles):
avec[b_ind(f, r, net)] = get_a(n, net, f, r)
return avec
elif (f is None) and (r is not None):
avec = np.zeros(net.nfillers)
for f in range(net.nfillers):
avec[f] = get_a(n, net, f, r)
return avec
elif (f is not None) and (r is None):
avec = np.zeros(net.nroles)
for r in range(net.nroles):
avec[r] = get_a(n, net, f, r)
return avec
else:
return get_a(n, net, f, r)
def Q0(net, n):
a = net.S2C(n)
return -np.sum(a**2 * (1-a)**2)
def Q0GradE(net, n):
q0grad = np.zeros(net.nunits)
for phi in range(net.ndim_f):
for rho in range(net.ndim_r):
q0grad[u_ind(phi, rho, net)] = 0.0
for f in range(net.nfillers):
for r in range(net.nroles):
a_fr = n2a(n, net, f, r)
g_fr = 2 * a_fr * (1 - a_fr) * (1 - 2 * a_fr)
q0grad[u_ind(phi, rho, net)] += w(f, r, phi, rho, net) * g_fr
return -q0grad
def Q0Grad(net, n):
a = net.S2C(n)
g = 2 * a * (1-a) * (1-2*a)
gmat = np.tile(g, (net.nunits, 1)).T
q0grad = np.sum(net.TPinv * gmat, axis=0)
return -q0grad
def Q1(net, n):
return -np.sum(np.sum(net.vec2mat(net.S2C(n))**2, axis=0)-1)**2
def Q1GradE(net, n):
q1grad = np.zeros(net.nunits)
for phi in range(net.ndim_f):
for rho in range(net.ndim_r):
unit_grad = 0.0
for r in range(net.nroles):
var1 = np.sum(n2a(n, net, r=r)**2) - 1
var2 = 0.0
for f in range(net.nfillers):
var2 += n2a(n, net, f, r) * w(f, r, phi, rho, net)
unit_grad += 4 * var1 * var2
q1grad[u_ind(phi, rho, net)] = unit_grad
return -q1grad
def Q1Grad(net, n):
a = net.S2C(n)
q1grad = 0.0
for r_ind, rr in enumerate(net.role_names):
curr_binding_ind = net.find_roles(rr)
amat = np.tile(a[curr_binding_ind], (net.nunits, 1)).T
term2 = np.sum(net.TPinv[curr_binding_ind, :] * amat, axis=0)
term1 = np.sum(a[curr_binding_ind] ** 2) - 1
q1grad += term1 * term2
q1grad = 4 * q1grad
return -q1grad
def QGrad(n, net):
return net.c * Q0Grad(net, n) + (1-net.c) * Q1Grad(net, n)
def Q(net, n):
return net.c * Q0(net, n) + (1-net.c) * Q1(net, n)