Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
Args:
dtypes (dict):
mapping of field names and dtypes.
pii_fields (dict):
mapping of pii field names and categories.
Returns:
dict:
mapping of field names and transformer instances.
"""
transformers_dict = dict()
for name, dtype in dtypes.items():
dtype = np.dtype(dtype)
if dtype.kind == 'i':
transformer = transformers.NumericalTransformer(dtype=int)
elif dtype.kind == 'f':
transformer = transformers.NumericalTransformer(dtype=float)
elif dtype.kind == 'O':
anonymize = pii_fields.get(name)
transformer = transformers.CategoricalTransformer(anonymize=anonymize)
elif dtype.kind == 'b':
transformer = transformers.BooleanTransformer()
elif dtype.kind == 'M':
transformer = transformers.DatetimeTransformer()
else:
raise ValueError('Unsupported dtype: {}'.format(dtype))
LOGGER.info('Loading transformer %s for field %s',
transformer.__class__.__name__, name)
transformers_dict[name] = transformer
},
'categorical_transformer': {
'type': 'str',
'default': 'categoircal_fuzzy',
'description': 'Type of transformer to use for the categorical variables',
'choices': [
'categorical',
'categorical_fuzzy',
'one_hot_encoding',
'label_encoding'
]
}
}
DEFAULT_TRANSFORMER = 'one_hot_encoding'
CATEGORICAL_TRANSFORMERS = {
'categorical': rdt.transformers.CategoricalTransformer(fuzzy=False),
'categorical_fuzzy': rdt.transformers.CategoricalTransformer(fuzzy=True),
'one_hot_encoding': rdt.transformers.OneHotEncodingTransformer,
'label_encoding': rdt.transformers.LabelEncodingTransformer,
}
TRANSFORMER_TEMPLATES = {
'O': rdt.transformers.OneHotEncodingTransformer
}
def __init__(self, distribution=None, categorical_transformer=None, *args, **kwargs):
super().__init__(*args, **kwargs)
if self._metadata is not None and 'model_kwargs' in self._metadata._metadata:
model_kwargs = self._metadata._metadata['model_kwargs']
if distribution is None:
distribution = model_kwargs['distribution']
def _get_transformers(self, dtypes):
"""Create the transformer instances needed to process the given dtypes.
Args:
dtypes (dict):
mapping of field names and dtypes.
Returns:
dict:
mapping of field names and transformer instances.
"""
transformer_templates = {
'i': rdt.transformers.NumericalTransformer(dtype=int),
'f': rdt.transformers.NumericalTransformer(dtype=float),
'O': rdt.transformers.CategoricalTransformer,
'b': rdt.transformers.BooleanTransformer,
'M': rdt.transformers.DatetimeTransformer,
}
transformer_templates.update(self._transformer_templates)
transformers = dict()
for name, dtype in dtypes.items():
transformer_template = transformer_templates[np.dtype(dtype).kind]
if isinstance(transformer_template, type):
transformer = transformer_template()
else:
transformer = copy.deepcopy(transformer_template)
LOGGER.info('Loading transformer %s for field %s',