Skip to content
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
Cannot retrieve contributors at this time
1031 lines (856 sloc) 34.6 KB
Info: Replication the result of a published paper
Original repo:
General convections:
1. The perception model follows O --> S --> A
- O stands for observation
- S stands for mental believed state
- A stands for the action
2. for vector, I usually use the column convection, which means Nx1
3. for probability distribution, I use
- p for generic notation of probiity
- psi for perception and state encoder
- pi for policy state to action
4. understanding the conditional probability variable
- psi_s1o means the probability of state given obs, 1 means |
the first var will be col, and the secod var will be row. I know this is a
little bit counter-intuitive, but I feel hard to change my habbits.
- to differentiate conditional and joint distribution, I will not include 1,
for example, I will use p_so to show joint distirbution.
5. about variable sequence. When comming across multiple variables,
I will sort them in the following sequnece
- highest priority -1: Util_matrix
- priority -2: O observation
- priority -3: S internal mental state
- lowest priority -4: action
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import logsumexp # for partition function
from matplotlib.widgets import Slider # interact plot
# define the saving path
path = os.path.dirname(os.path.abspath(__file__))
def I( p_x, p_y1x):
p_x: sender distribution nX x 1
p_y1x: channal nX x nY
p_y = p_x.T @ p_y1x
H_y = -np.sum( p_y * np.log( p_y + 1e-20))
H_y1x = -np.sum( p_x * p_y1x * np.log( p_y1x + 1e-20))
return H_y - H_y1x
def setup_predator_prey_example(mating_Utility=False):
a1: wait and attack
a2: stalk and attack
a3: flee
small w:
a1++ might not come towards you
a2+++ will not hear you
a3- no food
medium w:
a1++ might not come towards you
a2+ hear you and flee
a3- no food
large w:
a1-- die
a2-- die
a3++++ survive
# set up observation
obs_vals = [ 2, 3, 4, 6, 7, 8, 10, 11, 12] # opponent size
nO = len(obs_vals) # obs' cardinality
obs_vars = [ str(obs_val) for obs_val in obs_vals] # obs' semantic meaning
p_obs = np.ones([ nO, 1]) / nO # obs distribution
# set up action
if mating_Utility:
act_vars = [ 'display', 'flee']
act_vals = [ 400, 500]
act_vars = [ f'sneak up w={sz}' for sz in obs_vals[0:3]] + \
[ f'ambush w={sz}' for sz in obs_vals[3:6]] + \
[ 'ambush', 'sneak up', 'flee'] # obs' semantic meaning
act_vals = obs_vals[ 0: 6] + [ 100, 200, 300] # action utility
if mating_Utility:
def utility_fn( obs, act):
# set val
sur_util = 5
best_hunt_util = 3.5
sneak_small_util = 3
ambush_util = 2.3
sneak_medium_util = 1.5
flee_small_medium_util = .5
eaten_util = 0.
# for small group, the best act is sneakup,
# for each size, there is a best special sneak up skills
if (obs < 5) and (act==obs):
return best_hunt_util
# for medium group, the best act is ambush,
# for specific size, there is a best ambush skills
if (obs < 9) and (act==obs):
return best_hunt_util
# for both small, medium, there is a generic sneak skill
# for small, sneak up +++
# for medium, sneak up +
if act == act_vals[-2]:
if obs < 5:
return sneak_small_util
elif (5 < obs) and (obs < 9):
return sneak_medium_util
# for both small, medium, there is a generic ambush skill
# for small, ambush ++
# for medium, ambush ++
if (act == act_vals[-3]) and (obs < 9):
return ambush_util
# for both small, medium, flee is a bad choice, because there
# is not food, flee /.+
if (act == act_vals[-1]) and (obs < 9):
return flee_small_medium_util
# for large, flee is the only choice
if (9 < obs):
if (act == act_vals[-1]):
return sur_util
return eaten_util
# for small group, the best act is sneakup,
# for each size, there is a best special sneak up skills
# if the wrong specific act is used in small group, effect * 80%
# if the wrong specific act is used in medium group, equals to generic
if (obs < 5):
if (act < 5):
return sneak_small_util * .8
return ambush_util
if (obs < 9):
if (act < 5):
return sneak_medium_util
return ambush_util * .8
# make util table
util_mat = make_util_mat( utility_fn, obs_vals, act_vals)
return obs_vals, obs_vars, p_obs, act_vals, act_vars, util_mat
## medical system env
def setup_medical_example(uniform_obs=True):
There are three kinds disease H, L12, L34
each diseases has two sub types of diseases,
For each specific type, the specific treatment works the best (high utilty)
Within the same kind of dissease, cross type treatment is ok but less effective (medium utility)
Cross kind treatment results in bad results (low utility)
There are also general treatments for each disease, h, l12, l34, l
# set up observations
obs_vars = [ 'h1', 'h2', 'l1', 'l2', 'l3', 'l4'] # obs' semantic meaning
nO = len(obs_vars) # obs' cardinality
obs_vals = np.arange(1, nO+1) # opponent size
if uniform_obs:
p_obs = np.ones([ nO, 1]) / nO # obs distribution
p_obs = np.ones( [nO, 1]) # load the non uniform
p_obs[ 0:2, 0] = 3
p_obs = p_obs / np.sum( p_obs) # normalize
# set up actions
act_vars = [ f'treat={o}' for o in obs_vars ] \
+ [ 'treat=l12', 'treat=l34', 'treat=h', 'treat=l']
nA = len( act_vars)
act_vals = np.arange( 1, nA+1)
def utility_fn( obs, act):
correct_util = 3 # correct
wrong_heart_util = correct_util * .3 #
general_heart_util = 1.5
wrong_lung_util1 = correct_util * .5
wrong_lung_util2 = correct_util * 0.
general_lung_util = 1.5
general_lung_util12 = 2.5
general_lung_util34 = 2.5
# correct treatment
if obs == act:
return correct_util
# heart-disease, within kind wrong treatment
if ( obs < 3) and ( act < 3):
return wrong_heart_util
# lung-disease12, wrong treatment
if ( 2 < obs) and ( obs < 5) and ( 2 < act) and ( act < 5):
return wrong_lung_util1
# lung-disease34, lung
if ( 2 < obs) and ( obs < 5) and ( 4 < act) and ( act < 7):
return wrong_lung_util2
# lung-disease12, wrong treatment
if ( 4 < obs) and ( obs < 7) and ( 4 < act) and ( act < 7):
return wrong_lung_util1
# lung-disease34, lung
if ( 4 < obs) and ( obs < 7) and ( 2 < act) and ( act < 5):
return wrong_lung_util2
# general heart treatment
if ( obs < 3) and ( act == 9):
return general_heart_util
# general lung treatments
if ( 2 < obs) and ( obs < 7):
if ( act == 7) and ( obs < 5):
return general_lung_util12
if ( act == 8) and ( 4 < obs):
return general_lung_util34
if act == 10:
return general_lung_util
# wrong treatment for wrong cause
return 0
util_mat = make_util_mat( utility_fn, obs_vals, act_vals)
return obs_vals, obs_vars, p_obs, act_vals, act_vars, util_mat
# from utility function to utility matrix
def make_util_mat( utility_fn, obs_vals, act_vals):
util_mat = np.zeros( [len(obs_vals), len(act_vals)])
for o_idx, o in enumerate(obs_vals):
for a_idx, a in enumerate(act_vals):
util_mat[ o_idx, a_idx] = utility_fn( o, a)
return util_mat
# handcrafted state encode and p(s|o)
def psi_hand( obs_vals, state_vals, lamb):
'''STATE ENCODER p(s|o,λ)
state is the noisy peception of observation
The distribution is approximate using the sampling method.
nO = len( obs_vals)
nS = len( state_vals)
qs1o = np.zeros( [nO, nS])
nsamples = 5000 # num of samples used to collect data
# The implementation of sampling is reject sampling
# accept the reasonable state, in fact this is a epsilon-greedy
for io in range(nO):
idx = 0
qs1o_samples = np.zeros( [nsamples,])
while idx < nsamples-1:
sample= np.round( obs_vals[io] + np.random.randn(1)/lamb)
if (sample>0) and (sample<nS+1):
qs1o_samples[idx+1] = sample
idx +=1
# count frequencies over state and
# noramlize it as probability distribution
bins = np.arange(.5, nS+1.5)
freq, _ = np.histogram( qs1o_samples, bins)
qs1o[ io, :] = freq / np.sum( freq)
return qs1o
def BA_algs( util_mat, p_x, q_y1x,
tol = 1e-4, max_iter=10000):
px --> NX x 1:
sender's distritbuion
py --> NY x 1:
receiver's distribution
util_mat --> NX x NY:
utility matrix
beta --> scalar:
inverse teperature
tol --> scalar:
tolerance for convergence checking
max_iter --> scalar:
maximum iteration number
# init for iteration
i = 0
done = False
p_y = p_x.T @ q_y1x
while not done:
# cache data to check convergence
old_q_y1x = q_y1x
# update the channel p(y|x)
log_q_y1x = beta * util_mat + np.log( p_y + 1e-20)
q_y1x = np.exp( log_q_y1x - logsumexp( log_q_y1x, axis=1, keepdims=True))
# update the marginal policy
p_y = p_x.T @ q_y1x
# update counter
i += 1
# check convergence
if np.sum(abs(old_q_y1x - q_y1x)) < tol:
done = True
if i >= max_iter:
print( f'BA alg reached maximum iteration {max_iter}, results might be inaccurate')
done = True
return q_y1x, p_y.T
def get_pi_a1s( util_mat, psi_s1o, p_a1os,
p_o, p_s, p_a,
beta2, beta3):
'''COMPUTE π(a|s)
# inference observation given mental state using bayes rule
# ψ(o|s) = ψ(s|o)p(o)/p(s) : nSxnO
psi_o1s = (p_o * psi_s1o / p_s.T).T
if beta3 == 0:
This is the special case, because S block O-->A
O --> S --> A
# bel_U(s,a) = ∑_o ψ(s|o)U(o,a)
bel_util = psi_o1s @ util_mat
# π(a|s) ∝ p(a)exp( β2*Bel_U(s,a)) nSxnA
log_pi_a1s = beta2 * bel_util + np.log( p_a.T + 1e-20)
pi_a1s = np.exp( log_pi_a1s - logsumexp( log_pi_a1s, axis=-1, keepdims=True))
A more general case
This time, O --> S
\ /
# π(a|s) = ∑_o ψ(s|o)p(a|o,s) nSxnA
!!!!!!!!!!!!!!!!!!!!!!This could be problematic !!!!!!!!!!!!!!!!!!!!!!!!!
pi_a1s = np.sum( p_a1os * psi_o1s.T[ :, :, np.newaxis], axis=0) + 1e-20
pi_a1s = pi_a1s / np.sum( pi_a1s ,axis=-1, keepdims=True)
return pi_a1s
def get_p_a1os( util_mat, pi_a1s, p_a,
beta2, beta3):
'''COMPUTE p(a|o,s)
nO = util_mat.shape[0]
nA = util_mat.shape[1]
nS = pi_a1s.shape[0]
if beta3 == 0:
p_a1os = np.zeros( [ nO, nS, nA])
for oi in range( nO):
p_a1os[ oi, :, :] = pi_a1s
# p(a|o,s) ∝ π(a|s) exp( β3 U(o,a) - β3/β2 log(π(a|s)/p(a))
log_p_a1os = beta3 * util_mat[ :, np.newaxis, :] \
- beta3 / beta2 * np.log( pi_a1s[ np.newaxis, :, :] + 1e-20)\
+ beta3 / beta2 * np.log( p_a.reshape([-1])[ np.newaxis, np.newaxis, :] + 1e-20)\
+ np.log( pi_a1s[ np.newaxis, :, :] + 1e-20)
p_a1os = np.exp( log_p_a1os - logsumexp( log_p_a1os, axis=-1, keepdims=True))
return p_a1os
def get_psi_s1o( util_mat, pi_a1s, p_a1os,
p_s, p_a,
beta1, beta2, beta3):
'''COMPUTE ψ(s|o)
if beta3 == 0:
# compute EU(o,s): ∑_a π(a|s)U(o,a) nO x nS
EU = util_mat @ pi_a1s.T
# compute D[π(a|s)||p(a)] 1 x nS
DKL1 = np.sum( pi_a1s * np.log( pi_a1s + 1e-20)
- pi_a1s * np.log( p_a.T + 1e-20), axis=-1).reshape([1,-1])
# compute EU(o,s): ∑_a p(a|o,s)U(o,a) nOxnS
EU = np.sum( util_mat[ :, np.newaxis, :] * p_a1os, axis=-1)
# compute D[p(a|o,s)||p(a)] nOxnS
DKL1 = np.sum( p_a1os * np.log( p_a1os + 1e-20)
- p_a1os * np.log( p_a.reshape([-1,])[ np.newaxis, np.newaxis, :] + 1e-20), axis=-1)
# compute D[p(a|o,s)||π(a|s)] nOxnS
DKL2 = np.sum( p_a1os * np.log( p_a1os + 1e-20)
- p_a1os * np.log( pi_a1s[ np.newaxis, :, :] + 1e-20), axis=-1)
if beta3 == 0:
if np.sum( DKL2) > 0:
raise Exception( 'In sequntial case, p(a|o,s) should equal to π(a|s)')
# Fser = EU(o,s) - 1/β2 D[π(a|s)||p(a)], nOxnS
F = EU - 1/beta2 * DKL1
# Fpar = EU(o,s) - 1/β2 D[p(a|o,s)||p(a)] - (1/β3 - 1/β2) * D[p(a|o,s)||π(a|s)] nOxnS
F = EU - 1/beta2 * DKL1 - ( 1/beta3 - 1/beta2) * DKL2
# ψ(s|o) ∝ p(s)exp( β1 F )
log_psi_s1o = beta1 * F + np.log( p_s.T + 1e-20)
psi_s1o = np.exp( log_psi_s1o - logsumexp( log_psi_s1o, axis=-1, keepdims=True))
return psi_s1o
def general_BA_algs( util_matrix, p_o, psi_s1o, p_a1os,
beta1, beta2, beta3,
tol, max_iter):
'''General BA ALG
The original BA algorithm can only be applied
to one channel. It the architecture includes
multiple channels, either cascade or parallel
structure. This algorithm is introduced by the
titled paper as a general solution to the multi
channel structure
The shape of some vars
p_o: nO x 1
p_s: nS x 1
p_a: nA x 1
psi_s1o: nO x nS (see the convenction disclamer at
the beginning for more details)
psi_o1s: nS x nO
pi_a1s: nS x nA
p_a1os: nO x nS x nA
# reshape the vector and compute all we need for iteration,
# they are not necessarily need to be correct, because
# this is just for initialization
# compute p_s, p_a for initialization, p_o is given
# p_o # nOx1
p_s = (p_o.T @ psi_s1o).T # nSx1
# p(a|o) = ∑_s ψ(s|o)p(a|o,s)
p_a1o = np.sum( psi_s1o[ :, :, np.newaxis] * p_a1os, axis=1)
p_a = (p_o.T @ p_a1o).T # nAx1
## compute π(a|s) or π(a|s,o) depends on case
pi_a1s = get_pi_a1s( util_matrix, psi_s1o, p_a1os,
p_o, p_s, p_a,
beta2, beta3)
if beta3 == 0:
# sequential case, make sure p(a|o,s) = π(a|s)
p_a1os = get_p_a1os( util_matrix, pi_a1s, p_a,
beta2, beta3)
# start iteration
done = False
i = 0
while not done:
# cache the current val for convergence checks
old_p_a1os = p_a1os
old_psi_s1o = psi_s1o
# follow the sequence of the original paper
# update ψ(s|o)
psi_s1o = get_psi_s1o( util_matrix,
pi_a1s, p_a1os, p_s, p_a,
beta1, beta2, beta3)
if beta3 == 0:
# update π(a|s)
pi_a1s = get_pi_a1s( util_matrix, psi_s1o, p_a1os,
p_o, p_s, p_a,
beta2, beta3)
# update p(a|o,s)
p_a1os = get_p_a1os( util_matrix, pi_a1s, p_a,
beta2, beta3)
# update p(a|o,s)
p_a1os = get_p_a1os( util_matrix, pi_a1s, p_a,
beta2, beta3)
# update π(a|s)
pi_a1s = get_pi_a1s( util_matrix, psi_s1o, p_a1os,
p_o, p_s, p_a,
beta2, beta3)
# update marginal p(s), p(a), when calculating mariginal policy
# we add a small value to prevent the distribution to becomes 0
# note that this is very important, during iteration.
p_s = (p_o.T @ psi_s1o).T + 1e-20
p_s = p_s / np.sum(p_s) # nSx1
# p(a|o) = ∑_s ψ(s|o)p(a|o,s)
p_a1o = np.sum( psi_s1o[ :, :, np.newaxis] * p_a1os, axis=1) # nOxnA
p_a = (p_o.T @ p_a1o).T + 1e-20
p_a = p_a / np.sum( p_a) # nAx1
# update counter
i += 1
# check convergence
if np.sum(abs(p_a1os - old_p_a1os)) + np.sum(abs(psi_s1o - old_psi_s1o)) < tol:
done = True
if i >= max_iter:
print( f'General BA alg reached maximum iteration {max_iter}, results might be inaccurate')
done = True
return psi_s1o, pi_a1s, p_a1os, p_s, p_a
def show_predator_prey_util() :
_, obs_vars, _, _, act_vars, util_mat = setup_predator_prey_example()
plt.figure( figsize=( 21, 7))
plt.subplot( 1,2,1)
plt.imshow( util_mat.T, cmap='Blues', origin='lower')
plt.title('Utility of predator prey')
plt.xticks( np.arange(len(obs_vars))+.5, obs_vars)
plt.yticks( np.arange(len(act_vars))+.5, act_vars)
plt.xlabel( 'observed animal size')
plt.ylabel( 'action')
plt.subplot( 1,2,2)
log_pi = 100*util_mat
opt_pi = np.exp( log_pi - logsumexp( log_pi, axis=1, keepdims=True))
plt.imshow( opt_pi.T, cmap='Blues', origin='lower'
, vmin=0, vmax=1)
plt.title('Optimal policy')
plt.xticks( np.arange(len(obs_vars))+.5, obs_vars)
plt.yticks( np.arange(len(act_vars))+.5, act_vars)
plt.xlabel( 'observed animal size')
plt.ylabel( 'action')
fig_name = f'{path}/predator_prey_utility.png'
plt.savefig( fig_name)
def show_medical_util():
_, obs_vars, _, _, act_vars, util_mat = setup_medical_example()
plt.figure( figsize=( 21, 7))
plt.subplot( 1,2,1)
plt.imshow( util_mat.T, cmap='Blues', origin='lower')
plt.title( 'Utility of medical example')
plt.xticks( np.arange(len(obs_vars))+.5, obs_vars)
plt.yticks( np.arange(len(act_vars))+.5, act_vars)
plt.xlabel( 'Disease type')
plt.ylabel( 'Treatment')
plt.subplot( 1,2,2)
log_pi = 100*util_mat
opt_pi = np.exp( log_pi - logsumexp( log_pi, axis=1, keepdims=True))
plt.imshow( opt_pi.T, cmap='Blues', origin='lower'
, vmin=0, vmax=1)
plt.title('Optimal policy')
plt.xticks( np.arange(len(obs_vars))+.5, obs_vars)
plt.yticks( np.arange(len(act_vars))+.5, act_vars)
plt.xlabel( 'Disease type')
plt.ylabel( 'Treatment')
fig_name = f'{path}/medical_utility.png'
plt.savefig( fig_name)
def illustrate_cascade_channel( lamb, beta1, beta2, beta3):
This function replicate fig to illustrate
how cascade channel may work, the env is
the prey pradator environment.
# set hyperparameters
# lamb : precision of the hand-crafted perceptual model
# beta1: price for I(O;S)
# beta2: price for I(S;A)
# beta3: price for I(A;S,O)
tol = 1e-4 # tolerance for convergence
max_iter = 10000 # maximum number of BA iterations
# load env and perception
obs_vals, obs_vars, p_o, act_vals, act_vars, util_mat = setup_predator_prey_example()
# init the internal representation state
state_vals = np.arange( 1, 14)
state_vars = [ str(s) for s in state_vals]
# obtain the cardinality of each variable
nO = len( obs_vals)
nS = len( state_vals)
nA = len( act_vals)
# fix perception channel #
# fix percpetion channel: ecnode observation into state
# The basic idea of this handcrafted perception is
# assuming the perception system can almost optimally encode the objective
# weith minor perturbation.
psi_s1o = psi_hand( obs_vals, state_vals, lamb)
# inference the observation based on the mental state
# ψ(o|s) = ψ(s|o)p(o)/p(s), the inference is calculated using Baye's rule
psi_o1s = p_o * psi_s1o + 1e-20 # nOx1 * nOxnS, p(o) will braodcast
psi_o1s = (psi_o1s / np.sum( psi_o1s, axis=0, keepdims=True)).T
p_s = (p_o.T @ psi_s1o).T
# now we have p(o), ψ(o|s), what we need p(a|o,s) == π(a|s)
# init the p(a|o,s) as a uniform distribution
pi_a1s = np.ones( [ nS, nA])
pi_a1s = pi_a1s / np.sum( pi_a1s, axis=-1, keepdims=True)
# before that we need to find the bel util matrix
# because the policy channel is only about mental states and actions
# belU(s,a) = ∑_o ψ(o|s)U(o,a)
bel_util = psi_o1s @ util_mat
# use BA iteration to find the optimal policy channel
results = BA_algs( bel_util, p_s, pi_a1s,
tol, max_iter)
# unpack the optimized reults and
# calcute observational policy π(a|o)
pi_a1s, p_a = results
# calculate the observation policy π(a|o) = ∑_s ψ(s|o)π(a|s)
pi_a1o = np.sum( psi_s1o[ :, :, np.newaxis] * pi_a1s[ np.newaxis, :, :], axis=1)
# store fix perception results for visualization
fix_psi = dict()
fix_psi['p(o)'] = p_o
fix_psi['psi(s|o)'] = psi_s1o
fix_psi['p(s)'] = p_s
fix_psi['pi(a|s)'] = pi_a1s
fix_psi['pi(a|o)'] = pi_a1o
fix_psi['p(a)'] = p_a
fix_psi['EU'] = np.sum( p_o * pi_a1o * util_mat)
fix_psi['I(o;s)'] = I( p_o, psi_s1o)
fix_psi['I(s;a)'] = I( p_s, pi_a1s)
fix_psi['Jser'] = fix_psi['EU'] - 1/beta1 * fix_psi['I(o;s)'] \
- 1/beta2 * fix_psi['I(s;a)']
# RD perception channel #
# BA algorithm is a iterative algs, requiring initialization of some values
# According to the channel rule p(o,s,a) = p(o)ψ(o|s)p(a|o,s)
# once we know these three distributions, we can compute all other correspondence
# among them, p(o) is given, ψ(o|s) I choose the handcrafed percpetion as init
# what we only need is to asume p(a|o,s)
# init the p(a|o,s) as a uniform distribution
p_a1os = np.ones( [ nO, nS, nA])
p_a1os = p_a1os / np.sum( p_a1os, axis=-1, keepdims=True)
# run the general blahut ariomoto algorithm to get the optimal
# channel pairs
results = general_BA_algs( util_mat, p_o, psi_s1o, p_a1os, # utility & dist
beta1, beta2, beta3, # price parameter
tol, max_iter) # iteration hyperparameter
# unpack the results
psi_s1o, pi_a1s, p_a1os, p_s, p_a = results
# calculate the observation policy π(a|o) = ∑_s ψ(s|o)π(a|s)
pi_a1o = np.sum( psi_s1o[ :, :, np.newaxis] * pi_a1s[ np.newaxis, :, :], axis=1)
# store fix perception results for visualization
RD_psi = dict()
RD_psi['p(o)'] = p_o
RD_psi['psi(s|o)'] = psi_s1o
RD_psi['p(s)'] = p_s
RD_psi['p(a|o,s)'] = p_a1os
RD_psi['pi(a|s)'] = pi_a1s
RD_psi['pi(a|o)'] = pi_a1o
RD_psi['p(a)'] = p_a
RD_psi['EU'] = np.sum( p_o * pi_a1o * util_mat)
RD_psi['I(o;s)'] = I( p_o, psi_s1o)
RD_psi['I(s;a)'] = I( p_s, pi_a1s)
RD_psi['Jser'] = RD_psi['EU'] - 1/beta1 * RD_psi['I(o;s)'] \
- 1/beta2 * RD_psi['I(s;a)']
# Visualization #
plt.figure( figsize=( 21, 14))
# Panel A: visualize handcrated perception channel ψ_λ(s|o)
plt.subplot( 2, 3, 1)
plt.imshow( fix_psi['psi(s|o)'].T, cmap='Reds', origin='lower', vmin=0, vmax=1)
plt.xticks( np.arange(nO)+.5, obs_vars)
plt.yticks( np.arange(nS)+.5, state_vars)
# note: share xlabel with panel D
plt.ylabel( 'mental believed size')
# Panel B: visualize the optimized policy π_λ(a|s)
# with fix perception ψ_λ(s|o)
plt.subplot( 2, 3, 2)
plt.imshow( fix_psi['pi(a|o)'].T, cmap='Blues', origin='lower', vmin=0, vmax=1)
plt.xticks( np.arange(nO)+.5, obs_vars)
plt.yticks( np.arange(nA)+.5, act_vars)
# note: share xlabel with panel E
plt.ylabel( 'action')
# Panel C: viualize EU and Jser
plt.subplot( 2, 3, 3)
groups = [ 'rationality', 'resource-rationality']
x = np.arange(len(groups))
fix_group = [ fix_psi['EU'], fix_psi['Jser']]
RD_group = [ RD_psi['EU'], RD_psi['Jser']]
width = .35 x-width/2, fix_group, width, label='fix ψ_λ(s|o)', color='salmon') x+width/2, RD_group, width, label='learnt ψ_RD(s|o)', color='royalblue')
plt.xticks( x, groups)
plt.ylabel( 'values')
plt.ylim([ 0, 3.8])
# Panel D: visual the RD optimized percpetion channel ψ_RD(s|o)
plt.subplot( 2, 3, 4)
plt.imshow( RD_psi['psi(s|o)'].T, cmap='Reds', origin='lower', vmin=0, vmax=1)
plt.xticks( np.arange(nO)+.5, obs_vars)
plt.yticks( np.arange(nS)+.5, state_vars)
plt.xlabel( 'observed animal size')
plt.ylabel( 'mental believed size')
# Panel E: visualize the optimized policy π_RD(a|s)
# with adaptive perception ψ_RD(s|o)
plt.subplot( 2, 3, 5)
plt.imshow( RD_psi['pi(a|o)'].T, cmap='Blues', origin='lower', vmin=0, vmax=1)
plt.xticks( np.arange(nO)+.5, obs_vars)
plt.yticks( np.arange(nA)+.5, act_vars)
plt.xlabel( 'observed animal size')
plt.ylabel( 'action')
# Panel F: viualize mutual information
plt.subplot( 2, 3, 6)
groups = [ 'I(o;s)', 'I(s;a)']
x = np.arange(len(groups))
fix_group = [ fix_psi['I(o;s)'], fix_psi['I(s;a)']]
RD_group = [ RD_psi['I(o;s)'], RD_psi['I(s;a)']]
width = .35 x-width/2, fix_group, width, label='fix ψ_λ(s|o)', color='salmon') x+width/2, RD_group, width, label='learnt ψ_RD(s|o)', color='royalblue')
plt.xticks( x, groups)
plt.ylabel( 'values')
plt.ylim([ 0, 3.8])
fig_name = f'{path}/cascade_channel-lambda={lamb}-beta1={beta1}-beta2={beta2}.png'
plt.savefig( fig_name)
def get_parl_results( beta1, beta2, beta3,
is_uniform, tol, max_iter):
# load env and perception
obs_vals, _, p_o, act_vals, _, util_mat = setup_medical_example(is_uniform)
# init the model, as said in original paper in page 14,
# we use same notation for model and state, becuase this benefit comparision.
state_vals = np.arange( 1, 4)
# obtain the cardinality of each variable
nO = len( obs_vals)
nS = len( state_vals)
nA = len( act_vals)
# BA algorithm is a iterative algs, requiring initialization of some values
# According to the channel rule p(o,s,a) = p(o)ψ(o|s)p(a|o,s)
# once we know these three distributions, we can compute all other correspondence
# among them, p(o) is given,
# what we only need is to asume ψ(s|o) and p(a|o,s)
# init the ψ(s|o) as an uniform distribution
# this is not how they used in their document,
# but I cannot understand their code
psi_s1o = np.random.rand( nO, nS)
psi_s1o = psi_s1o / np.sum( psi_s1o, axis=-1, keepdims=True)
# init the p(a|o,s) as an uniform distribution
p_a1os = np.ones( [ nO, nS, nA])
p_a1os = p_a1os / np.sum( p_a1os, axis=-1, keepdims=True)
# run the general blahut ariomoto algorithm to get the optimal
# channel pairs
results = general_BA_algs( util_mat, p_o, psi_s1o, p_a1os, # utility & dist
beta1, beta2, beta3, # price parameter
tol, max_iter) # iteration hyperparameter
# unpack the results
psi_s1o, pi_a1s, p_a1os, p_s, p_a = results
# calculate the observation policy π(a|o) = ∑_s ψ(s|o)π(a|s)
pi_a1o = np.sum( psi_s1o[ :, :, np.newaxis] * p_a1os, axis=1)
# store fix perception results for visualization
p1 = dict()
p1['p(o)'] = p_o
p1['psi(s|o)'] = psi_s1o
p1['p(s)'] = p_s
p1['p(a|o,s)'] = p_a1os
p1['pi(a|s)'] = pi_a1s
p1['pi(a|o)'] = pi_a1o
p1['p(a)'] = p_a
# p1['EU'] = np.sum( p_o * pi_a1o * util_mat)
# p1['I(o;s)'] = I( p_o, psi_s1o)
# p1['I(o;a|s)'] = I( p_s, pi_a1s)
# p1['Jser'] = p1['EU'] - 1/beta1 * p1['I(o;s)'] \
# - 1/beta3 * p1['I(s;a)']
return p1
def illustrate_parallel_channel( beta1, beta2, beta3):
This function replicate fig to illustrate
how parallel channel may work, the env is
the prey pradator environment.
# set hyperparameters
# lamb : precision of the hand-crafted perceptual model
# beta1: price for I(O;S)
# beta2: price for I(S;A)
# beta3: price for I(A;S,O)
tol = 1e-3 # tolerance for convergence
max_iter = 10000 # maximum number of BA iterations
# load vars for plot
_, obs_vars, _, _, act_vars, _ = setup_medical_example()
state_vals = np.arange( 1, 4)
state_vars = [ f'm={s}' for s in state_vals]
# obtain the cardinality of each variable
nO = len( obs_vars)
nS = len( state_vars)
nA = len( act_vars)
# RD begin with uniform prior #
is_uniform = True
p1 = get_parl_results( beta1, beta2, beta3,
is_uniform, tol, max_iter)
# RD begin with biased prior #
is_uniform = False
p2 = get_parl_results( beta1, beta2, beta3,
is_uniform, tol, max_iter)
# Visualization #
plt.figure( figsize=( 21, 14))
# Panel A: visualize higher-level model selector with uniform prior
plt.subplot( 2, 3, 1)
plt.imshow( p1['psi(s|o)'].T, cmap='Reds', origin='lower', vmin=0, vmax=1)
plt.xticks( np.arange(nO)+.5, obs_vars)
plt.yticks( np.arange(nS)+.5, state_vars)
# note: share xlabel with panel D
plt.ylabel( 'first diagnosis')
# Panel B: visualize the model policy with uniform pior
plt.subplot( 2, 3, 2)
plt.imshow( p1['pi(a|s)'].T, cmap='Greens', origin='lower', vmin=0, vmax=1)
plt.xticks( np.arange(nS)+.5, state_vars)
plt.yticks( np.arange(nA)+.5, act_vars)
# note: share xlabel with panel E
plt.ylabel( 'treatment')
# Panel C: visualize the obs policy with uniform pior
plt.subplot( 2, 3, 3)
plt.imshow( p1['pi(a|o)'].T, cmap='Blues', origin='lower', vmin=0, vmax=1)
plt.xticks( np.arange(nO)+.5, obs_vars)
plt.yticks( np.arange(nA)+.5, act_vars)
# note: share xlabel with panel E
plt.ylabel( 'treatment')
# Panel C: visualize higher-level model selector with biased prior
plt.subplot( 2, 3, 4)
plt.imshow( p2['psi(s|o)'].T, cmap='Reds', origin='lower', vmin=0, vmax=1)
plt.xticks( np.arange(nO)+.5, obs_vars)
plt.yticks( np.arange(nS)+.5, state_vars)
plt.xlabel( 'disease type')
plt.ylabel( 'first diagnosis')
# Panel D: visualize the model policy with biased prior
plt.subplot( 2, 3, 5)
plt.imshow( p2['pi(a|s)'].T, cmap='Greens', origin='lower', vmin=0, vmax=1)
plt.xticks( np.arange(nS)+.5, state_vars)
plt.yticks( np.arange(nA)+.5, act_vars)
plt.xlabel( 'first diagnosis')
plt.ylabel( 'treatment')
# Panel E: visualize the obs policy with biased prior
plt.subplot( 2, 3, 6)
plt.imshow( p2['pi(a|o)'].T, cmap='Blues', origin='lower', vmin=0, vmax=1)
plt.xticks( np.arange(nO)+.5, obs_vars)
plt.yticks( np.arange(nA)+.5, act_vars)
plt.xlabel( 'disease type')
plt.ylabel( 'treatment')
fig_name = f'{path}/parallel_channel-beta1={beta1}-beta3={beta3}.png'
plt.savefig( fig_name)
if __name__ == '__main__':
#Predator prey
lambdas = [ 1.65, 1.65, .4]
beta1s = [ 8, 8, 1 ]
beta2s = [ 10, 1, 1 ]
beta3s = [ 0, 0, 0 ]
for lamb, beta1, beta2, beta3 in zip( lambdas, beta1s, beta2s, beta3s):
illustrate_cascade_channel( lamb, beta1, beta2, beta3)
# Medical
beta1 = 2.
beta2 = np.inf
beta3 = .9
illustrate_parallel_channel( beta1, beta2, beta3)