-
Notifications
You must be signed in to change notification settings - Fork 36
/
utils.py
135 lines (115 loc) · 3.64 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, Normalize
from random import choice, randint, sample, shuffle, uniform
from dsl import *
global rng
rng = []
def unifint(
diff_lb: float,
diff_ub: float,
bounds: Tuple[int, int]
) -> int:
"""
diff_lb: lower bound for difficulty, must be in range [0, diff_ub]
diff_ub: upper bound for difficulty, must be in range [diff_lb, 1]
bounds: interval [a, b] determining the integer values that can be sampled
"""
a, b = bounds
d = uniform(diff_lb, diff_ub)
global rng
rng.append(d)
return min(max(a, round(a + (b - a) * d)), b)
def is_grid(
grid: Any
) -> bool:
"""
returns True if and only if argument is a valid grid
"""
if not isinstance(grid, tuple):
return False
if not 0 < len(grid) <= 30:
return False
if not all(isinstance(r, tuple) for r in grid):
return False
if not all(0 < len(r) <= 30 for r in grid):
return False
if not len(set(len(r) for r in grid)) == 1:
return False
if not all(all(isinstance(x, int) for x in r) for r in grid):
return False
if not all(all(0 <= x <= 9 for x in r) for r in grid):
return False
return True
def strip_prefix(
string: str,
prefix: str
) -> str:
"""
removes prefix
"""
return string[len(prefix):]
def format_grid(
grid: List[List[int]]
) -> Grid:
"""
grid type casting
"""
return tuple(tuple(row) for row in grid)
def format_example(
example: dict
) -> dict:
"""
example data type
"""
return {
'input': format_grid(example['input']),
'output': format_grid(example['output'])
}
def format_task(
task: dict
) -> dict:
"""
task data type
"""
return {
'train': [format_example(example) for example in task['train']],
'test': [format_example(example) for example in task['test']]
}
def plot_task(
task: List[dict],
title: str = None
) -> None:
"""
displays a task
"""
cmap = ListedColormap([
'#000', '#0074D9', '#FF4136', '#2ECC40', '#FFDC00',
'#AAAAAA', '#F012BE', '#FF851B', '#7FDBFF', '#870C25'
])
norm = Normalize(vmin=0, vmax=9)
args = {'cmap': cmap, 'norm': norm}
height = 2
width = len(task)
figure_size = (width * 3, height * 3)
figure, axes = plt.subplots(height, width, figsize=figure_size)
for column, example in enumerate(task):
axes[0, column].imshow(example['input'], **args)
axes[1, column].imshow(example['output'], **args)
axes[0, column].axis('off')
axes[1, column].axis('off')
if title is not None:
figure.suptitle(title, fontsize=20)
plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
def fix_bugs(
dataset: dict
) -> None:
"""
fixes bugs in the original ARC training dataset
"""
dataset['a8d7556c']['train'][2]['output'] = fill(dataset['a8d7556c']['train'][2]['output'], 2, {(8, 12), (9, 12)})
dataset['6cf79266']['train'][2]['output'] = fill(dataset['6cf79266']['train'][2]['output'], 1, {(6, 17), (7, 17), (8, 15), (8, 16), (8, 17)})
dataset['469497ad']['train'][1]['output'] = fill(dataset['469497ad']['train'][1]['output'], 7, {(5, 12), (5, 13), (5, 14)})
dataset['9edfc990']['train'][1]['output'] = fill(dataset['9edfc990']['train'][1]['output'], 1, {(6, 13)})
dataset['e5062a87']['train'][1]['output'] = fill(dataset['e5062a87']['train'][1]['output'], 2, {(1, 3), (1, 4), (1, 5), (1, 6)})
dataset['e5062a87']['train'][0]['output'] = fill(dataset['e5062a87']['train'][0]['output'], 2, {(5, 2), (6, 3), (3, 6), (4, 7)})