Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_make_batchset(swap_io):
dummy_json = make_dummy_json(128, [128, 512], [16, 128])
# check w/o adaptive batch size
batchset = make_batchset(dummy_json, 24, 2 ** 10, 2 ** 10,
min_batch_size=1, swap_io=swap_io)
assert sum([len(batch) >= 1 for batch in batchset]) == len(batchset)
print([len(batch) for batch in batchset])
batchset = make_batchset(dummy_json, 24, 2 ** 10, 2 ** 10,
min_batch_size=10, swap_io=swap_io)
assert sum([len(batch) >= 10 for batch in batchset]) == len(batchset)
print([len(batch) for batch in batchset])
# check w/ adaptive batch size
batchset = make_batchset(dummy_json, 24, 256, 64,
min_batch_size=10, swap_io=swap_io)
assert sum([len(batch) >= 10 for batch in batchset]) == len(batchset)
print([len(batch) for batch in batchset])
batchset = make_batchset(dummy_json, 24, 256, 64,
min_batch_size=10, swap_io=swap_io)
assert sum([len(batch) >= 10 for batch in batchset]) == len(batchset)
def test_gradient_noise_injection(module):
args = make_arg(grad_noise=True)
args_org = make_arg()
dummy_json = make_dummy_json_st(2, [10, 20], [10, 20], [10, 20], idim=20, odim=5)
if module == "pytorch":
import espnet.nets.pytorch_backend.e2e_st as m
else:
raise NotImplementedError
batchset = make_batchset(dummy_json, 2, 2 ** 10, 2 ** 10, shortest_first=True)
model = m.E2E(20, 5, args)
model_org = m.E2E(20, 5, args_org)
for batch in batchset:
loss = model(*convert_batch(batch, module, idim=20, odim=5))
loss_org = model_org(*convert_batch(batch, module, idim=20, odim=5))
loss.backward()
grad = [param.grad for param in model.parameters()][10]
loss_org.backward()
grad_org = [param.grad for param in model_org.parameters()][10]
assert grad[0] != grad_org[0]
def test_model_trainable_and_decodable(module, num_encs, model_dict):
args = make_arg(num_encs=num_encs, **model_dict)
batch = prepare_inputs("pytorch", num_encs)
# test trainable
m = importlib.import_module(module)
model = m.E2E([40 for _ in range(num_encs)], 5, args)
loss = model(*batch)
loss.backward() # trainable
# test attention plot
dummy_json = make_dummy_json(num_encs, [10, 20], [10, 20], idim=40, odim=5, num_inputs=num_encs)
batchset = make_batchset(dummy_json, 2, 2 ** 10, 2 ** 10, shortest_first=True)
att_ws = model.calculate_all_attentions(*convert_batch(
batchset[0], "pytorch", idim=40, odim=5, num_inputs=num_encs))
from espnet.asr.asr_utils import PlotAttentionReport
tmpdir = tempfile.mkdtemp()
plot = PlotAttentionReport(model.calculate_all_attentions, batchset[0], tmpdir, None, None, None)
for i in range(num_encs):
# att-encoder
att_w = plot.get_attention_weight(0, att_ws[i][0])
plot._plot_and_save_attention(att_w, '{}/att{}.png'.format(tmpdir, i))
# han
att_w = plot.get_attention_weight(0, att_ws[num_encs][0])
plot._plot_and_save_attention(att_w, '{}/han.png'.format(tmpdir), han_mode=True)
# test decodable
with torch.no_grad(), chainer.no_backprop_mode():
in_data = [np.random.randn(10, 40) for _ in range(num_encs)]
def test_sortagrad_trainable_with_batch_bins(module):
args = make_arg(sortagrad=1)
idim = 20
odim = 5
dummy_json = make_dummy_json_st(4, [10, 20], [10, 20], [10, 20], idim=idim, odim=odim)
if module == "pytorch":
import espnet.nets.pytorch_backend.e2e_st as m
else:
raise NotImplementedError
batch_elems = 2000
batchset = make_batchset(dummy_json, batch_bins=batch_elems, shortest_first=True)
for batch in batchset:
n = 0
for uttid, info in batch:
ilen = int(info['input'][0]['shape'][0])
olen = int(info['output'][0]['shape'][0])
n += ilen * idim + olen * odim
assert olen < batch_elems
model = m.E2E(20, 5, args)
for batch in batchset:
loss = model(*convert_batch(batch, module, idim=20, odim=5))
if isinstance(loss, tuple):
# chainer return several values as tuple
loss[0].backward() # trainable
else:
loss.backward() # trainable
def test_sortagrad_trainable_with_batch_frames(module):
args = make_arg(sortagrad=1)
idim = 6
odim = 5
dummy_json = make_dummy_json_mt(4, [10, 20], [10, 20], idim=idim, odim=odim)
if module == "pytorch":
import espnet.nets.pytorch_backend.e2e_mt as m
else:
import espnet.nets.chainer_backend.e2e_mt as m
batch_frames_in = 20
batch_frames_out = 20
batchset = make_batchset(dummy_json,
batch_frames_in=batch_frames_in,
batch_frames_out=batch_frames_out,
shortest_first=True,
mt=True, iaxis=1, oaxis=0)
for batch in batchset:
i = 0
o = 0
for uttid, info in batch:
i += int(info['output'][1]['shape'][0])
o += int(info['output'][0]['shape'][0])
assert i <= batch_frames_in
assert o <= batch_frames_out
model = m.E2E(6, 5, args)
for batch in batchset:
attn_loss = model(*convert_batch(batch, module, idim=6, odim=5))
def test_sortagrad_trainable_with_batch_bins(module):
args = make_arg(sortagrad=1)
idim = 6
odim = 5
dummy_json = make_dummy_json_mt(4, [10, 20], [10, 20], idim=idim, odim=odim)
if module == "pytorch":
import espnet.nets.pytorch_backend.e2e_mt as m
else:
import espnet.nets.chainer_backend.e2e_mt as m
batch_elems = 2000
batchset = make_batchset(dummy_json, batch_bins=batch_elems, shortest_first=True, mt=True, iaxis=1, oaxis=0)
for batch in batchset:
n = 0
for uttid, info in batch:
ilen = int(info['output'][1]['shape'][0])
olen = int(info['output'][0]['shape'][0])
n += ilen * idim + olen * odim
assert olen < batch_elems
model = m.E2E(6, 5, args)
for batch in batchset:
attn_loss = model(*convert_batch(batch, module, idim=6, odim=5))
attn_loss.backward()
with torch.no_grad(), chainer.no_backprop_mode():
in_data = np.random.randint(0, 5, (1, 100))
model.translate(in_data, args, args.char_list)
:param str batch_sort_key: how to sort data before creating minibatches [input, output, shuffle]
:return: List[List[Tuple[str, dict]]] list of batches
Reference: https://github.com/espnet/espnet/pull/759/files
https://github.com/espnet/espnet/commit/dc0a0d3cfc271af945804f391e81cd5824b08725
https://github.com/espnet/espnet/commit/73018318a65d18cf2e644a45aa725323c9e4a0e6
"""
assert task in list(TASK_SET.keys())
#swap_io: if True, use "input" as output and "output" as input in `data` dict
swap_io = False
if task == TASK_SET['tts']:
swap_io = True
minibatches = make_batchset_espnet(
data,
batch_size=batch_size,
max_length_in=max_length_in,
max_length_out=max_length_out,
num_batches=num_batches,
min_batch_size=min_batch_size,
shortest_first=shortest_first,
batch_sort_key=batch_sort_key,
swap_io=swap_io,
count=batch_strategy,
batch_bins=batch_bins,
batch_frames_in=batch_frames_in,
batch_frames_out=batch_frames_out,
batch_frames_inout=batch_frames_inout)
return minibatches
with open(args.valid_json, 'rb') as f:
valid_json = json.load(f)['utts']
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
# make minibatch list (variable length)
train = make_batchset(train_json, args.batch_size,
args.maxlen_in, args.maxlen_out, args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0, oaxis=0)
valid = make_batchset(valid_json, args.batch_size,
args.maxlen_in, args.maxlen_out, args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0, oaxis=0)
load_tr = LoadInputsAndTargets(
mode='asr', load_output=True, preprocess_conf=args.preprocess_conf,
preprocess_args={'train': True} # Switch the mode of preprocessing
)
load_cv = LoadInputsAndTargets(
mode='asr', load_output=True, preprocess_conf=args.preprocess_conf,
preprocess_args={'train': False} # Switch the mode of preprocessing
# FIXME: TOO DIRTY HACK
setattr(optimizer, "target", reporter)
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
# Setup a converter
converter = CustomConverter(idim=idim)
# read json data
with open(args.train_json, 'rb') as f:
train_json = json.load(f)['utts']
with open(args.valid_json, 'rb') as f:
valid_json = json.load(f)['utts']
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
# make minibatch list (variable length)
train = make_batchset(train_json, args.batch_size,
args.maxlen_in, args.maxlen_out, args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
mt=True, iaxis=1, oaxis=0)
valid = make_batchset(valid_json, args.batch_size,
args.maxlen_in, args.maxlen_out, args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
with open(args.valid_json, 'rb') as f:
valid_json = json.load(f)['utts']
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
# make minibatch list (variable length)
train = make_batchset(train_json, args.batch_size,
args.maxlen_in, args.maxlen_out, args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
mt=True, iaxis=1, oaxis=0)
valid = make_batchset(valid_json, args.batch_size,
args.maxlen_in, args.maxlen_out, args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
mt=True, iaxis=1, oaxis=0)
load_tr = LoadInputsAndTargets(
mode='mt', load_output=True, preprocess_conf=args.preprocess_conf,
preprocess_args={'train': True} # Switch the mode of preprocessing
)
load_cv = LoadInputsAndTargets(
mode='mt', load_output=True, preprocess_conf=args.preprocess_conf,
preprocess_args={'train': False} # Switch the mode of preprocessing