forked from facebookarchive/bAbI-tasks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
babi-tasks
executable file
·76 lines (66 loc) · 1.97 KB
/
babi-tasks
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env th
-- Copyright (c) 2015-present, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
local tablex = require 'pl.tablex'
local utilities = require 'babi.utilities'
BABI_HOME = os.getenv('BABI_HOME') or utilities.babi_home()
local function generate(task_name, number, output, user_config)
local task = require('babi.tasks.' .. task_name)
local config = task.DEFAULT_CONFIG or {}
config = tablex.merge(config, user_config, true)
math.randomseed(os.time())
for i = 1, number do
local story
repeat
story = task:generate(config)
until story
output:write(story .. '\n')
end
end
assert(#arg > 0, 'Usage: generate.lua task [number] [output_file] [--option value ...]')
local task_names = {
[1]='WhereIsActor',
[2]='WhereIsObject',
[3]='WhereWasObject',
[4]='IsDir',
[5]='WhoWhatGave',
[6]='IsActorThere',
[7]='Counting',
[8]='Listing',
[9]='Negation',
[10]='Indefinite',
[11]='BasicCoreference',
[12]='Conjunction',
[13]='CompoundCoreference',
[14]='Time',
[15]='Deduction',
[16]='Induction',
[17]='PositionalReasoning',
[18]='Size',
[19]='PathFinding',
[20]='Motivations'
}
local generate_arg = {
tonumber(arg[1]) and assert(task_names[tonumber(arg[1])]) or arg[1],
1,
io.stdout,
{}
}
for i = 2, #arg do
if arg[i]:sub(1,2) == '--' then
for j = i, #arg, 2 do
local flag = arg[j]:sub(3, -1):gsub('-', '_')
generate_arg[4][flag] = tonumber(arg[j + 1]) or arg[j + 1]
end
break
elseif i == 2 then
generate_arg[2] = assert(tonumber(arg[2]))
elseif i == 3 then
generate_arg[3] = assert(io.open(arg[3], 'a'))
end
end
generate(unpack(generate_arg))