How to use the magent.utility function in magent

To help you get started, we’ve selected a few magent examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geek-ai / MAgent / python / magent / model.py View on Github external
"""target function for sub-processing to host a model

    Parameters
    ----------
    addr: socket address
    sample_buffer_capacity: int
        the maximum number of samples (s,r,a,s') to collect in a game round
    RLModel: BaseModel
        the RL algorithm class
    args: dict
        arguments to RLModel
    """
    import magent.utility

    model = RLModel(**model_args)
    sample_buffer = magent.utility.EpisodesBuffer(capacity=sample_buffer_capacity)

    conn = multiprocessing.connection.Client(addr)

    while True:
        cmd = conn.recv()
        if cmd[0] == 'act':
            policy = cmd[1]
            eps = cmd[2]
            array_info = cmd[3]

            view, feature, ids = NDArrayPackage(array_info).recv_from(conn)
            obs = (view, feature)

            acts = model.infer_action(obs, ids, policy=policy, eps=eps)
            package = NDArrayPackage(acts)
            conn.send(package.info)
github geek-ai / MAgent / examples / train_gather.py View on Github external
# init env
    env = magent.GridWorld(load_config(size=args.map_size))
    env.set_render_dir("build/render")

    handles = env.get_handles()
    food_handle = handles[0]
    player_handles = handles[1:]

    # sample eval observation set
    eval_obs = None
    if args.eval:
        print("sample eval set...")
        env.reset()
        generate_map(env, args.map_size, food_handle, player_handles)
        eval_obs = magent.utility.sample_observation(env, player_handles, 0, 2048, 500)

    # load models
    models = [
        RLModel(env, player_handles[0], args.name,
                batch_size=512, memory_size=2 ** 19, target_update=1000,
                train_freq=4, eval_obs=eval_obs)
    ]

    # load saved model
    save_dir = "save_model"
    if args.load_from is not None:
        start_from = args.load_from
        print("load models...")
        for model in models:
            model.load(save_dir, start_from)
    else: