Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
tutorial for new DataSource (#420)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #420

This is a tutorial to create a DataSource (using the new data handler)

Differential Revision: D14554266

fbshipit-source-id: c35f7e0d88edbe7e291544d63530d9ccf7400b9a
  • Loading branch information
Titousensei authored and facebook-github-bot committed Mar 27, 2019
1 parent 6ec230d commit 6d4abf4
Show file tree
Hide file tree
Showing 3 changed files with 487 additions and 0 deletions.
169 changes: 169 additions & 0 deletions demo/datasource/source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import os
from random import Random
from typing import Dict, List, Optional, Type

from pytext.data.sources.data_source import RootDataSource


def load_vocab(file_path):
"""
Given a file, prepare the vocab dictionary where each line is the value and
(line_no - 1) is the key
"""
vocab = {}
with open(file_path, "r") as file_contents:
for idx, word in enumerate(file_contents):
vocab[str(idx)] = word.strip()
return vocab


class LookupReader:
def __init__(self, file_path, vocab):
self.file_path = file_path
self.vocab = vocab

def __iter__(self):
with open(self.file_path, "r") as reader:
for line in reader:
yield " ".join(
self.vocab.get(s.strip(), "__UNK__")
# ATIS every row starts/ends with BOS/EOS: remove them
for s in line.split()[1:-1]
)


class AtisIntentDataSource(RootDataSource):
"""
DataSource which loads queries and intent from the ATIS dataset.
The simple usage is to provide the path the unzipped atis directory, and
it will use the default filenames and parameters.
ATIS dataset has the following characteristics:
- words are stored in a dict file.
- content files contain only indices of words.
- there's no eval set: we'll take random rows from the train set.
- all queries start with BOS (Beginning Of Sentence) and end with EOS
(End Of Sentence), which we'll remove.
"""

class Config(RootDataSource.Config):
path: str = "."
field_names: Optional[List[str]] = None
validation_split: Optional[float] = 0.25
random_seed: Optional[int] = None
# Filenames can be overridden if necessary
intent_filename: Optional[str] = "atis.dict.intent.csv"
vocab_filename: Optional[str] = "atis.dict.vocab.csv"
test_queries_filename: Optional[str] = "atis.test.query.csv"
test_intent_filename: Optional[str] = "atis.test.intent.csv"
train_queries_filename: Optional[str] = "atis.train.query.csv"
train_intent_filename: Optional[str] = "atis.train.intent.csv"

# Config mimics the constructor
# This will be the default in future pytext.
@classmethod
def from_config(cls, config: Config, schema: Dict[str, Type]):
return cls(schema=schema, **config._asdict())

def __init__(
self,
path="my_directory",
field_names=None,
validation_split=0.25,
random_seed=None,
intent_filename="atis.dict.intent.csv",
vocab_filename="atis.dict.vocab.csv",
test_queries_filename="atis.test.query.csv",
test_intent_filename="atis.test.intent.csv",
train_queries_filename="atis.train.query.csv",
train_intent_filename="atis.train.intent.csv",
**kwargs,
):
super().__init__(**kwargs)

assert (
len(field_names or []) == 2
), "AtisIntentDataSource only handles 2 field_names: {}".format(field_names)

self.random_seed = random_seed
self.validation_split = validation_split

# Load the vocab dict in memory and provide a lookup function
# This allows other applications to
self.words = load_vocab(os.path.join(path, vocab_filename))
self.intents = load_vocab(os.path.join(path, intent_filename))

self.query_field = field_names[0]
self.intent_field = field_names[1]

self.test_queries_filepath = os.path.join(path, test_queries_filename)
self.test_intent_filepath = os.path.join(path, test_intent_filename)
self.train_queries_filepath = os.path.join(path, train_queries_filename)
self.train_intent_filepath = os.path.join(path, train_intent_filename)

def _selector(self, select_eval):
"""
This selector ensures that the same pseudo-random sequence is
always the same from the beginning. The `select_eval` parameter
guarantees that the training set and eval set are exact complements.
"""
rng = Random(self.random_seed)

def fn():
return select_eval ^ (rng.random() >= self.validation_split)

return fn

def _iter_rows(self, query_reader, intent_reader, select_fn=lambda: True):
for query_str, intent_str in zip(query_reader, intent_reader):
if select_fn():
yield {self.query_field: query_str, self.intent_field: intent_str}

def _iter_raw_train(self):
return iter(
self._iter_rows(
query_reader=LookupReader(self.train_queries_filepath, self.words),
intent_reader=LookupReader(self.train_intent_filepath, self.intent),
select_fn=self._selector(select_eval=False),
)
)

def _iter_raw_eval(self):
return iter(
self._iter_rows(
query_reader=LookupReader(self.train_queries_filepath, self.words),
intent_reader=LookupReader(self.train_intent_filepath, self.intent),
select_fn=self._selector(select_eval=True),
)
)

def _iter_raw_test(self):
return iter(
self._iter_rows(
query_reader=LookupReader(self.test_queries_filepath, self.words),
intent_reader=LookupReader(self.test_intent_filepath, self.intent),
)
)


# Need to declare str type for this source.
# This will be included by default in future pytext.
@AtisIntentDataSource.register_type(str)
def load_text(s):
return str(s)


if __name__ == "__main__":
import sys

src = AtisIntentDataSource(sys.argv[1], field_names=["query", "intent"], schema={})
for row in src._iter_raw_train():
print("TRAIN", row)
for row in src._iter_raw_eval():
print("EVAL", row)
for row in src._iter_raw_test():
print("TEST", row)
Loading

0 comments on commit 6d4abf4

Please sign in to comment.