How to use the torchx.nn function in torchx

To help you get started, we’ve selected a few torchx examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github SurrealAI / surreal / surreal / utils / pytorch / mlp.py View on Github external
import torchx.nn as nnx
from surreal.utils.common import iter_last


def fc_layers(input_size, output_size, hiddens, initializer='xavier'):
    assert isinstance(hiddens, (list, tuple))
    fcs = nn.ModuleList() # IMPORTANT for .cuda() to work!!
    layers = [input_size] + hiddens + [output_size]
    for prev, next in zip(layers[:-1], layers[1:]):
        fcs.append(nn.Linear(prev, next))
    if initializer == 'xavier':
        conv_fc_init(fcs)
    return fcs


class MLP(nnx.Module):
    def __init__(self, input_size, output_size, hiddens, activation=None):
        super().__init__()
        if activation is None:
            self.activation = F.relu
        else:
            raise NotImplementedError # TODO: other activators
        self.layers = fc_layers(input_size=input_size,
                                output_size=output_size,
                                hiddens=hiddens)

    def reinitialize(self):
        conv_fc_init(self.layers)

    def forward(self, x):
        for is_last, fc in iter_last(self.layers):
            x = fc(x)
github SurrealAI / surreal / surreal / model / ddpg_net.py View on Github external
from queue import Queue
import torch.nn as nn
from torch.nn.init import xavier_uniform
import surreal.utils as U
import torch.nn.functional as F
import numpy as np
import itertools
import torchx.nn as nnx

from .model_builders import *
from .z_filter import ZFilter

class DDPGModel(nnx.Module):

    def __init__(self,
                 obs_spec,
                 action_dim,
                 use_layernorm,
                 actor_fc_hidden_sizes,
                 critic_fc_hidden_sizes,
                 conv_out_channels,
                 conv_kernel_sizes,
                 conv_strides,
                 conv_hidden_dim,
                 critic_only=False,
                 ):
        super(DDPGModel, self).__init__()

        # hyperparameters
github SurrealAI / surreal / surreal / model / gail_net.py View on Github external
import torch.nn as nn
from torch.nn.init import xavier_uniform
import surreal.utils as U
import torch.nn.functional as F
import numpy as np
import torchx
import torchx.nn as nnx

from .model_builders import *
from .z_filter import ZFilter
from .ppo_net import DiagGauss, PPOModel

import itertools


class GAILModel(nnx.Module):
    """
        The GAIL model just defines a discriminator.
    """
    def __init__(self,
                 obs_spec,
                 action_dim,
                 model_config,
                 use_cuda,
                 use_z_filter=False):

        super(GAILModel, self).__init__()
        self.obs_spec = obs_spec
        self.action_dim = action_dim
        self.model_config = model_config
        self.use_z_filter = use_z_filter
github SurrealAI / surreal / surreal / model / z_filter.py View on Github external
import torch
import torch.nn as nn
import surreal.utils as U
import numpy as np
import torchx.nn as nnx

class ZFilter(nnx.Module):
    """
        Keeps historical average and std of inputs
        Whitens data and clamps to +/- 5 std
        Attributes:
            in_size: state dimension
                required from input
            eps: tolerance value for computing Z-filter (whitening)
                default to 10
            running_sum: running sum of all previous states
                (Note, type is torch.cuda.FloatTensor or torch.FloatTensor)
            running_sumsq: sum of square of all previous states
                (Note, type is torch.cuda.FloatTensor or torch.FloatTensor)
            count: number of experiences accumulated
                (Note, type is torch.cuda.FloatTensor or torch.FloatTensor)
    """
    def __init__(self, obs_spec, eps=1e-5):
github SurrealAI / surreal / surreal / model / reward_filter.py View on Github external
import torch
import numpy as np
import torchx.nn as nnx

class RewardFilter(nnx.Module):
    """
        Keeps historical average of rewards
        Attributes:
            eps: tolerance value for computing reward filter (whitening)
                default to 10
            running_sum: running sum of all previous rewards
                (Note, type is torch.cuda.FloatTensor or torch.FloatTensor)
            running_sumsq: sum of square of all previous states
                (Note, type is torch.cuda.FloatTensor or torch.FloatTensor)
            count: number of experiences accumulated
                (Note, type is torch.cuda.FloatTensor or torch.FloatTensor)
    """
    def __init__(self, eps=1e-5):
        """
            Constructor for RewardFilter class
            Args:
github SurrealAI / surreal / surreal / model / q_net.py View on Github external
import torch.nn as nn
from torch.nn.init import xavier_uniform
import surreal.utils as U
import torch.nn.functional as F
import numpy as np
import torchx.nn as nnx


class DuelingQbase(nnx.Module):
    def init_dueling(self, *,
                     action_dim,
                     prelinear_size,
                     fc_hidden_sizes,
                     dueling):
        """
        Args:
            - prelinear_size: size of feature vector before the linear layers,
                like flattened conv or LSTM features
            - fc_hidden_sizes: list of fully connected layer sizes before `action_dim` softmax
        """
        self.dueling = dueling
        self.prelinear_size = prelinear_size
        U.assert_type(fc_hidden_sizes, list)
        hiddens = [prelinear_size] + fc_hidden_sizes
        self.fc_action_layers = nn.ModuleList()
github SurrealAI / surreal / surreal / distributed / module_dict.py View on Github external
def __init__(self, module_dict):
        U.assert_type(module_dict, dict)
        for k, m in module_dict.items():
            U.assert_type(k, str, 'Key "{}" must be string.'.format(k))
            U.assert_type(m, nnx.Module,
                          '"{}" must be torchx.nn.Module.'.format(m))
        self._module_dict = module_dict
github SurrealAI / surreal / surreal / model / layer_norm.py View on Github external
#import torch
#import torch.nn as nn
import torchx.nn as nnx

# Inspired by https://github.com/pytorch/pytorch/issues/1959
class LayerNorm(nnx.Module):

    def __init__(self):
        super().__init__()
        self.eps = 1e-6

    def forward(self, x):
        # For (N, C) or (N, C, H, W), we want to average across C
        assert len(x.shape) in [2, 4]
        c_dimension = 1
        mean = x.mean(c_dimension, keepdim=True)
        std = x.std(c_dimension, keepdim=True)
        return (x - mean) / (std + self.eps)
github SurrealAI / surreal / surreal / model / ppo_net.py View on Github external
prob_shape = prob.shape
            prob = prob.reshape(-1, self.d * 2)
        mean_nd = prob[:, :self.d]
        std_nd = prob[:, self.d:]
        return np.random.randn(prob.shape[0], self.d) * std_nd + mean_nd

    def maxprob(self, prob):
        '''
            Method deterministically sample actions of maximum likelihood
        '''
        if len(prob.shape) == 3:
            return prob[:, :, self.d]
        return prob[:, :self.d]


class PPOModel(nnx.Module):
    '''
        PPO Model class that wraps aroud the actor and critic networks
        Attributes:
            actor: Actor network, see surreal.model.model_builders.builders
            critic: Critic network, see surreal.model.model_builders.builders
            z_filter: observation z_filter. see surreal.model.z_filter
        Member functions:
            update_target_param: updates kept parameters to that of another model
            update_target_param: updates kept z_filter to that of another model
            forward_actor: forward pass actor to generate policy with option
                to use z-filter
            forward_actor: forward pass critic to generate policy with option
                to use z-filter
            z_update: updates Z_filter running obs mean and variance
    '''
    def __init__(self,
github SurrealAI / surreal / surreal / model / model_builders / builders.py View on Github external
xp = L.Linear(hidden_sizes[1])(xp_input_concat)
        xp = L.ReLU()(xp)
        if use_layernorm:
            xp = L.LayerNorm(1)(xp)
        xp = L.Linear(1)(xp)

        self.model_concat = L.Functional(inputs=xp_input_concat, outputs=xp)
        self.model_concat.build((None, D_act + hidden_sizes[0]))

    def forward(self, obs, act):
        h_obs = self.model_obs(obs)
        h1 = torch.cat((h_obs, act), 1)
        value = self.model_concat(h1)
        return value

class PPO_ActorNetwork(nnx.Module):
    '''
        PPO custom actor network structure
    '''
    def __init__(self, D_obs, D_act, hidden_sizes=[64, 64], init_log_sig=0):
        '''
            Constructor for PPO actor network
            Args: 
                D_obs: observation space dimension, scalar
                D_act: action space dimension, scalar
                hidden_sizes: list of fully connected dimension
                init_log_sig: initial value for log standard deviation parameter
        '''
        super(PPO_ActorNetwork, self).__init__()
        # assumes D_obs here is the correct RNN hidden dim
        xp_input = L.Placeholder((None, D_obs))
        xp = L.Linear(hidden_sizes[0])(xp_input)