Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
for i, y in enumerate(lines):
self.assertEqual(data[i], y)
self.assertEqual(len(data), len(lines))
self.assertEqual(data._length, len(lines))
# check if length is cached
self.assertEqual(len(data), len(lines))
self.assertIsInstance(data._dataset, easyfile.TextFile)
data = data.map(str.split)
for x, y in zip(data, lines):
self.assertEqual(x, y.split())
self.assertIsInstance(data, lineflow.core.MapDataset)
self.assertIsInstance(data._dataset, TextDataset)
def __init__(self,
dataset: DatasetMixin,
map_func: Callable[[Any], Any]) -> None:
assert callable(map_func)
self._map_func = map_func
super(MapDataset, self).__init__(dataset)
def map(self, map_func: Callable[[Any], Any]) -> 'MapDataset':
"""Applies a function across the examples of this dataset.
Args:
map_func (Callable[[Any], Any]): A function to apply.
Returns ('MapDataset'):
The dataset applied the function.
"""
return MapDataset(self, map_func)
pkl_path = os.path.join(root, 'aclImdb', 'imdb.pkl')
return download.cache_or_load_file(pkl_path, creator, loader)
cached_get_imdb = lru_cache()(get_imdb)
def _imdb_loader(path: str) -> Tuple[str, int]:
with io.open(path, 'rt', encoding='utf-8') as f:
string = f.read()
label = 0 if 'pos' in path else 1
return (string, label)
class Imdb(MapDataset):
def __init__(self, split: str = 'train', loader=_imdb_loader) -> None:
if split not in {'train', 'test'}:
raise ValueError(f"only 'train' and 'test' are valid for 'split', but '{split}' is given.")
raw = cached_get_imdb()
super().__init__(raw[split], loader)