Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def mathy_load_a3c(env_name: str, gym_env: MathyGymEnv, model: str):
import tensorflow as tf
global __agent
if __agent is None:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "5"
tf.compat.v1.logging.set_verbosity("CRITICAL")
if model is None:
raise ValueError("model is none, must be specified")
args = BaseConfig(model_dir=model)
__agent = A3CAgent(args)
from datetime import datetime
from typing import List, Optional
import gym
import matplotlib.pyplot as plt
import numpy as np
import plac
from tqdm import trange
from mathy.agents.a3c import A3CAgent, BaseConfig
from mathy.agents.mcts import MCTS
from mathy.env.gym import MathyGymEnv
__mcts: Optional[MCTS] = None
__agent: Optional[A3CAgent] = None
def mathy_load_a3c(env_name: str, gym_env: MathyGymEnv, model: str):
import tensorflow as tf
global __agent
if __agent is None:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "5"
tf.compat.v1.logging.set_verbosity("CRITICAL")
if model is None:
raise ValueError("model is none, must be specified")
args = BaseConfig(model_dir=model)
__agent = A3CAgent(args)
model_folder = tempfile.mkdtemp()
setup_tf_env()
args = A3CConfig(
max_eps=3,
verbose=True,
topics=["poly"],
model_dir=model_folder,
update_gradients_every=4,
num_workers=1,
units=4,
embedding_units=4,
lstm_units=4,
print_training=True,
)
instance = A3CAgent(args)
instance.train()
# Load the model back in
model_two = get_or_create_policy_model(
args=args, predictions=PolySimplify().action_size, is_main=True
)
# Comment this out to keep your model
shutil.rmtree(model_folder)
model_folder = tempfile.mkdtemp()
setup_tf_env()
args = A3CConfig(
profile=True,
max_eps=2,
verbose=True,
mcts_sims=1,
action_strategy="mcts_worker_0",
topics=["poly-grouping"],
model_dir=model_folder,
num_workers=2,
print_training=True,
)
A3CAgent(args).train()
assert os.path.isfile(os.path.join(args.model_dir, "worker_0.profile"))
assert os.path.isfile(os.path.join(args.model_dir, "worker_1.profile"))
# Comment this out to keep your model
shutil.rmtree(model_folder)
import shutil
import tempfile
model_folder = tempfile.mkdtemp()
setup_tf_env()
args = A3CConfig(
action_strategy="mcts_worker_n",
max_eps=1,
verbose=True,
topics=["poly-combine"],
model_dir=model_folder,
num_workers=2,
print_training=True,
)
A3CAgent(args).train()
# Comment this out to keep your model
shutil.rmtree(model_folder)
import shutil
import tempfile
model_folder = tempfile.mkdtemp()
setup_tf_env()
args = A3CConfig(
max_eps=1,
verbose=True,
action_strategy="mcts_worker_0",
topics=["poly-combine"],
model_dir=model_folder,
num_workers=2,
print_training=True,
)
A3CAgent(args).train()
# Comment this out to keep your model
shutil.rmtree(model_folder)