[go: up one dir, main page]

Skip to content

Commit

Permalink
removed query masking
Browse files Browse the repository at this point in the history
  • Loading branch information
kyubyong park committed Sep 24, 2019
1 parent 3a6723b commit 0fbf949
Showing 1 changed file with 47 additions and 52 deletions.
99 changes: 47 additions & 52 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@ def get_token_embeddings(vocab_size, num_units, zero_pad=True):
embeddings[1:, :]), 0)
return embeddings

def scaled_dot_product_attention(Q, K, V,
def scaled_dot_product_attention(Q, K, V, key_masks,
causality=False, dropout_rate=0.,
training=True,
scope="scaled_dot_product_attention"):
'''See 3.2.1.
Q: Packed queries. 3d tensor. [N, T_q, d_k].
K: Packed keys. 3d tensor. [N, T_k, d_k].
V: Packed values. 3d tensor. [N, T_k, d_v].
key_masks: A 2d tensor with shape of [N, key_seqlen]
causality: If True, applies masking for future blinding
dropout_rate: A floating point number of [0, 1].
training: boolean for controlling droput
Expand All @@ -76,7 +77,7 @@ def scaled_dot_product_attention(Q, K, V,
outputs /= d_k ** 0.5

# key masking
outputs = mask(outputs, Q, K, type="key")
outputs = mask(outputs, key_masks=key_masks, type="key")

# causality or future blinding masking
if causality:
Expand All @@ -87,8 +88,8 @@ def scaled_dot_product_attention(Q, K, V,
attention = tf.transpose(outputs, [0, 2, 1])
tf.summary.image("attention", tf.expand_dims(attention[:1], -1))

# query masking
outputs = mask(outputs, Q, K, type="query")
# # query masking
# outputs = mask(outputs, Q, K, type="query")

# dropout
outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=training)
Expand All @@ -98,65 +99,58 @@ def scaled_dot_product_attention(Q, K, V,

return outputs

def mask(inputs, queries=None, keys=None, type=None):

def mask(inputs, key_masks=None, type=None):
"""Masks paddings on keys or queries to inputs
inputs: 3d tensor. (N, T_q, T_k)
queries: 3d tensor. (N, T_q, d)
keys: 3d tensor. (N, T_k, d)
inputs: 3d tensor. (h*N, T_q, T_k)
key_masks: 3d tensor. (N, 1, T_k)
type: string. "key" | "future"
e.g.,
>> queries = tf.constant([[[1.],
[2.],
[0.]]], tf.float32) # (1, 3, 1)
>> keys = tf.constant([[[4.],
[0.]]], tf.float32) # (1, 2, 1)
>> inputs = tf.constant([[[4., 0.],
[8., 0.],
[0., 0.]]], tf.float32)
>> mask(inputs, queries, keys, "key")
array([[[ 4.0000000e+00, -4.2949673e+09],
[ 8.0000000e+00, -4.2949673e+09],
[ 0.0000000e+00, -4.2949673e+09]]], dtype=float32)
>> inputs = tf.constant([[[1., 0.],
[1., 0.],
[1., 0.]]], tf.float32)
>> mask(inputs, queries, keys, "query")
array([[[1., 0.],
[1., 0.],
[0., 0.]]], dtype=float32)
>> inputs = tf.zeros([2, 2, 3], dtype=tf.float32)
>> key_masks = tf.constant([[0., 0., 1.],
[0., 1., 1.]])
>> mask(inputs, key_masks=key_masks, type="key")
array([[[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09],
[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]],
[[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]],
[[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09],
[ 0.0000000e+00, 0.0000000e+00, -4.2949673e+09]],
[[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]]], dtype=float32)
"""
padding_num = -2 ** 32 + 1
if type in ("k", "key", "keys"):
# Generate masks
masks = tf.sign(tf.reduce_sum(tf.abs(keys), axis=-1)) # (N, T_k)
masks = tf.expand_dims(masks, 1) # (N, 1, T_k)
masks = tf.tile(masks, [1, tf.shape(queries)[1], 1]) # (N, T_q, T_k)

# Apply masks to inputs
paddings = tf.ones_like(inputs) * padding_num
outputs = tf.where(tf.equal(masks, 0), paddings, inputs) # (N, T_q, T_k)
elif type in ("q", "query", "queries"):
# Generate masks
masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q)
masks = tf.expand_dims(masks, -1) # (N, T_q, 1)
masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k)

# Apply masks to inputs
outputs = inputs*masks
key_masks = tf.to_float(key_masks)
key_masks = tf.tile(key_masks, [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen)
key_masks = tf.expand_dims(key_masks, 1) # (h*N, 1, seqlen)
outputs = inputs + key_masks * padding_num
# elif type in ("q", "query", "queries"):
# # Generate masks
# masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q)
# masks = tf.expand_dims(masks, -1) # (N, T_q, 1)
# masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k)
#
# # Apply masks to inputs
# outputs = inputs*masks
elif type in ("f", "future", "right"):
diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k)
tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k)
future_masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k)

paddings = tf.ones_like(masks) * padding_num
outputs = tf.where(tf.equal(masks, 0), paddings, inputs)
paddings = tf.ones_like(future_masks) * padding_num
outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs)
else:
print("Check if you entered type correctly!")


return outputs

def multihead_attention(queries, keys, values,

def multihead_attention(queries, keys, values, key_masks,
num_heads=8,
dropout_rate=0,
training=True,
Expand All @@ -166,6 +160,7 @@ def multihead_attention(queries, keys, values,
queries: A 3d tensor with shape of [N, T_q, d_model].
keys: A 3d tensor with shape of [N, T_k, d_model].
values: A 3d tensor with shape of [N, T_k, d_model].
key_masks: A 2d tensor with shape of [N, key_seqlen]
num_heads: An int. Number of heads.
dropout_rate: A floating point number.
training: Boolean. Controller of mechanism for dropout.
Expand All @@ -178,17 +173,17 @@ def multihead_attention(queries, keys, values,
d_model = queries.get_shape().as_list()[-1]
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
# Linear projections
Q = tf.layers.dense(queries, d_model, use_bias=False) # (N, T_q, d_model)
K = tf.layers.dense(keys, d_model, use_bias=False) # (N, T_k, d_model)
V = tf.layers.dense(values, d_model, use_bias=False) # (N, T_k, d_model)
Q = tf.layers.dense(queries, d_model, use_bias=True) # (N, T_q, d_model)
K = tf.layers.dense(keys, d_model, use_bias=True) # (N, T_k, d_model)
V = tf.layers.dense(values, d_model, use_bias=True) # (N, T_k, d_model)

# Split and concat
Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, d_model/h)
K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)
V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)

# Attention
outputs = scaled_dot_product_attention(Q_, K_, V_, causality, dropout_rate, training)
outputs = scaled_dot_product_attention(Q_, K_, V_, key_masks, causality, dropout_rate, training)

# Restore shape
outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2 ) # (N, T_q, d_model)
Expand Down

0 comments on commit 0fbf949

Please sign in to comment.