Skip to content

Commit

Permalink
finalize logic game training pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
frankaging committed Feb 14, 2023
1 parent 7d7e0e0 commit b956bc6
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 63 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
wandb/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
Binary file modified logic_data/cfg_all.pkl
Binary file not shown.
Binary file modified logic_data/cfg_test.pkl
Binary file not shown.
Binary file modified logic_data/cfg_train.pkl
Binary file not shown.
4 changes: 2 additions & 2 deletions logic_data/decoder_config.json
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
{
"activation_function": "gelu_new",
"attn_pdrop": 0.1,
"bos_token_id": 50256,
"bos_token_id": 6,
"embd_pdrop": 0.1,
"eos_token_id": 50256,
"eos_token_id": 7,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"model_type": "gpt2",
Expand Down
Binary file removed logic_data/dev_data.pkl
Binary file not shown.
79 changes: 35 additions & 44 deletions logic_data/in_context_learning_benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": 270,
"execution_count": null,
"id": "a6915748",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -138,9 +138,9 @@
"assert len(primitive_clauses) == 192\n",
"training_clauses = primitive_clauses[:180]\n",
"eval_clauses = primitive_clauses[180:]\n",
"pickle.dump(primitive_clauses, open(\"./logic_data/cfg_all.pkl\", 'wb'))\n",
"pickle.dump(training_clauses, open(\"./logic_data/cfg_train.pkl\", 'wb'))\n",
"pickle.dump(eval_clauses, open(\"./logic_data/cfg_test.pkl\", 'wb'))\n",
"pickle.dump(primitive_clauses, open(\"./cfg_all.pkl\", 'wb'))\n",
"pickle.dump(training_clauses, open(\"./cfg_train.pkl\", 'wb'))\n",
"pickle.dump(eval_clauses, open(\"./cfg_test.pkl\", 'wb'))\n",
"\n",
"def parse(clauses):\n",
" conjs = re.split(r\"\\s*(?:and|or)\\s*\", clauses)\n",
Expand Down Expand Up @@ -220,7 +220,12 @@
" elif (d[\"EQ\"] == \"==\" and d[\"VAL\"] == False) or \\\n",
" (d[\"EQ\"] == \"!=\" and d[\"VAL\"] == True):\n",
" assert value_assignment[d['L']] != value_assignment[d['R']]\n",
" \n",
" \n",
" if \"and\" in clauses:\n",
" assert final_value == (data[0][\"VAL\"] and data[1][\"VAL\"])\n",
" elif \"or\" in clauses:\n",
" assert final_value == (data[0][\"VAL\"] or data[1][\"VAL\"])\n",
" \n",
" return value_assignment\n",
" # we need to assert check\n",
"\n",
Expand All @@ -242,24 +247,16 @@
},
{
"cell_type": "code",
"execution_count": 271,
"execution_count": null,
"id": "7a865e28",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [10:18<00:00, 32.33it/s]\n"
]
}
],
"outputs": [],
"source": [
"seed = 42\n",
"n_training_examples = 20000\n",
"n_eval_examples = 1000\n",
"n_test_examples = 1000\n",
"\n",
"n_training_program = 5\n",
"\n",
"FALSE_TOKEN_ID = 0\n",
"TRUE_TOKEN_ID = 1\n",
Expand All @@ -273,8 +270,10 @@
"vocab = set([i for i in range(10, 50257)]) # reserve the first 10 for special tokens.\n",
"n_fewshot = 6\n",
"n_examples = n_fewshot + 1\n",
"training_clauses = pickle.load(open(\"./logic_data/cfg_train.pkl\", 'rb'))\n",
"eval_clauses = pickle.load(open(\"./logic_data/cfg_test.pkl\", 'rb'))\n",
"training_clauses = pickle.load(open(\"./cfg_train.pkl\", 'rb'))\n",
"eval_clauses = pickle.load(open(\"./cfg_test.pkl\", 'rb'))\n",
"if n_training_program is not None:\n",
" training_clauses = random.sample(training_clauses, k=n_training_program)\n",
"\n",
"all_train_input_ids = []\n",
"all_train_output_ids = []\n",
Expand Down Expand Up @@ -304,18 +303,10 @@
},
{
"cell_type": "code",
"execution_count": 272,
"execution_count": null,
"id": "8dd755b4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:30<00:00, 32.59it/s]\n"
]
}
],
"outputs": [],
"source": [
"all_eval_input_ids = []\n",
"all_eval_output_ids = []\n",
Expand Down Expand Up @@ -345,18 +336,10 @@
},
{
"cell_type": "code",
"execution_count": 273,
"execution_count": null,
"id": "c7c2ebcf",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:31<00:00, 31.95it/s]\n"
]
}
],
"outputs": [],
"source": [
"all_test_input_ids = []\n",
"all_test_output_ids = []\n",
Expand Down Expand Up @@ -386,7 +369,7 @@
},
{
"cell_type": "code",
"execution_count": 274,
"execution_count": null,
"id": "697df13d",
"metadata": {},
"outputs": [],
Expand All @@ -410,15 +393,23 @@
},
{
"cell_type": "code",
"execution_count": 275,
"execution_count": null,
"id": "7de80f65",
"metadata": {},
"outputs": [],
"source": [
"pickle.dump(train_data, open(\"./logic_data/train_data.pkl\", 'wb'))\n",
"pickle.dump(dev_data, open(\"./logic_data/dev_data.pkl\", 'wb'))\n",
"pickle.dump(test_data, open(\"./logic_data/test_data.pkl\", 'wb'))"
"pickle.dump(train_data, open(f\"./train_data.n_rule.{n_training_program}.n_shot.{n_fewshot}.pkl\", 'wb'))\n",
"pickle.dump(dev_data, open(f\"./dev_data.n_rule.{n_training_program}.n_shot.{n_fewshot}.pkl\", 'wb'))\n",
"pickle.dump(test_data, open(f\"./test_data.pkl\", 'wb'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "556950d8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -437,7 +428,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.9.13"
}
},
"nbformat": 4,
Expand Down
Binary file removed logic_data/test_data.pkl
Binary file not shown.
Binary file removed logic_data/train_data.pkl
Binary file not shown.
Loading

0 comments on commit b956bc6

Please sign in to comment.