Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
import torch
stats['pytorch'] = {}
if torch.cuda.is_available():
stats['pytorch']['gpu_max_memory_bytes'] = torch.cuda.max_memory_allocated()
if 'tensorflow' in sys.modules:
import tensorflow as tf
stats['tensorflow'] = {}
if int(tf.__version__.split('.')[0]) < 2:
if tf.test.is_gpu_available():
stats['tensorflow']['gpu_max_memory_bytes'] = tf.contrib.memory_stats.MaxBytesInUse()
else:
if len(tf.config.experimental.list_physical_devices('GPU')) >= 1:
logging.info("SEML stats: There is currently no way to get actual GPU memory usage in TensorFlow 2.")
collection = db_utils.get_collection(run.config['db_collection'])
collection.update_one(
{'_id': exp_id},
{'$set': {'stats': stats}})