Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"loss": NeuralType(None)
}
return input_ports, output_ports
def __init__(self, num_classes, **kwargs):
LossNM.__init__(self, **kwargs)
self._criterion = nn.CrossEntropyLoss()
self.num_classes = num_classes
def _loss_function(self, logits, labels):
logits = logits.view(-1, self.num_classes)
labels = labels.view(-1)
loss = self._criterion(logits, labels)
return loss
class SequenceRegressionLoss(LossNM):
@staticmethod
def create_ports():
input_ports = {
"logits": NeuralType({
0: AxisType(BatchTag),
1: AxisType(ChannelTag)
}),
"labels": NeuralType({
0: AxisType(BatchTag)
})
}
output_ports = {
"loss": NeuralType(None)
}
return input_ports, output_ports
torch.nn.ReLU(),
# reshape to [-1, 120]
torch.nn.Flatten(),
torch.nn.Linear(120, 84),
torch.nn.ReLU(),
torch.nn.Linear(84, 10),
torch.nn.LogSoftmax(dim=1)
)
self.to(self._device)
def forward(self, images):
predictions = self.model(images)
return predictions
class NLLLoss(LossNM):
@staticmethod
def create_ports():
input_ports = {
"predictions": NeuralType(
{0: AxisType(BatchTag),
1: AxisType(LogProbabilityTag)}),
"targets": NeuralType({0: AxisType(BatchTag)}),
}
output_ports = {"loss": NeuralType(None)}
return input_ports, output_ports
def __init__(self, **kwargs):
# Neural Module API specific
LossNM.__init__(self, **kwargs)
# End of Neural Module API specific
self._criterion = torch.nn.NLLLoss()
NeuralType({0: AxisType(BatchTag)}),
}
output_ports = {"loss": NeuralType(None)}
return input_ports, output_ports
def __init__(self, **kwargs):
TrainableNM.__init__(self, **kwargs)
self._loss_fn = SequenceClassificationLoss()
def forward(self, log_probs, labels):
loss = self._loss_fn(log_probs, labels)
return loss
class LossAggregatorNM(LossNM):
@staticmethod
def create_ports(num_losses=2):
input_ports = {}
for i in range(num_losses):
input_ports["loss_" + str(i + 1)] = NeuralType(None)
output_ports = {"loss": NeuralType(None)}
return input_ports, output_ports
def __init__(self, *, num_inputs, **kwargs):
kwargs["create_port_args"] = {"num_losses": num_inputs}
LossNM.__init__(self, **kwargs)
def _loss_function(self, **kwargs):
values = [kwargs[x] for x in sorted(kwargs.keys())]
loss = values[0]
}),
}
output_ports = {"loss": NeuralType(None)}
return input_ports, output_ports
def __init__(self, label_smoothing=0.0, **kwargs):
LossNM.__init__(self, **kwargs)
self._criterion = SmoothedCrossEntropyLoss(label_smoothing)
def _loss_function(self, logits, output_ids, output_mask):
loss = self._criterion(logits, output_ids, output_mask)
return loss
class LossAggregatorNM(LossNM):
@staticmethod
def create_ports(num_losses=2):
input_ports = {}
for i in range(num_losses):
input_ports["loss_" + str(i + 1)] = NeuralType(None)
output_ports = {"loss": NeuralType(None)}
return input_ports, output_ports
def __init__(self, *, num_inputs, **kwargs):
kwargs["create_port_args"] = {"num_losses": num_inputs}
LossNM.__init__(self, **kwargs)
def _loss_function(self, **kwargs):
values = [kwargs[x] for x in sorted(kwargs.keys())]
loss = values[0]
pad_mask = pad_mask.float()
loss = -torch.sum(loss * pad_mask)
if self.sample_wise:
loss /= target_log_probs.size(0)
else:
loss /= pad_mask.sum() + EPS
return loss
def _ctc_loss(self, log_probs, targets, pad_mask):
lengths = pad_mask.sum(-1)
loss = self.ctc(log_probs.transpose(0, 1), targets, lengths, lengths)
loss = torch.mean(loss)
return loss
class CrossEntropyLoss(LossNM):
"""
CrossEntropyLoss
"""
@staticmethod
def create_ports():
input_ports = {
"logits": NeuralType({
0: AxisType(BatchTag),
1: AxisType(ChannelTag)
}),
"labels": NeuralType({
0: AxisType(BatchTag),
})
}
slots,
intent_loss_weight=0.6):
intent_loss = self._criterion(intent_logits, intents)
active_loss = input_mask.view(-1) > 0.5
active_logits = slot_logits.view(-1, self.num_slots)[active_loss]
active_labels = slots.view(-1)[active_loss]
slot_loss = self._criterion(active_logits, active_labels)
loss = intent_loss * intent_loss_weight + \
slot_loss * (1 - intent_loss_weight)
return loss
class PaddedSmoothedCrossEntropyLossNM(LossNM):
"""
Neural module which calculates CrossEntropyLoss and
1) excludes padding tokens from loss calculation
2) allows to use label smoothing regularization
3) allows to calculate loss for the desired number of last tokens
Args:
label_smoothing: label smoothing regularization coefficient
predict_last_k: how many last tokens to use for the loss calculation
"""
@staticmethod
def create_ports():
input_ports = {
"logits":
"batch_size": params.get("batch_size", 1),
"beam_size": params.get("beam_size", 4),
"len_pen": params.get("length_penalty", 0)
}
self.generator = BeamSearchSequenceGenerator(
decoder.embedding_layer, decoder.decoder, log_softmax,
**generator_params)
def forward(self, hidden_states_src, input_mask_src):
output_ids = self.generator(
encoder_hidden_states=hidden_states_src,
encoder_input_mask=input_mask_src)
return output_ids
class PaddedSmoothedCrossEntropyLossNM(LossNM):
"""
Neural module which calculates CrossEntropyLoss and
1) excludes padding tokens from loss calculation
2) allows to use label smoothing regularization
3) allows to calculate loss for the desired number of last tokens
Args:
label_smoothing: label smoothing regularization coefficient
predict_last_k: how many last tokens to use for the loss calculation
"""
@staticmethod
def create_ports():
input_ports = {
"log_probs":
NeuralType({
output_ports = {"loss": NeuralType(None)}
return input_ports, output_ports
def __init__(self, *, num_inputs, **kwargs):
kwargs["create_port_args"] = {"num_losses": num_inputs}
LossNM.__init__(self, **kwargs)
def _loss_function(self, **kwargs):
values = [kwargs[x] for x in sorted(kwargs.keys())]
loss = values[0]
for loss_i in values[1:]:
loss = loss.add(loss_i.item())
return loss
class TokenClassificationLoss(LossNM):
@staticmethod
def create_ports():
input_ports = {
"logits": NeuralType({
0: AxisType(BatchTag),
1: AxisType(TimeTag),
2: AxisType(ChannelTag)
}),
"labels": NeuralType({
0: AxisType(BatchTag),
1: AxisType(TimeTag)
}),
"input_mask": NeuralType({
0: AxisType(BatchTag),
1: AxisType(TimeTag)
})