Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.chainer_backend.transformer import ctc
from espnet.nets.chainer_backend.transformer.attention import MultiHeadAttention
from espnet.nets.chainer_backend.transformer.decoder import Decoder
from espnet.nets.chainer_backend.transformer.encoder import Encoder
from espnet.nets.chainer_backend.transformer.label_smoothing_loss import LabelSmoothingLoss
from espnet.nets.chainer_backend.transformer.plot import PlotAttentionReport
from espnet.nets.ctc_prefix_score import CTCPrefixScore
CTC_SCORING_RATIO = 1.5
MAX_DECODER_OUTPUT = 5
class E2E(ASRInterface, chainer.Chain):
"""E2E module.
Args:
idim (int): Input dimmensions.
odim (int): Output dimmensions.
args (Namespace): Training config.
ignore_id (int, optional): Id for ignoring a character.
flag_return (bool, optional): If true, return a list with (loss,
loss_ctc, loss_att, acc) in forward. Otherwise, return loss.
"""
@staticmethod
def add_arguments(parser):
"""Customize flags for transformer setup.
class Reporter(chainer.Chain):
"""A chainer reporter wrapper."""
def report(self, loss_ctc, loss_att, acc, cer_ctc, cer, wer, mtl_loss):
"""Report at every step."""
reporter.report({'loss_ctc': loss_ctc}, self)
reporter.report({'loss_att': loss_att}, self)
reporter.report({'acc': acc}, self)
reporter.report({'cer_ctc': cer_ctc}, self)
reporter.report({'cer': cer}, self)
reporter.report({'wer': wer}, self)
logging.info('mtl loss:' + str(mtl_loss))
reporter.report({'loss': mtl_loss}, self)
class E2E(ASRInterface, torch.nn.Module):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Add arguments."""
E2E.encoder_add_arguments(parser)
E2E.attention_add_arguments(parser)
E2E.decoder_add_arguments(parser)
return parser
return torch.load(model_path), 'lm'
idim, odim, args = get_model_conf(model_path, conf_path)
logging.warning('reading model parameters from ' + model_path)
if hasattr(args, "model_module"):
model_module = args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, args)
torch_load(model_path, model)
assert isinstance(model, MTInterface) or isinstance(model, ASRInterface)
return model.state_dict(), 'asr-mt'
Args:
args (namespace): The program arguments. See py:func:`espnet.bin.asr_recog.get_parser` for details
"""
logging.warning("experimental API for custom LMs is selected by --api v2")
if args.batchsize > 1:
raise NotImplementedError("batch decoding is not implemented")
if args.streaming_mode is not None:
raise NotImplementedError("streaming mode is not implemented")
if args.word_rnnlm:
raise NotImplementedError("word LM is not implemented")
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model)
assert isinstance(model, ASRInterface)
model.eval()
load_inputs_and_targets = LoadInputsAndTargets(
mode='asr', load_output=False, sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None else args.preprocess_conf,
preprocess_args={'train': False})
if args.rnnlm:
lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
# NOTE: for a compatibility with less than 0.5.0 version models
lm_model_module = getattr(lm_args, "model_module", "default")
lm_class = dynamic_import_lm(lm_model_module, lm_args.backend)
lm = lm_class(len(train_args.char_list), lm_args)
torch_load(args.rnnlm, lm)
lm.eval()
num_encs = len(loss_ctc_list) - 1
reporter.report({'loss_ctc': loss_ctc_list[0]}, self)
for i in range(num_encs):
reporter.report({'loss_ctc{}'.format(i + 1): loss_ctc_list[i + 1]}, self)
reporter.report({'loss_att': loss_att}, self)
reporter.report({'acc': acc}, self)
reporter.report({'cer_ctc': cer_ctc_list[0]}, self)
for i in range(num_encs):
reporter.report({'cer_ctc{}'.format(i + 1): cer_ctc_list[i + 1]}, self)
reporter.report({'cer': cer}, self)
reporter.report({'wer': wer}, self)
logging.info('mtl loss:' + str(mtl_loss))
reporter.report({'loss': mtl_loss}, self)
class E2E(ASRInterface, torch.nn.Module):
"""E2E module.
:param List idims: List of dimensions of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Add arguments for multi-encoder setting."""
E2E.encoder_add_arguments(parser)
E2E.attention_add_arguments(parser)
E2E.decoder_add_arguments(parser)
E2E.ctc_add_arguments(parser)
return parser
from espnet.nets.pytorch_backend.e2e_asr import Reporter
from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask
from espnet.nets.pytorch_backend.nets_utils import th_accuracy
from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos
from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention
from espnet.nets.pytorch_backend.transformer.decoder import Decoder
from espnet.nets.pytorch_backend.transformer.encoder import Encoder
from espnet.nets.pytorch_backend.transformer.initializer import initialize
from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import LabelSmoothingLoss
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask
from espnet.nets.pytorch_backend.transformer.mask import target_mask
from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport
from espnet.nets.scorers.ctc import CTCPrefixScorer
class E2E(ASRInterface, torch.nn.Module):
"""E2E module.
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
@staticmethod
def add_arguments(parser):
"""Add arguments."""
group = parser.add_argument_group("transformer model setting")
group.add_argument("--transformer-init", type=str, default="pytorch",
choices=["pytorch", "xavier_uniform", "xavier_normal",
"kaiming_uniform", "kaiming_normal"],
def recog(args):
"""Decode with the given args.
Args:
args (Namespace): The program arguments
"""
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model)
assert isinstance(model, ASRInterface)
model.recog_args = args
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch_load(args.rnnlm, rnnlm)
rnnlm.eval()
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
"""ST Interface module."""
from espnet.nets.asr_interface import ASRInterface
from espnet.utils.dynamic_import import dynamic_import
class STInterface(ASRInterface):
"""ST Interface for ESPnet model implementation.
NOTE: This class is inherited from ASRInterface to enable joint translation
and recognition when performing multi-task learning with the ASR task.
"""
def translate(self, x, trans_args, char_list=None, rnnlm=None, ensemble_models=[]):
"""Recognize x for evaluation.
:param ndarray x: input acouctic feature (B, T, D) or (T, D)
:param namespace trans_args: argment namespace contraining options
:param list char_list: list of characters
:param torch.nn.Module rnnlm: language model module
:return: N-best decoding results
:rtype: list
:return pit_loss
:rtype torch.Tensor (B)
:return permutation
:rtype torch.LongTensor (B, 1|2|3)
"""
bs = losses.size(0)
ret = [self.min_pit_sample(losses[i]) for i in range(bs)]
loss_perm = torch.stack([r[0] for r in ret], dim=0).to(losses.device) # (B)
permutation = torch.tensor([r[1] for r in ret]).long().to(losses.device)
return torch.mean(loss_perm), permutation
class E2E(ASRInterface, torch.nn.Module):
"""E2E module
:param int idim: dimension of inputs
:param int odim: dimension of outputs
:param Namespace args: argument Namespace containing options
"""
def __init__(self, idim, odim, args):
torch.nn.Module.__init__(self)
self.mtlalpha = args.mtlalpha
assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]"
self.etype = args.etype
self.verbose = args.verbose
self.char_list = args.char_list
self.outdir = args.outdir
self.reporter = Reporter()
if args.mtlalpha == 1.0:
mtl_mode = 'ctc'
logging.info('Pure CTC mode')
elif args.mtlalpha == 0.0:
mtl_mode = 'att'
logging.info('Pure attention mode')
else:
mtl_mode = 'mtl'
logging.info('Multitask learning mode')
if (args.enc_init is not None or args.dec_init is not None) and args.num_encs == 1:
model = load_trained_modules(idim_list[0], odim, args)
else:
model_class = dynamic_import(args.model_module)
model = model_class(idim_list[0] if args.num_encs == 1 else idim_list, odim, args)
assert isinstance(model, ASRInterface)
if args.rnnlm is not None:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch.load(args.rnnlm, rnnlm)
model.rnnlm = rnnlm
# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + '/model.json'
with open(model_conf, 'wb') as f:
logging.info('writing a model config file to ' + model_conf)
f.write(json.dumps((idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)),