# Copyright (c) Facebook, Inc. and its affiliates.
import json
import os
import re
from collections import defaultdict
from functools import lru_cache
from typing import List
from urllib.parse import unquote
import pkg_resources
try:
import inflect
import stanza
PREPROCESSING_ALLOWED = True
import_error = None
except ImportError as error: # noqa
PREPROCESSING_ALLOWED = False
import_error = error
DATA_DIR_PATH = pkg_resources.resource_filename("nle", "minihack/dat")
EXCEPTIONS = (
"floor of a room",
"agent",
"staircase up",
)
[docs]class TextProcessor:
"""Base class for modeling relations between an object and subject."""
[docs] def __init__(self):
# Will only do it the first time
stanza.download("en")
self.nlp = stanza.Pipeline(lang="en", processors="tokenize,mwt,pos")
self.inflect = inflect.engine()
[docs] @lru_cache(maxsize=None)
def preprocess(self, input_str: str) -> str:
# Removes the brackets and non-letter charachters
text = re.sub(r"\([^)]*\)", "", input_str)
pattern = re.compile(r"[^a-zA-Z]+")
text = pattern.sub(" ", text)
# Remove trailing whitespaces
text = re.sub(r"\w[ ]{2,}\w", " ", text)
return text.strip()
[docs] @lru_cache(maxsize=None)
def process(self, input_str: str) -> str:
input_str = self.preprocess(input_str)
# First find nouns in phrase
result = self.nlp(input_str)
nouns = [
word.text
for sent in result.sentences
for word in sent.words
if word.upos in {"NOUN", "PROPN"}
]
if not nouns:
return input_str
# Pick last noun in input
noun = nouns[-1]
# Singularise the noun - returns False if the word is alread singular
singular = self.inflect.singular_noun(noun.lower())
if not singular:
return nouns[-1].lower()
else:
return singular
[docs]class NetHackWiki:
"""A class representing Nethack Wiki Data - pages and links between them.
Args:
raw_wiki_file_name (str):
The path to the raw file of NetHack wiki. The raw file can be
downloaded using the `get_nhwiki_data.sh` script located in
`minihack/scripts`.
processed_wiki_file_name (str):
The path to the processed file of NetHack wiki. The processing
is performed in the `__init__` function of this classed.
save_processed_json (bool):
Whether to save the processed json file of the wiki. Only
considered when a raw wiki file is passed. Defaults to True.
ignore_inpage_anchors (bool):
Whether to ingnore in-page anchors. Defaults to True.
preprocess_input (bool):
Whether to perform a preprocessing on wiki data. Defaults to True.
exceptions (Tuple[str] or None):
Name of entities in screen descriptions that are ingored. If None,
there are no exceptions. Defaults to None.
"""
[docs] def __init__(
self,
raw_wiki_file_name: str,
processed_wiki_file_name: str,
save_processed_json: bool = True,
ignore_inpage_anchors: bool = True,
preprocess_input: bool = True,
exceptions: tuple = None,
) -> None:
if os.path.isfile(processed_wiki_file_name):
with open(processed_wiki_file_name, "r") as json_file:
self.wiki = json.load(json_file)
elif os.path.isfile(raw_wiki_file_name):
raw_json = load_json(raw_wiki_file_name)
self.wiki = process_json(
raw_json, ignore_inpage_anchors=ignore_inpage_anchors
)
if save_processed_json:
with open(processed_wiki_file_name, "w+") as json_file:
json.dump(self.wiki, json_file)
else:
raise ValueError(
"""One of `raw_wiki_file_name` or `processed_wiki_file_name`
must be supplied as argument and be a file. Try using
`nle/minihack/scripts/get_nhwiki_data.sh` to download the
data."""
)
self.exceptions = exceptions if exceptions is not None else EXCEPTIONS
self.preprocess_input = preprocess_input
if preprocess_input:
if PREPROCESSING_ALLOWED:
self.text_processor = TextProcessor()
else:
print(
"To perform text preprocessing, `inflect` and `stanza`"
f"must be installed. See {import_error} for more information"
)
self.preprocess_input = False
[docs] def get_page_text(self, page: str) -> str:
"""Get the text of a page.
Args:
page (str): The page name.
Returns:
str: The text of the page.
"""
if page in self.exceptions:
return ""
if self.preprocess_input:
page = self.text_processor.process(page)
return self.wiki.get(page, {}).get("text", "")
[docs] def get_page_data(self, page: str) -> dict:
"""Get the data of a page.
Args:
page (str): The page name.
Returns:
dict: The page data as a dict.
"""
if page in self.exceptions:
return {}
if self.preprocess_input:
page = self.text_processor.process(page)
return self.wiki.get(page, {})
[docs]def load_json(file_name: str) -> list:
"""Load a file containing a json object per line into a list of dicts."""
with open(file_name, "r") as json_file:
input_json = []
for line in json_file:
input_json.append(json.loads(line))
return input_json
[docs]def process_json(wiki_json: List[dict], ignore_inpage_anchors) -> dict:
"""Process a list of json pages of the wiki into one dict of all pages."""
result: dict = {}
redirects = {}
result["_global_counts"] = defaultdict(int)
def href_normalise(x: str):
result = unquote(x.lower())
if ignore_inpage_anchors:
result = result.split("#")[0]
return result.replace("_", " ")
for page in wiki_json:
relevant_page_info = dict(
title=page["wikipedia_title"].lower(),
length=len("".join(page["text"])),
categories=page["categories"].split(","),
raw_text="".join(page["text"]),
text=clean_page_text(page["page_data"]),
)
# noqa: E731
relevant_page_info["anchors"] = [
dict(
text=anchor["text"].lower(),
page=href_normalise(anchor.get("title", anchor.get("href"))),
start=anchor["start"],
)
for anchor in page["anchors"]
]
redirect_anchors = [
anchor
for anchor in page["anchors"]
if anchor.get("title")
and href_normalise(anchor["href"])
!= href_normalise(anchor["title"])
]
redirects.update(
{
href_normalise(anchor["href"]): href_normalise(anchor["title"])
for anchor in redirect_anchors
}
)
unique_anchors: dict = defaultdict(int)
for anchor in relevant_page_info["anchors"]:
unique_anchors[anchor["page"]] += 1
result["_global_counts"][anchor["page"]] += 1
relevant_page_info["unique_anchors"] = dict(unique_anchors)
result[relevant_page_info["title"]] = relevant_page_info
for alias, page in redirects.items():
result[alias] = result[page]
return result
[docs]def clean_page_text(text: List[str]) -> str:
"""Clean Markdown text to make it more passable into an NLP model.
This is currently very basic, and more advanced parsing could be employed
if necessary."""
return re.sub(r"[^a-zA-Z0-9_\s\.]", "", ",".join(text))