Suriyadeepan Ramamoorthy
# input image
_x = tf.placeholder(tf.float32, [None, 784])
# reshape image for convolution
x = tf.reshape(_x, [-1, 28, 28, 1])
# first layer of convolution
with tf.variable_scope('conv1'):
# create 256 filters of kernel size 9x9
w = tf.get_variable('w', shape=[9, 9, 1, 256], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer())
# stride = 1
conv1 = tf.nn.conv2d(x, w, [1,1,1,1], padding='VALID', name='conv1')
# relu activation
conv1 = tf.nn.relu(conv1)
In convolutional capsule layers each unit in a capsule is a convolutional unit. Therefore, each capsule will output a grid of vectors rather than a single vector output.
with tf.variable_scope('primary_caps'):
# 9x9 filters, 32*8=256 channels, stride=2
primary_capsules = tf.contrib.slim.conv2d(inputs=conv1,
num_outputs=32*8,
kernel_size=9,
stride=2,
padding='VALID',
activation_fn=None)
# apply "squash" non-linearity
primary_capsules = squash(primary_capsules)
# primary capsules : 32 x [6x6]
num_capsules = 32*6*6
primary_capsule_dim = 8
# reshape primary capsules for calculating prediction vectors
primary_capsules = tf.reshape(primary_capsules_,
[-1, 1, num_capsules, 1, primary_capsule_dim])
# next capsule layer (digit capsules) : [10, 16]
num_digits = 10
digit_capsule_dim = 16
# weight matrix
Wij = tf.get_variable('Wij',
[num_digits, num_capsules, primary_capsule_dim, digit_capsule_dim],
dtype=tf.float32)
# tile primary capsules for multiplication with weight matrix
tiled_prim_caps = tf.tile(primary_capsules, [1, num_digits, 1, 1, 1])
# yeah.. we need a loop :(
# help me fix this!
cap_predictions = tf.scan(lambda _, x : tf.matmul(x, Wij), # fn
tiled_prim_caps, # elements
initializer = tf.zeros([num_digits, num_capsules, 1, digit_capsule_dim])
)
# squeeze dummy dimensions
cap_predictions = tf.squeeze(cap_predictions, [3])
# { b_ij } log prior probabilities
priors = tf.get_variable('log_priors',
[num_digits, num_caps],
initializer=tf.zeros_initializer())
# expand to support batch dimension
priors = tf.expand_dims(priors, axis=0)
for i in range(routing_iterations):
with tf.variable_scope('routing_{}'.format(i)):
# softmax along "digits" axis
c = tf.nn.softmax(priors, dim=1)
# reshape to multiply with predictions
c_t = tf.expand_dims(priors, axis=-1)
s_t = cap_predictions * c_t
s = tf.reduce_sum(s_t, axis=2)
digit_caps = squash(s)
delta_priors = tf.reduce_sum(
cap_predictions * tf.expand_dims(digit_caps, 2), -1)
priors = priors + delta_priors
return digit_caps
#positives
pos_loss = tf.maximum(0.,
0.9 - tf.reduce_sum(digit_caps_norm * _y,
axis=1))
# mean-squared error
pos_loss = tf.reduce_mean(tf.square(pos_loss))
# negatives
y_negs = 1. - _y
neg_loss = tf.maximum(0., digit_caps_norm * y_negs - 0.1)
neg_loss = tf.reduce_sum(tf.square(neg_loss), axis=-1) * 0.5
neg_loss = tf.reduce_mean(neg_loss)
margin_loss = pos_loss + neg_loss
# reconstruct original image with a 3-layered MLP
def reconstruct(target_cap):
with tf.name_scope('reconstruct'):
fc = fully_connected(target_cap, 512)
fc = fully_connected(fc, 1024)
fc = fully_connected(fc, 784, activation_fn=None)
out = tf.sigmoid(fc)
return out
reconstruct_loss = tf.reduce_mean(tf.reduce_sum(
tf.square(_x - reconstruct(target_cap)), axis=-1))
total_loss = pos_loss + neg_loss + 0.0005 * reconstruct_loss