Speech Transformer

Transformer

class kospeech.models.transformer.model.SpeechTransformer(input_dim: int, num_classes: int, extractor: str, num_encoder_layers: int = 12, num_decoder_layers: int = 6, encoder_dropout_p: float = 0.2, decoder_dropout_p: float = 0.2, d_model: int = 512, d_ff: int = 2048, pad_id: int = 0, sos_id: int = 1, eos_id: int = 2, num_heads: int = 8, joint_ctc_attention: bool = False, max_length: int = 400)[source]

A Speech Transformer model. User is able to modify the attributes as needed. The model is based on the paper “Attention Is All You Need”.

Parameters
  • input_dim (int) – dimension of input vector

  • num_classes (int) – number of classification

  • extractor (str) – type of CNN extractor (default: vgg)

  • num_encoder_layers (int, optional) – number of recurrent layers (default: 12)

  • num_decoder_layers (int, optional) – number of recurrent layers (default: 6)

  • encoder_dropout_p (float, optional) – dropout probability of encoder (default: 0.2)

  • decoder_dropout_p (float, optional) – dropout probability of decoder (default: 0.2)

  • d_model (int) – dimension of model (default: 512)

  • d_ff (int) – dimension of feed forward net (default: 2048)

  • pad_id (int) – identification of <PAD_token> (default: 0)

  • sos_id (int) – identification of <SOS_token> (default: 1)

  • eos_id (int) – identification of <EOS_token> (default: 2)

  • num_heads (int) – number of attention heads (default: 8)

  • max_length (int, optional) – max decoding step (default: 400)

  • joint_ctc_attention (bool, optional) – flag indication joint ctc attention or not (default: False)

Inputs: inputs, input_lengths, targets, teacher_forcing_ratio
  • inputs (torch.Tensor): tensor of sequences, whose length is the batch size and within which each sequence is a list of token IDs. This information is forwarded to the encoder.

  • input_lengths (torch.Tensor): tensor of sequences, whose contains length of inputs.

  • targets (torch.Tensor): tensor of sequences, whose length is the batch size and within which each sequence is a list of token IDs. This information is forwarded to the decoder.

Returns

(Tensor, Tensor, Tensor)

  • predicted_log_probs (torch.FloatTensor): Log probability of model predictions.

  • encoder_output_lengths: The length of encoder outputs. (batch)

  • encoder_log_probs: Log probability of encoder outputs will be passed to CTC Loss.

    If joint_ctc_attention is False, return None.

forward(inputs: torch.Tensor, input_lengths: torch.Tensor, targets: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]

Forward propagate a inputs and targets pair for training.

Parameters
  • inputs (torch.FloatTensor) – A input sequence passed to encoder. Typically for inputs this will be a padded FloatTensor of size (batch, seq_length, dimension).

  • input_lengths (torch.LongTensor) – The length of input tensor. (batch)

  • targets (torch.LongTensr) – A target sequence passed to decoder. IntTensor of size (batch, seq_length)

Returns

(Tensor, Tensor, Tensor)

  • predicted_log_probs (torch.FloatTensor): Log probability of model predictions.

  • encoder_output_lengths: The length of encoder outputs. (batch)

  • encoder_log_probs: Log probability of encoder outputs will be passed to CTC Loss.

    If joint_ctc_attention is False, return None.

Encoder

class kospeech.models.transformer.encoder.TransformerEncoder(input_dim: int, extractor: str = 'vgg', d_model: int = 512, d_ff: int = 2048, num_layers: int = 6, num_heads: int = 8, dropout_p: float = 0.3, joint_ctc_attention: bool = False, num_classes: int = None)[source]

The TransformerEncoder is composed of a stack of N identical layers. Each layer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a simple, position-wise fully connected feed-forward network.

Parameters
  • input_dim – dimension of feature vector

  • extractor (str) – convolutional extractor

  • d_model – dimension of model (default: 512)

  • d_ff – dimension of feed forward network (default: 2048)

  • num_layers – number of encoder layers (default: 6)

  • num_heads – number of attention heads (default: 8)

  • dropout_p – probability of dropout (default: 0.3)

Inputs:
  • inputs: list of sequences, whose length is the batch size and within which each sequence is list of tokens

  • input_lengths: list of sequence lengths

forward(inputs: torch.Tensor, input_lengths: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]

Forward propagate a inputs for encoder training.

Parameters
  • inputs (torch.FloatTensor) – A input sequence passed to encoder. Typically for inputs this will be a padded FloatTensor of size (batch, seq_length, dimension).

  • input_lengths (torch.LongTensor) – The length of input tensor. (batch)

Returns

  • outputs: A output sequence of encoder. FloatTensor of size (batch, seq_length, dimension)

  • output_lengths: The length of encoder outputs. (batch)

  • encoder_log_probs: Log probability of encoder outputs will be passed to CTC Loss.

    If joint_ctc_attention is False, return None.

Return type

(Tensor, Tensor, Tensor)

class kospeech.models.transformer.encoder.TransformerEncoderLayer(d_model: int = 512, num_heads: int = 8, d_ff: int = 2048, dropout_p: float = 0.3)[source]

EncoderLayer is made up of self-attention and feedforward network. This standard encoder layer is based on the paper “Attention Is All You Need”.

Parameters
  • d_model – dimension of model (default: 512)

  • num_heads – number of attention heads (default: 8)

  • d_ff – dimension of feed forward network (default: 2048)

  • dropout_p – probability of dropout (default: 0.3)

Decoder

class kospeech.models.transformer.decoder.TransformerDecoder(num_classes: int, d_model: int = 512, d_ff: int = 512, num_layers: int = 6, num_heads: int = 8, dropout_p: float = 0.3, pad_id: int = 0, sos_id: int = 1, eos_id: int = 2, max_length: int = 400)[source]

The TransformerDecoder is composed of a stack of N identical layers. Each layer has three sub-layers. The first is a multi-head self-attention mechanism, and the second is a multi-head attention mechanism, third is a feed-forward network.

Parameters
  • num_classes – umber of classes

  • d_model – dimension of model

  • d_ff – dimension of feed forward network

  • num_layers – number of decoder layers

  • num_heads – number of attention heads

  • dropout_p – probability of dropout

  • pad_id – identification of pad token

  • eos_id – identification of end of sentence token

decode(encoder_outputs: torch.Tensor, encoder_output_lengths: torch.Tensor)torch.Tensor[source]

Decode encoder_outputs.

Parameters

encoder_outputs (torch.FloatTensor) – A output sequence of encoder. FloatTensor of size (batch, seq_length, dimension)

Returns

Log probability of model predictions.

Return type

  • predicted_log_probs (torch.FloatTensor)

forward(targets: torch.Tensor, encoder_outputs: torch.Tensor, encoder_output_lengths: torch.Tensor)torch.Tensor[source]

Forward propagate a encoder_outputs for training.

Parameters
  • targets (torch.LongTensr) – A target sequence passed to decoder. IntTensor of size (batch, seq_length)

  • encoder_outputs (torch.FloatTensor) – A output sequence of encoder. FloatTensor of size (batch, seq_length, dimension)

  • encoder_output_lengths – The length of encoder outputs. (batch)

Returns

Log probability of model predictions.

Return type

  • predicted_log_probs (torch.FloatTensor)

class kospeech.models.transformer.decoder.TransformerDecoderLayer(d_model: int = 512, num_heads: int = 8, d_ff: int = 2048, dropout_p: float = 0.3)[source]

DecoderLayer is made up of self-attention, multi-head attention and feedforward network. This standard decoder layer is based on the paper “Attention Is All You Need”.

Parameters
  • d_model – dimension of model (default: 512)

  • num_heads – number of attention heads (default: 8)

  • d_ff – dimension of feed forward network (default: 2048)

  • dropout_p – probability of dropout (default: 0.3)

Sublayers

class kospeech.models.transformer.sublayers.AddNorm(sublayer: torch.nn.modules.module.Module, d_model: int = 512)[source]

Add & Normalization layer proposed in “Attention Is All You Need”. Transformer employ a residual connection around each of the two sub-layers, (Multi-Head Attention & Feed-Forward) followed by layer normalization.

class kospeech.models.transformer.sublayers.PositionwiseFeedForward(d_model: int = 512, d_ff: int = 2048, dropout_p: float = 0.3)[source]

Position-wise Feedforward Networks proposed in “Attention Is All You Need”. Fully connected feed-forward network, which is applied to each position separately and identically. This consists of two linear transformations with a ReLU activation in between. Another way of describing this is as two convolutions with kernel size 1.

Embeddings

class kospeech.models.transformer.embeddings.Embedding(num_embeddings: int, pad_id: int, d_model: int = 512)[source]

Embedding layer. Similarly to other sequence transduction models, transformer use learned embeddings to convert the input tokens and output tokens to vectors of dimension d_model. In the embedding layers, transformer multiply those weights by sqrt(d_model)

class kospeech.models.transformer.embeddings.PositionalEncoding(d_model: int = 512, max_len: int = 5000)[source]

Positional Encoding proposed in “Attention Is All You Need”. Since transformer contains no recurrence and no convolution, in order for the model to make use of the order of the sequence, we must add some positional information.

“Attention Is All You Need” use sine and cosine functions of different frequencies:

PE_(pos, 2i) = sin(pos / power(10000, 2i / d_model)) PE_(pos, 2i+1) = cos(pos / power(10000, 2i / d_model))

Mask

kospeech.models.transformer.mask.get_attn_pad_mask(inputs, input_lengths, expand_length)[source]

mask position is set to 1

kospeech.models.transformer.mask.get_decoder_self_attn_mask(seq_k: torch.Tensor, seq_q: torch.Tensor, pad_id)[source]

For masking the decoder self attention

kospeech.models.transformer.mask.get_non_pad_mask(inputs: torch.Tensor, input_lengths: torch.Tensor)torch.Tensor[source]

Padding position is set to 0, either use input_lengths or pad_id