multi_head_attention

ivy.neural_net_functional.layers.multi_head_attention(x, to_q_fn, to_kv_fn, to_out_fn, scale, num_heads, context=None, mask=None, to_q_v=None, to_kv_v=None, to_out_v=None)[source]

Applies multi-head attention to inputs x.

Parameters
  • x (array) – The array to determine the queries from [batch_shape,num_queries,x_feat_dim].

  • to_q_fn (callable) – The function to compute queries from input x, returning queries [batch_shape,num_queries,numheads×feat_dim].

  • to_kv_fn (callable) – The function to compute keys and values from the context.

  • to_out_fn (callable) – The function to compute the output from the scaled dot-product attention.

  • scale (float) – The value by which to scale the query-key similarity measure before softmax.

  • num_heads (int) – The number of attention heads to use.

  • context (array, optional) – The array to determine the keys and values from. Default is None. [batch_shape,num_keys,cont_feat_dim].

  • mask (array, optional) – The mask to apply to the query-key values. Default is None. [batch_shape,num_queries,num_keys]

  • to_q_v (variables array, optional) – The variables for function to_q_fn. Default is None.

  • to_kv_v (variables array, optional) – The variables for function to_kv_fn. Default is None.

  • to_out_v (variables array, optional) – The variables for function to_out_fn. Default is None.

:return The output following application of multi-head attention. [batch_shape,num_queries,out_feat_dim]


Supported Frameworks:

empty jax_logo empty tf_logo empty pytorch_logo empty mxnet_logo empty numpy_logo empty