How to use the ray.rllib.utils.annotations.override function in ray

To help you get started, we’ve selected a few ray examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github ray-project / ray / rllib / evaluation / rollout_worker.py View on Github external
    @override(EvaluatorInterface)
    def set_weights(self, weights):
        for pid, w in weights.items():
            self.policy_map[pid].set_weights(w)
github ray-project / ray / rllib / models / tf / modelv1_compat.py View on Github external
        @override(ModelV2)
        def last_output(self):
            return self.cur_instance.outputs
github ray-project / ray / rllib / contrib / maddpg / maddpg_policy.py View on Github external
    @override(TFPolicy)
    def extra_compute_action_feed_dict(self):
        return {}
github ray-project / ray / rllib / policy / torch_policy_template.py View on Github external
        @override(Policy)
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            if not postprocess_fn:
                return sample_batch
            return postprocess_fn(self, sample_batch, other_agent_batches,
                                  episode)
github ray-project / ray / python / ray / rllib / evaluation / tf_policy_template.py View on Github external
        @override(TFPolicy)
        def gradients(self, optimizer, loss):
            if gradients_fn:
                return gradients_fn(self, optimizer, loss)
            else:
                return TFPolicy.gradients(self, optimizer, loss)
github ray-project / ray / rllib / optimizers / aso_aggregator.py View on Github external
    @override(Aggregator)
    def should_broadcast(self):
        return self.num_sent_since_broadcast >= self.broadcast_interval
github ray-project / ray / rllib / agents / trainer.py View on Github external
    @override(Trainable)
    def _restore(self, checkpoint_path):
        extra_data = pickle.load(open(checkpoint_path, "rb"))
        self.__setstate__(extra_data)
github ray-project / ray / rllib / models / torch / fcnet.py View on Github external
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs_flat"]
        features = self._hidden_layers(obs.reshape(obs.shape[0], -1))
        logits = self._logits(features)
        self._cur_value = self._value_branch(features).squeeze(1)
        return logits, state
github ray-project / ray / rllib / agents / qmix / model.py View on Github external
    @override(TorchModelV2)
    def forward(self, input_dict, hidden_state, seq_lens):
        x = F.relu(self.fc1(input_dict["obs_flat"].float()))
        h_in = hidden_state[0].reshape(-1, self.rnn_hidden_dim)
        h = self.rnn(x, h_in)
        q = self.fc2(h)
        return q, [h]
github ray-project / ray / rllib / models / torch / torch_action_dist.py View on Github external
    @override(ActionDistribution)
    def sample(self):
        return self.dist.sample()