Source code for kospeech.data.data_loader

# Copyright (c) 2020, Soohwan Kim. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import math
import threading
import torch
import random

from omegaconf import DictConfig
from torch.utils.data import Dataset
from kospeech.data import load_dataset
from kospeech.utils import logger
from kospeech.data import SpectrogramParser
from kospeech.vocabs import Vocabulary


[docs]class SpectrogramDataset(Dataset, SpectrogramParser): """ Dataset for feature & transcript matching Args: audio_paths (list): list of audio path transcripts (list): list of transcript sos_id (int): identification of <start of sequence> eos_id (int): identification of <end of sequence> spec_augment (bool): flag indication whether to use spec-augmentation or not (default: True) config (DictConfig): set of configurations dataset_path (str): path of dataset """ def __init__( self, audio_paths: list, # list of audio paths transcripts: list, # list of transcript paths sos_id: int, # identification of start of sequence token eos_id: int, # identification of end of sequence token config: DictConfig, # set of arguments spec_augment: bool = False, # flag indication whether to use spec-augmentation of not dataset_path: str = None, # path of dataset, audio_extension: str = 'pcm' # audio extension ) -> None: super(SpectrogramDataset, self).__init__( feature_extract_by=config.audio.feature_extract_by, sample_rate=config.audio.sample_rate, n_mels=config.audio.n_mels, frame_length=config.audio.frame_length, frame_shift=config.audio.frame_shift, del_silence=config.audio.del_silence, input_reverse=config.audio.input_reverse, normalize=config.audio.normalize, freq_mask_para=config.audio.freq_mask_para, time_mask_num=config.audio.time_mask_num, freq_mask_num=config.audio.freq_mask_num, sos_id=sos_id, eos_id=eos_id, dataset_path=dataset_path, transform_method=config.audio.transform_method, audio_extension=audio_extension ) self.audio_paths = list(audio_paths) self.transcripts = list(transcripts) self.augment_methods = [self.VANILLA] * len(self.audio_paths) self.dataset_size = len(self.audio_paths) self._augment(spec_augment) self.shuffle()
[docs] def get_item(self, idx): """ get feature vector & transcript """ feature = self.parse_audio(os.path.join(self.dataset_path, self.audio_paths[idx]), self.augment_methods[idx]) transcript = self.parse_transcript(self.transcripts[idx]) return feature, transcript
[docs] def parse_transcript(self, transcript): """ Parses transcript """ tokens = transcript.split(' ') transcript = list() transcript.append(int(self.sos_id)) for token in tokens: transcript.append(int(token)) transcript.append(int(self.eos_id)) return transcript
def _augment(self, spec_augment): """ Spec Augmentation """ if spec_augment: logger.info("Applying Spec Augmentation...") for idx in range(self.dataset_size): self.augment_methods.append(self.SPEC_AUGMENT) self.audio_paths.append(self.audio_paths[idx]) self.transcripts.append(self.transcripts[idx])
[docs] def shuffle(self): """ Shuffle dataset """ tmp = list(zip(self.audio_paths, self.transcripts, self.augment_methods)) random.shuffle(tmp) self.audio_paths, self.transcripts, self.augment_methods = zip(*tmp)
def __len__(self): return len(self.audio_paths) def count(self): return len(self.audio_paths)
[docs]class AudioDataLoader(threading.Thread): """ Audio Data Loader Args: dataset (SpectrogramDataset): dataset for feature & transcript matching queue (Queue.queue): queue for threading batch_size (int): size of batch thread_id (int): identification of thread """ def __init__(self, dataset, queue, batch_size, thread_id, pad_id): threading.Thread.__init__(self) self.collate_fn = _collate_fn self.dataset = dataset self.queue = queue self.index = 0 self.batch_size = batch_size self.dataset_count = dataset.count() self.thread_id = thread_id self.pad_id = pad_id def _create_empty_batch(self): seqs = torch.zeros(0, 0, 0) targets = torch.zeros(0, 0).to(torch.long) seq_lengths = list() target_lengths = list() return seqs, targets, seq_lengths, target_lengths
[docs] def run(self): """ Load data from MelSpectrogramDataset """ logger.debug('loader %d start' % self.thread_id) while True: items = list() for _ in range(self.batch_size): if self.index >= self.dataset_count: break feature_vector, transcript = self.dataset.get_item(self.index) if feature_vector is not None: items.append((feature_vector, transcript)) self.index += 1 if len(items) == 0: batch = self._create_empty_batch() self.queue.put(batch) break batch = self.collate_fn(items, self.pad_id) self.queue.put(batch) logger.debug('loader %d stop' % self.thread_id)
def count(self): return math.ceil(self.dataset_count / self.batch_size)
def _collate_fn(batch, pad_id): """ functions that pad to the maximum sequence length """ def seq_length_(p): return len(p[0]) def target_length_(p): return len(p[1]) # sort by sequence length for rnn.pack_padded_sequence() batch = sorted(batch, key=lambda sample: sample[0].size(0), reverse=True) seq_lengths = [len(s[0]) for s in batch] target_lengths = [len(s[1]) - 1 for s in batch] max_seq_sample = max(batch, key=seq_length_)[0] max_target_sample = max(batch, key=target_length_)[1] max_seq_size = max_seq_sample.size(0) max_target_size = len(max_target_sample) feat_size = max_seq_sample.size(1) batch_size = len(batch) seqs = torch.zeros(batch_size, max_seq_size, feat_size) targets = torch.zeros(batch_size, max_target_size).to(torch.long) targets.fill_(pad_id) for x in range(batch_size): sample = batch[x] tensor = sample[0] target = sample[1] seq_length = tensor.size(0) seqs[x].narrow(0, 0, seq_length).copy_(tensor) targets[x].narrow(0, 0, len(target)).copy_(torch.LongTensor(target)) seq_lengths = torch.IntTensor(seq_lengths) return seqs, targets, seq_lengths, target_lengths
[docs]class MultiDataLoader(object): """ Multi Data Loader using Threads. Args: dataset_list (list): list of MelSpectrogramDataset queue (Queue.queue): queue for threading batch_size (int): size of batch num_workers (int): the number of cpu cores used """ def __init__(self, dataset_list, queue, batch_size, num_workers, pad_id): self.dataset_list = dataset_list self.queue = queue self.batch_size = batch_size self.num_workers = num_workers self.loader = list() for idx in range(self.num_workers): self.loader.append(AudioDataLoader(self.dataset_list[idx], self.queue, self.batch_size, idx, pad_id))
[docs] def start(self): """ Run threads """ for idx in range(self.num_workers): self.loader[idx].start()
[docs] def join(self): """ Wait for the other threads """ for idx in range(self.num_workers): self.loader[idx].join()
[docs]def split_dataset(config: DictConfig, transcripts_path: str, vocab: Vocabulary): """ split into training set and validation set. Args: opt (ArgumentParser): set of options transcripts_path (str): path of transcripts Returns: train_batch_num, train_dataset_list, valid_dataset - **train_time_step** (int): number of time step for training - **trainset_list** (list): list of training dataset - **validset** (data_loader.MelSpectrogramDataset): validation dataset """ logger.info("split dataset start !!") trainset_list = list() if config.train.dataset == 'kspon': train_num = 620000 valid_num = 2545 elif config.train.dataset == 'libri': train_num = 281241 valid_num = 5567 else: raise NotImplementedError("Unsupported Dataset : {0}".format(config.train.dataset)) audio_paths, transcripts = load_dataset(transcripts_path) total_time_step = math.ceil(len(audio_paths) / config.train.batch_size) valid_time_step = math.ceil(valid_num / config.train.batch_size) train_time_step = total_time_step - valid_time_step train_audio_paths = audio_paths[:train_num + 1] train_transcripts = transcripts[:train_num + 1] valid_audio_paths = audio_paths[train_num + 1:] valid_transcripts = transcripts[train_num + 1:] if config.audio.spec_augment: train_time_step <<= 1 train_num_per_worker = math.ceil(train_num / config.train.num_workers) # audio_paths & script_paths shuffled in the same order # for seperating train & validation tmp = list(zip(train_audio_paths, train_transcripts)) random.shuffle(tmp) train_audio_paths, train_transcripts = zip(*tmp) # seperating the train dataset by the number of workers for idx in range(config.train.num_workers): train_begin_idx = train_num_per_worker * idx train_end_idx = min(train_num_per_worker * (idx + 1), train_num) trainset_list.append( SpectrogramDataset( train_audio_paths[train_begin_idx:train_end_idx], train_transcripts[train_begin_idx:train_end_idx], vocab.sos_id, vocab.eos_id, config=config, spec_augment=config.audio.spec_augment, dataset_path=config.train.dataset_path, audio_extension=config.audio.audio_extension, ) ) validset = SpectrogramDataset( audio_paths=valid_audio_paths, transcripts=valid_transcripts, sos_id=vocab.sos_id, eos_id=vocab.eos_id, config=config, spec_augment=False, dataset_path=config.train.dataset_path, audio_extension=config.audio.audio_extension, ) logger.info("split dataset complete !!") return train_time_step, trainset_list, validset