# Copyright (c) Facebook, Inc. and its affiliates.
from __future__ import annotations
import enum
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
if TYPE_CHECKING:
from minihack import MiniHack
from nle.nethack import Command, CompassDirection
Y_cmd = CompassDirection.NW
[docs]class EventType(enum.IntEnum):
MESSAGE = 0
LOC_ACTION = 1
COORD = 2
LOC = 3
COMESTIBLES = [
"orange",
"meatball",
"meat ring",
"meat stick",
"kelp frond",
"eucalyptus leaf",
"clove of garlic",
"sprig of wolfsbane",
"carrot",
"egg",
"banana",
"melon",
"candy bar",
"lump of royal jelly",
]
[docs]class Event(ABC):
"""An event which can occur in a MiniHack episode.
This is the base class of all other events.
"""
[docs] def __init__(
self,
reward: float,
repeatable: bool,
terminal_required: bool,
terminal_sufficient: bool,
):
"""Initialise the Event.
Args:
reward (float):
The reward for the event occuring
repeatable (bool):
Whether the event can occur repeated (i.e. if the reward can be
collected repeatedly
terminal_required (bool):
Whether this event is required for the episode to terminate.
terminal_sufficient (bool):
Whether this event causes the episode to terminate on its own.
"""
self.reward = reward
self.repeatable = repeatable
self.terminal_required = terminal_required
self.terminal_sufficient = terminal_sufficient
self.achieved = False
[docs] @abstractmethod
def check(self, env, previous_observation, action, observation) -> float:
"""Check whether the environment is in the state such that this event
has occured.
Args:
env (MiniHack):
The MiniHack environment in question.
previous_observation (tuple):
The previous state observation.
action (int):
The action taken.
observation (tuple):
The current observation.
Returns:
float: The reward.
"""
pass
[docs] def reset(self):
"""Reset the event, if there is any state necessary."""
self.achieved = False
def _set_achieved(self) -> float:
if not self.repeatable:
self.achieved = True
return self.reward
def _standing_on_top(env, location):
return not env.screen_contains(location)
[docs]class LocActionEvent(Event):
"""An event which checks whether an action is performed at a specified
location.
"""
[docs] def __init__(
self,
*args,
loc: str,
action: Command,
):
"""Initialise the Event.
Args:
loc (str):
The name of the location to reach.
action (int):
The action to perform.
reward (float):
The reward for the event occuring
repeatable (bool):
Whether the event can occur repeated (i.e. if the reward can be
collected repeatedly
terminal_required (bool):
Whether this event is required for the episode to terminate.
terminal_sufficient (bool):
Whether this event causes the episode to terminate on its own.
"""
super().__init__(*args)
self.loc = loc
self.action = action
self.status = False
[docs] def check(self, env, previous_observation, action, observation) -> float:
del previous_observation, observation
if env._actions[action] == self.action and _standing_on_top(
env, self.loc
):
self.status = True
elif env._actions[action] == Y_cmd and self.status:
return self._set_achieved()
else:
self.status = False
return 0
[docs] def reset(self):
super().reset()
self.status = False
[docs]class LocEvent(Event):
"""An event which checks whether a specified location is reached."""
[docs] def __init__(self, *args, loc: str):
super().__init__(*args)
"""Initialise the Event.
Args:
loc (str):
The name of the location to reach.
reward (float):
The reward for the event occuring
repeatable (bool):
Whether the event can occur repeated (i.e. if the reward can be
collected repeatedly
terminal_required (bool):
Whether this event is required for the episode to terminate.
terminal_sufficient (bool):
Whether this event causes the episode to terminate on its own.
"""
self.loc = loc
[docs] def check(self, env, previous_observation, action, observation) -> float:
del previous_observation, action, observation
if _standing_on_top(env, self.loc):
return self._set_achieved()
return 0.0
[docs]class CoordEvent(Event):
"""An event which occurs when reaching certain coordinates."""
[docs] def __init__(self, *args, coordinates: Tuple[int, int]):
"""Initialise the Event.
Args:
coordinates (tuple):
The coordinates to reach for the event.
reward (float):
The reward for the event occuring
repeatable (bool):
Whether the event can occur repeated (i.e. if the reward can be
collected repeatedly
terminal_required (bool):
Whether this event is required for the episode to terminate.
terminal_sufficient (bool):
Whether this event causes the episode to terminate on its own.
"""
super().__init__(*args)
self.coordinates = coordinates
[docs] def check(self, env, previous_observation, action, observation) -> float:
coordinates = tuple(observation[env._blstats_index][:2])
if self.coordinates == coordinates:
return self._set_achieved()
return 0.0
[docs]class MessageEvent(Event):
"""An event which occurs when any of the `messages` appear."""
[docs] def __init__(self, *args, messages: List[str]):
"""Initialise the Event.
Args:
messages (list):
The messages to be seen to trigger the event.
reward (float):
The reward for the event occuring
repeatable (bool):
Whether the event can occur repeated (i.e. if the reward can be
collected repeatedly
terminal_required (bool):
Whether this event is required for the episode to terminate.
terminal_sufficient (bool):
Whether this event causes the episode to terminate on its own.
"""
super().__init__(*args)
self.messages = messages
[docs] def check(self, env, previous_observation, action, observation) -> float:
del previous_observation, action
curr_msg = (
observation[env._original_observation_keys.index("message")]
.tobytes()
.decode("utf-8")
)
for msg in self.messages:
if msg in curr_msg:
return self._set_achieved()
return 0.0
[docs]class AbstractRewardManager(ABC):
"""This is the abstract base class for the ``RewardManager`` that is used
for defining custom reward functions.
"""
[docs] def __init__(self):
self.terminal_sufficient = None
self.terminal_required = None
[docs] @abstractmethod
def collect_reward(self) -> float:
"""Return reward calculated and accumulated in check_episode_end_call,
and then reset it.
Returns:
flaot: The reward.
"""
raise NotImplementedError
[docs] @abstractmethod
def check_episode_end_call(
self, env, previous_observation, action, observation
) -> bool:
"""Check if the task has ended, and accumulate any reward from the
transition in ``self._reward``.
Args:
env (MiniHack):
The MiniHack environment in question.
previous_observation (tuple):
The previous state observation.
action (int):
The action taken.
observation (tuple):
The current observation.
Returns:
bool: Boolean whether the episode has ended.
"""
raise NotImplementedError
[docs] @abstractmethod
def reset(self) -> None:
"""Reset all events, to be called when a new episode occurs."""
raise NotImplementedError
[docs]class RewardManager(AbstractRewardManager):
"""This class is used for managing rewards, events and termination for
MiniHack tasks.
Some notes on the ordering or calls in the MiniHack/NetHack base class:
- ``step(action)`` is called on the environment
- Within ``step``, first a copy of the last observation is made, and then the
underlying NetHack game is stepped
- Then ``_is_episode_end(observation)`` is called to check whether this the
episode has ended (and this is overridden if we've gone over our
max_steps, or the underlying NetHack game says we're done (i.e. we died)
- Then ``_reward_fn(last_observation, observation)`` is called to calculate
the reward at this time-step
- if ``end_status`` tells us the game is done, we quit the game
- then ``step`` returns the observation, calculated reward, done, and some
statistics.
All this means that we need to check whether an observation is terminal in
``_is_episode_end`` before we're calculating the reward function.
The call of ``_is_episode_end`` in ``MiniHack`` will call
``check_episode_end_call`` in this class, which checks for termination and
accumulates any reward, which is returned and zeroed in ``collect_reward``.
"""
[docs] def __init__(self):
self.events: List[Event] = []
self.custom_reward_functions: List[
Callable[[MiniHack, Any, int, Any], float]
] = []
self._reward = 0.0
# Only used for GroupedRewardManager
self.terminal_sufficient = None
self.terminal_required = None
[docs] def add_custom_reward_fn(
self, reward_fn: Callable[[MiniHack, Any, int, Any], float]
) -> None:
"""Add a custom reward function which is called every after step to
calculate reward.
The function should be a callable which takes the environment, previous
observation, action and current observation and returns a float reward.
Args:
reward_fn (Callable[[MiniHack, Any, int, Any], float]):
A reward function which takes an environment, previous
observation, action, next observation and returns a reward.
"""
self.custom_reward_functions.append(reward_fn)
[docs] def add_event(self, event: Event):
"""Add an event to be managed by the reward manager.
Args:
event (Event):
The event to be added.
"""
self.events.append(event)
def _add_message_event(
self, msgs, reward, repeatable, terminal_required, terminal_sufficient
):
self.add_event(
MessageEvent(
reward,
repeatable,
terminal_required,
terminal_sufficient,
messages=msgs,
)
)
def _add_loc_action_event(
self,
loc,
action,
reward,
repeatable,
terminal_required,
terminal_sufficient,
):
try:
action = Command[action.upper()]
except KeyError:
raise KeyError(
"Action {} is not in the action space.".format(action.upper())
)
self.add_event(
LocActionEvent(
reward,
repeatable,
terminal_required,
terminal_sufficient,
loc=loc.lower(),
action=action,
)
)
[docs] def add_eat_event(
self,
name: str,
reward=1,
repeatable=False,
terminal_required=True,
terminal_sufficient=False,
):
"""Add an event which is triggered when `name` is eaten.
Args:
name (str):
The name of the object being eaten.
reward (float):
The reward for this event. Defaults to 1.
repeatable (bool):
Whether this event can be triggered multiple times. Defaults to
False.
terminal_required (bool):
Whether this event is required for termination. Defaults to
True.
terminal_sufficient (bool):
Whether this event is sufficient for termination. Defaults to
False.
"""
msgs = [
f"This {name} is delicious",
"Blecch! Rotten food!",
"last bite of your meal",
]
if name == "apple":
msgs.append("Delicious! Must be a Macintosh!")
msgs.append("Core dumped.")
if name == "pear":
msgs.append("Core dumped.")
self._add_message_event(
msgs, reward, repeatable, terminal_required, terminal_sufficient
)
[docs] def add_wield_event(
self,
name: str,
reward=1,
repeatable=False,
terminal_required=True,
terminal_sufficient=False,
):
"""Add event which is triggered when a specific weapon is wielded.
Args:
name (str):
The name of the weapon to be wielded.
reward (float):
The reward for this event. Defaults to 1.
repeatable (bool):
Whether this event can be triggered multiple times. Defaults to
False.
terminal_required (bool):
Whether this event is required for termination. Defaults to
True.
terminal_sufficient (bool):
Whether this event is sufficient for termination. Defaults to
False.
"""
msgs = [
f"{name} wields itself to your hand!",
f"{name} (weapon in hand)",
]
self._add_message_event(
msgs, reward, repeatable, terminal_required, terminal_sufficient
)
[docs] def add_wear_event(
self,
name: str,
reward=1,
repeatable=False,
terminal_required=True,
terminal_sufficient=False,
):
"""Add event which is triggered when a specific armor is worn.
Args:
name (str):
The name of the armor to be worn.
reward (float):
The reward for this event. Defaults to 1.
repeatable (bool):
Whether this event can be triggered multiple times. Defaults to
False.
terminal_required (bool):
Whether this event is required for termination. Defaults to
True.
terminal_sufficient (bool):
Whether this event is sufficient for termination. Defaults to
False.
"""
msgs = [f"You are now wearing a {name}"]
self._add_message_event(
msgs, reward, repeatable, terminal_required, terminal_sufficient
)
[docs] def add_amulet_event(
self,
reward=1,
repeatable=False,
terminal_required=True,
terminal_sufficient=False,
):
"""Add event which is triggered when an amulet is worn.
Args:
reward (float):
The reward for this event. Defaults to 1.
repeatable (bool):
Whether this event can be triggered multiple times. Defaults to
False.
terminal_required (bool):
Whether this event is required for termination. Defaults to
True.
terminal_sufficient (bool):
Whether this event is sufficient for termination. Defaults to
False.
"""
self._add_message_event(
["amulet (being worn)."],
reward,
repeatable,
terminal_required,
terminal_sufficient,
)
[docs] def add_kill_event(
self,
name: str,
reward=1,
repeatable=False,
terminal_required=True,
terminal_sufficient=False,
):
"""Add event which is triggered when a specified monster is killed.
Args:
name (str):
The name of the monster to be killed.
reward (float):
The reward for this event. Defaults to 1.
repeatable (bool):
Whether this event can be triggered multiple times. Defaults to
False.
terminal_required (bool):
Whether this event is required for termination. Defaults to
True.
terminal_sufficient (bool):
Whether this event is sufficient for termination. Defaults to
False.
"""
self._add_message_event(
[f"You kill the {name}"],
reward,
repeatable,
terminal_required,
terminal_sufficient,
)
[docs] def add_message_event(
self,
msgs: List[str],
reward=1,
repeatable=False,
terminal_required=True,
terminal_sufficient=False,
):
"""Add event which is triggered when any of the given messages are seen.
Args:
msgs (List[str]):
The name of the monster to be killed.
reward (float):
The reward for this event. Defaults to 1.
repeatable (bool):
Whether this event can be triggered multiple times. Defaults to
False.
terminal_required (bool):
Whether this event is required for termination. Defaults to
True.
terminal_sufficient (bool):
Whether this event is sufficient for termination. Defaults to
False.
"""
self._add_message_event(
msgs, reward, repeatable, terminal_required, terminal_sufficient
)
[docs] def add_positional_event(
self,
place_name: str,
action_name: str,
reward=1,
repeatable=False,
terminal_required=True,
terminal_sufficient=False,
):
"""Add event which is triggered on taking a given action at a given place.
Args:
place_name (str):
The name of the place to trigger the event.
action_name (int):
The name of the action to trigger the event.
reward (float):
The reward for this event. Defaults to 1.
repeatable (bool):
Whether this event can be triggered multiple times. Defaults to
False.
terminal_required (bool):
Whether this event is required for termination. Defaults to
True.
terminal_sufficient (bool):
Whether this event is sufficient for termination. Defaults to
False.
"""
self._add_loc_action_event(
place_name,
action_name,
reward,
repeatable,
terminal_required,
terminal_sufficient,
)
[docs] def add_coordinate_event(
self,
coordinates: Tuple[int, int],
reward=1,
repeatable=False,
terminal_required=True,
terminal_sufficient=False,
):
"""Add event which is triggered on when reaching the specified
coordinates.
Args:
coordinates (Tuple[int, int]):
The coordinates to be reached (tuple of ints).
reward (float):
The reward for this event. Defaults to 1.
repeatable (bool):
Whether this event can be triggered multiple times. Defaults to
False.
terminal_required (bool):
Whether this event is required for termination. Defaults to
True.
terminal_sufficient (bool):
Whether this event is sufficient for termination. Defaults to
False.
"""
self.add_event(
CoordEvent(
reward,
repeatable,
terminal_required,
terminal_sufficient,
coordinates=coordinates,
)
)
[docs] def add_location_event(
self,
location: str,
reward=1,
repeatable=False,
terminal_required=True,
terminal_sufficient=False,
):
"""Add event which is triggered on reaching a specified location.
Args:
name (str):
The name of the location to be reached.
reward (float):
The reward for this event. Defaults to 1.
repeatable (bool):
Whether this event can be triggered multiple times. Defaults to
False.
terminal_required (bool):
Whether this event is required for termination. Defaults to
True.
terminal_sufficient (bool):
Whether this event is sufficient for termination. Defaults to
False.
"""
self.add_event(
LocEvent(
reward,
repeatable,
terminal_required,
terminal_sufficient,
loc=location,
)
)
def _set_achieved(self, event: Event) -> float:
if not event.repeatable:
event.achieved = True
return event.reward
def _standing_on_top(self, env, name):
"""Returns whether the agents is standing on top of the given object.
The object name (e.g. altar, sink, fountain) must exist on the map.
Args:
env (MiniHack):
The environment object.
name (str):
The name of the object.
Returns:
bool: True if the object name is not in the screen descriptions
with agent info taking the space of the corresponding tile rather
than the object).
"""
return not env.screen_contains(name)
[docs] def check_episode_end_call(
self, env, previous_observation, action, observation
) -> bool:
reward = 0.0
for event in self.events:
if event.achieved:
continue
reward += event.check(
env, previous_observation, action, observation
)
for custom_reward_function in self.custom_reward_functions:
reward += custom_reward_function(
env, previous_observation, action, observation
)
self._reward += reward
return self._check_complete()
def _check_complete(self) -> bool:
"""Checks whether the episode is complete.
Requires any event which is sufficient to be achieved, OR all required
events to be achieved."""
result = True
for event in self.events:
# This event is enough, we're done
if event.achieved and event.terminal_sufficient:
return True
# We need this event and we haven't done it, we're not done
if not event.achieved and event.terminal_required:
result = False
# We've achieved all terminal_required events, we're done
return result
[docs] def collect_reward(self) -> float:
result = self._reward
self._reward = 0.0
return result
[docs] def reset(self):
self._reward = 0.0
for event in self.events:
event.reset()
[docs]class SequentialRewardManager(RewardManager):
"""A reward manager that ignores ``terminal_required`` and
``terminal_sufficient``, and just require every event is completed in the
order it is added to the reward manager.
"""
[docs] def __init__(self):
self.current_event_idx = 0
super().__init__()
[docs] def check_episode_end_call(
self, env, previous_observation, action, observation
):
event = self.events[self.current_event_idx]
reward = event.check(env, previous_observation, action, observation)
if event.achieved:
self.current_event_idx += 1
self._reward += reward
return self._check_complete()
def _check_complete(self) -> bool:
return self.current_event_idx == len(self.events)
[docs]class GroupedRewardManager(AbstractRewardManager):
"""Operates as a collection of reward managers.
The rewards from each reward manager are summed, and termination can be
specified by ``terminal_sufficient`` and ``terminal_required`` on each
reward manager.
Given this can be nested arbitrarily deeply (as each reward manager could
itself be a GroupedRewardManager), this enables complex specification of
groups of rewards.
"""
[docs] def __init__(self):
self.reward_managers: List[AbstractRewardManager] = []
[docs] def check_episode_end_call(
self, env, previous_observation, action, observation
) -> bool:
for reward_manager in self.reward_managers:
result = reward_manager.check_episode_end_call(
env, previous_observation, action, observation
)
# This reward manager has completed and it's sufficient so we're
# done
if reward_manager.terminal_sufficient and result:
return True
# This reward manager is required and hasn't completed, so we're
# not done
if reward_manager.terminal_required and not result:
return False
# If we've got here we've completed all required reward managers, so
# we're done
return True
[docs] def add_reward_manager(
self,
reward_manager: AbstractRewardManager,
terminal_required: bool,
terminal_sufficient: bool,
) -> None:
"""Add a new reward manager, with ``terminal_sufficient`` and
``terminal_required`` acting as for individual events.
Args:
reward_manager (RewardManager):
The reward manager to be added.
terminal_required (bool):
Whether this reward manager terminating is required for the
episode to terminate.
terminal_sufficient:
Whether this reward manager terminating is sufficient for the
episode to terminate.
"""
reward_manager.terminal_required = terminal_required
reward_manager.terminal_sufficient = terminal_sufficient
self.reward_managers.append(reward_manager)
[docs] def collect_reward(self):
reward = 0.0
for reward_manager in self.reward_managers:
reward += reward_manager.collect_reward()
return reward
[docs] def reset(self):
self._reward = 0.0
for reward_manager in self.reward_managers:
reward_manager.reset()