Skip to content

Commit c3045cf

Browse files
committed
Initial Arduino test
1 parent 9570cb3 commit c3045cf

File tree

2 files changed

+238
-0
lines changed

2 files changed

+238
-0
lines changed

kit_test/arduino_test.py

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""
2+
Arduino test.
3+
4+
To run this with an arduino, first connect the Arduino test shield to the Arduino board.
5+
The test will:
6+
- Detect an Arduino board by its USB VID and PID.
7+
- Record the board's serial number.
8+
- Optional: Record the board's asset tag.
9+
- Flash the Arduino board with a test sketch.
10+
- Record the test sketch's outputs.
11+
- Flash the Arduino board with the stock firmware.
12+
"""
13+
import argparse
14+
import csv
15+
import logging
16+
import os
17+
import subprocess
18+
import sys
19+
import textwrap
20+
from pathlib import Path
21+
from shutil import which
22+
from tempfile import NamedTemporaryFile
23+
from typing import Any, Dict, Optional
24+
25+
import serial
26+
27+
from .arduino_binaries import STOCK_FW, TEST_FW
28+
from .hal import VidPid, discover_boards
29+
30+
logger = logging.getLogger("arduino_test")
31+
32+
BAUDRATE = 19200 # NOTE: This needs to match the baudrate in the test sketch
33+
SUPPORTED_VID_PIDS = [
34+
VidPid(0x2341, 0x0043), # Arduino Uno rev 3
35+
VidPid(0x2A03, 0x0043), # Arduino Uno rev 3
36+
VidPid(0x1A86, 0x7523), # Uno
37+
VidPid(0x10C4, 0xEA60), # Ruggeduino
38+
VidPid(0x16D0, 0x0613), # Ruggeduino
39+
]
40+
41+
42+
def get_avrdude_path() -> Path:
43+
"""Get the path to avrdude."""
44+
if sys.platform.startswith('win'):
45+
from avrdude_windows import get_avrdude_path
46+
47+
return Path(get_avrdude_path())
48+
else:
49+
avrdude_path = which('avrdude')
50+
if avrdude_path is None:
51+
raise FileNotFoundError("avrdude not found in PATH")
52+
return Path(avrdude_path)
53+
54+
55+
def flash_arduino(avrdude: Path, serial_port: str, sketch_path: Path) -> None:
56+
"""Flash the Arduino board with a sketch binary."""
57+
try:
58+
subprocess.check_call([
59+
str(avrdude), "-p", "atmega328p", "-c", "arduino",
60+
"-P", serial_port, "-D", "-U",
61+
f"flash:w:{sketch_path!s}:i"
62+
])
63+
except subprocess.CalledProcessError as e:
64+
logger.error(f"Failed to flash Arduino: {e}")
65+
raise AssertionError("Failed to flash Arduino") from e
66+
67+
68+
def parse_test_output(test_output: str, results: Dict[str, Any]) -> None:
69+
"""Parse the test output from the Arduino."""
70+
current_test = 1
71+
72+
lines = test_output.splitlines()
73+
lines_iter = iter(lines)
74+
for line in lines_iter:
75+
line = line.strip()
76+
if not line:
77+
continue
78+
79+
if line.startswith("TEST"):
80+
test_name = line.split()[1]
81+
try:
82+
test_num = int(test_name)
83+
except ValueError:
84+
continue
85+
86+
assert test_num == current_test, f"Missing test {current_test}"
87+
current_test += 1
88+
89+
# TODO
90+
if test_num == 1:
91+
pass
92+
elif test_num > 1 and test_num < 6:
93+
94+
pass
95+
elif test_num >= 6:
96+
pass
97+
98+
99+
def test_arduino(
100+
output_writer: csv.DictWriter,
101+
collect_asset: bool,
102+
avrdude: Path,
103+
test_sketch_hex: Path,
104+
stock_fw_hex: Path,
105+
) -> None:
106+
"""Test an arduino."""
107+
results: Dict[str, Any] = {}
108+
serial_port: Optional[serial.Serial] = None
109+
110+
# Find arduino port
111+
ports = discover_boards(SUPPORTED_VID_PIDS)
112+
if len(ports) == 0:
113+
logger.error("No arduinos found.")
114+
return
115+
116+
arduino = ports[0]
117+
serial_num = arduino.identity.asset_tag
118+
119+
try:
120+
results['serial'] = serial_num
121+
results['passed'] = False # default to failure
122+
if collect_asset:
123+
asset_tag = input("Enter the asset tag: ")
124+
results['asset'] = asset_tag
125+
126+
# Flash arduino with test sketch
127+
flash_arduino(avrdude, arduino.port, test_sketch_hex)
128+
129+
logger.info(f"Opening serial port {arduino.port}")
130+
serial_port = serial.Serial(
131+
port=arduino.port,
132+
baudrate=BAUDRATE,
133+
timeout=30,
134+
)
135+
logger.info(f"Flashed {test_sketch_hex} to {arduino.port}")
136+
137+
try:
138+
test_output = serial_port.read_until(b'TEST COMPLETE\n').decode('utf-8')
139+
test_summary = serial_port.readline().decode('utf-8').strip()
140+
except serial.SerialTimeoutException:
141+
logger.error("Timed out waiting for test output")
142+
raise AssertionError("Timed out waiting for test output")
143+
finally:
144+
serial_port.close()
145+
146+
parse_test_output(test_output, results)
147+
# Test summary only contains content when there are failures
148+
assert test_summary == "", f"Test failed: {test_summary}"
149+
150+
# Flash arduino with stock firmware
151+
flash_arduino(avrdude, arduino.port, stock_fw_hex)
152+
153+
logger.info("Board passed")
154+
results['passed'] = True
155+
finally:
156+
output_writer.writerow(results)
157+
if serial_port is not None:
158+
serial_port.close()
159+
160+
161+
def main(args: argparse.Namespace) -> None:
162+
"""Main function for the arduino test."""
163+
new_log = True
164+
fieldnames = ['asset', 'serial', 'passed']
165+
166+
try:
167+
avrdude = get_avrdude_path()
168+
except FileNotFoundError:
169+
logger.error(
170+
"avrdude not found in PATH, "
171+
"please install the avrdude package from your package manager.")
172+
sys.exit(1)
173+
174+
if not args.test_hex.is_file():
175+
logger.error(f"Test firmware not found: {args.test_hex}")
176+
sys.exit(1)
177+
if not args.stock_fw_hex.is_file():
178+
logger.error(f"Stock firmware not found: {args.stock_fw_hex}")
179+
sys.exit(1)
180+
181+
if args.log:
182+
logfile = args.log
183+
if os.path.exists(logfile):
184+
new_log = False
185+
else:
186+
logfile = NamedTemporaryFile(delete=False).name
187+
188+
with open(logfile, 'a', newline='') as csvfile:
189+
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
190+
if new_log:
191+
writer.writeheader()
192+
193+
while True:
194+
try:
195+
test_arduino(
196+
writer,
197+
args.collect_asset,
198+
avrdude,
199+
args.test_hex,
200+
args.stock_fw_hex,
201+
)
202+
except AssertionError as e:
203+
logger.error(f"Test failed: {e}")
204+
205+
result = input("Test another arduino? [Y/n]") or 'y'
206+
if result.lower() != 'y':
207+
break
208+
209+
logger.info(f"Test results saved to {logfile}")
210+
211+
212+
def create_subparser(subparsers: argparse._SubParsersAction) -> None:
213+
"""Arduino test command parser."""
214+
parser = subparsers.add_parser(
215+
"arduino",
216+
formatter_class=argparse.RawDescriptionHelpFormatter,
217+
description=textwrap.dedent(__doc__),
218+
help="Test an arduino. Requires the Arduino test shield.",
219+
)
220+
221+
parser.add_argument('--log', default=None, help='A CSV file to save test results to.')
222+
parser.add_argument('--collect-asset', action='store_true',
223+
help='Collect the asset tag from the Arduino board.')
224+
parser.add_argument(
225+
'--test-hex', type=Path, default=TEST_FW,
226+
help=(
227+
'The compiled hex file of the Arduino test sketch. '
228+
'Defaults to the packaged test firmware.'
229+
))
230+
parser.add_argument(
231+
'--stock-fw-hex', type=Path, default=STOCK_FW,
232+
help=(
233+
'The compiled hex file of the Arduino firmware to leave the arduino with. '
234+
'Defaults to the packaged stock firmware.'
235+
))
236+
237+
parser.set_defaults(func=main)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dynamic = ["version"]
2020
requires-python = ">=3.8"
2121
dependencies = [
2222
"pyserial >=3,<4",
23+
"avrdude-windows ==7.1.0; sys_platform == 'win32'",
2324
]
2425

2526
[project.optional-dependencies]

0 commit comments

Comments
 (0)