|
9 | 9 | "%load_ext autoreload\n", |
10 | 10 | "%autoreload 2\n", |
11 | 11 | "from autora.doc.runtime.predict_hf import Predictor\n", |
12 | | - "from autora.doc.runtime.prompts import INSTR, SYS, InstructionPrompts, SystemPrompts" |
| 12 | + "from autora.doc.runtime.prompts import PROMPTS, PromptIds" |
13 | 13 | ] |
14 | 14 | }, |
15 | 15 | { |
|
29 | 29 | "metadata": {}, |
30 | 30 | "outputs": [], |
31 | 31 | "source": [ |
32 | | - "TEST_CODE = \"\"\"\n", |
33 | | - "from sweetpea import *\n", |
34 | | - "from sweetpea.primitives import *\n", |
35 | | - "\n", |
36 | | - "number_list = [125, 132, 139, 146, 160, 167, 174, 181]\n", |
37 | | - "letter_list = ['b', 'd', 'f', 'h', 's', 'u', 'w', 'y']\n", |
38 | | - "\n", |
39 | | - "number = Factor(\"number\", number_list)\n", |
40 | | - "letter = Factor(\"letter\", letter_list)\n", |
41 | | - "task = Factor(\"task\", [\"number task\", \"letter task\", \"free choice task\"])\n", |
42 | | - "\n", |
43 | | - "\n", |
44 | | - "def is_forced_trial_switch(task):\n", |
45 | | - " return (task[-1] == \"number task\" and task[0] == \"letter task\") or \\\n", |
46 | | - " (task[-1] == \"letter task\" and task[0] == \"number task\")\n", |
47 | | - "\n", |
48 | | - "\n", |
49 | | - "def is_forced_trial_repeat(task):\n", |
50 | | - " return (task[-1] == \"number task\" and task[0] == \"number task\") or \\\n", |
51 | | - " (task[-1] == \"letter task\" and task[0] == \"letter task\")\n", |
52 | | - "\n", |
53 | | - "\n", |
54 | | - "def is_free_trial_transition(task):\n", |
55 | | - " return task[-1] != \"free choice task\" and task[0] == \"free choice task\"\n", |
56 | | - "\n", |
57 | | - "\n", |
58 | | - "def is_free_trial_repeat(task):\n", |
59 | | - " return task[-1] == \"free choice task\" and task[0] == \"free choice task\"\n", |
60 | | - "\n", |
61 | | - "\n", |
62 | | - "def is_not_relevant_transition(task):\n", |
63 | | - " return not (is_forced_trial_repeat(task) or is_forced_trial_switch(task) or is_free_trial_repeat(\n", |
64 | | - " task) or is_free_trial_transition(task))\n", |
65 | | - "\n", |
66 | | - "\n", |
67 | | - "transit = Factor(\"task transition\", [\n", |
68 | | - " DerivedLevel(\"forced switch\", transition(is_forced_trial_switch, [task]), 3),\n", |
69 | | - " DerivedLevel(\"forced repeat\", transition(is_forced_trial_repeat, [task])),\n", |
70 | | - " DerivedLevel(\"free transition\", transition(is_free_trial_transition, [task]), 4),\n", |
71 | | - " DerivedLevel(\"free repeat\", transition(is_free_trial_repeat, [task]), 4),\n", |
72 | | - " DerivedLevel(\"forced first\", transition(is_not_relevant_transition, [task]), 4)\n", |
73 | | - "])\n", |
74 | | - "design = [letter, number, task, transit]\n", |
75 | | - "crossing = [[letter], [number], [transit]]\n", |
76 | | - "constraints = [MinimumTrials(256)]\n", |
77 | | - "\n", |
78 | | - "block = MultiCrossBlock(design, crossing, constraints)\n", |
79 | | - "\n", |
80 | | - "experiment = synthesize_trials(block, 1)\n", |
81 | | - "\n", |
82 | | - "save_experiments_csv(block, experiment, 'code_1_sequences/seq')\n", |
| 32 | + "TEST_VAR_CODE = \"\"\"\n", |
| 33 | + "iv = Variable(name=\"x\", value_range=(0, 2 * np.pi), allowed_values=np.linspace(0, 2 * np.pi, 30))\n", |
| 34 | + "dv = Variable(name=\"y\", type=ValueType.REAL)\n", |
| 35 | + "variables = VariableCollection(independent_variables=[iv], dependent_variables=[dv])\n", |
83 | 36 | "\"\"\"" |
84 | 37 | ] |
85 | 38 | }, |
|
89 | 42 | "metadata": {}, |
90 | 43 | "outputs": [], |
91 | 44 | "source": [ |
92 | | - "output = pred.predict(\n", |
93 | | - " SYS[SystemPrompts.SYS_1],\n", |
94 | | - " INSTR[InstructionPrompts.INSTR_SWEETP_EXAMPLE],\n", |
95 | | - " [TEST_CODE],\n", |
96 | | - " temperature=0.05,\n", |
97 | | - " top_k=10,\n", |
98 | | - " num_ret_seq=3,\n", |
99 | | - ")[0]\n", |
100 | | - "for i, o in enumerate(output):\n", |
101 | | - " print(f\"******** Output {i} ********\\n{o}*************\\n\")" |
| 45 | + "def test(promptid, code):\n", |
| 46 | + " output = pred.predict(\n", |
| 47 | + " PROMPTS[promptid],\n", |
| 48 | + " [code],\n", |
| 49 | + " do_sample=0,\n", |
| 50 | + " max_length=800,\n", |
| 51 | + " temperature=0.05,\n", |
| 52 | + " top_k=10,\n", |
| 53 | + " num_ret_seq=1,\n", |
| 54 | + " )[0]\n", |
| 55 | + " for i, o in enumerate(output):\n", |
| 56 | + " print(f\"{promptid}\\n******* Output {i} ********\\n{o}\\n*************\\n\")" |
| 57 | + ] |
| 58 | + }, |
| 59 | + { |
| 60 | + "cell_type": "code", |
| 61 | + "execution_count": null, |
| 62 | + "metadata": {}, |
| 63 | + "outputs": [], |
| 64 | + "source": [ |
| 65 | + "test(PromptIds.AUTORA_VARS_ZEROSHOT, TEST_VAR_CODE)" |
| 66 | + ] |
| 67 | + }, |
| 68 | + { |
| 69 | + "cell_type": "code", |
| 70 | + "execution_count": null, |
| 71 | + "metadata": {}, |
| 72 | + "outputs": [], |
| 73 | + "source": [ |
| 74 | + "test(PromptIds.AUTORA_VARS_ONESHOT, TEST_VAR_CODE)" |
102 | 75 | ] |
103 | 76 | } |
104 | 77 | ], |
|
118 | 91 | "name": "python", |
119 | 92 | "nbconvert_exporter": "python", |
120 | 93 | "pygments_lexer": "ipython3", |
121 | | - "version": "3.8.18" |
| 94 | + "version": "3.11.5" |
122 | 95 | } |
123 | 96 | }, |
124 | 97 | "nbformat": 4, |
|
0 commit comments