Skip to content
167 changes: 110 additions & 57 deletions socs/agents/hwp_pid/drivers/pid_controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
# For a deeper understanding of the pid command message syntax refer to: https://assets.omega.com/manuals/M3397.pdf
import socket
import time
from dataclasses import dataclass, field
from typing import Optional, Union


@dataclass
class DecodedResponse:
msg_type: str
msg: str
measure: Optional[Union[int, float]] = field(default=None)


class PID:
Expand Down Expand Up @@ -175,10 +184,9 @@ def tune_stop(self):
responses.append(self.send_message("*W01400000"))
responses.append(self.send_message("*R01"))
responses.append(self.send_message("*Z02"))
messages = self.return_messages(responses)
if self.verb:
print(responses)
print(messages)
print(self.return_messages(responses))

stop_params = [0.2, 0, 0]
self.set_pid(stop_params)
Expand Down Expand Up @@ -214,10 +222,9 @@ def tune_freq(self):
responses.append(self.send_message(f"*W014{self.hex_freq}"))
responses.append(self.send_message("*R01"))
responses.append(self.send_message("*Z02"))
messages = self.return_messages(responses)
if self.verb:
print(responses)
print(messages)
print(self.return_messages(responses))

tune_params = [0.2, 63, 0]
self.set_pid(tune_params)
Expand All @@ -236,11 +243,17 @@ def get_freq(self):

responses = []
responses.append(self.send_message("*X01"))
if self.verb:
print(responses)

freq = self.return_messages(responses)[0]
return freq
decoded_resp = self.return_messages(responses)[0]
attempts = 3
for attempt in range(attempts):
if self.verb:
print(responses)
print(decoded_resp)
if decoded_resp.msg_type == 'measure':
return decoded_resp.measure
elif decoded_resp.msg_type == 'error':
print(f"Error reading freq: {decoded_resp.msg}")
raise ValueError('Could not get current frequency')

def get_target(self):
"""Returns the target frequency of the CHWP.
Expand All @@ -256,12 +269,17 @@ def get_target(self):

responses = []
responses.append(self.send_message("*R01"))
target = self.return_messages(responses)[0]
if self.verb:
print(responses)
print('Setpoint = ' + str(target))

return target
decoded_resp = self.return_messages(responses)[0]
attempts = 3
for attempt in range(attempts):
if self.verb:
print(responses)
print(decoded_resp)
if decoded_resp.msg_type == 'read':
return decoded_resp.measure
elif decoded_resp.msg_type == 'error':
print(f"Error reading target: {decoded_resp.msg}")
raise ValueError('Could not get target frequency')

def get_direction(self):
"""Get the current rotation direction.
Expand All @@ -280,14 +298,17 @@ def get_direction(self):

responses = []
responses.append(self.send_message("*R02"))
direction = self.return_messages(responses)[0]
if self.verb:
if direction == 1:
print('Direction = Reverse')
elif direction == 0:
print('Direction = Forward')

return direction
decoded_resp = self.return_messages(responses)[0]
attempts = 3
for attempt in range(attempts):
if self.verb:
print(responses)
print(decoded_resp)
if decoded_resp.msg_type == 'read':
return decoded_resp.measure
elif decoded_resp.msg_type == 'error':
print(f"Error reading direction: {decoded_resp.msg}")
raise ValueError('Could not get direction')

def set_pid(self, params):
"""Sets the PID parameters of the controller.
Expand Down Expand Up @@ -366,7 +387,7 @@ def send_message(self, msg):
msg (str): Command to send to the controller.

Returns:
str: Respnose from the controller.
str: Response from the controller.

"""
for attempt in range(2):
Expand All @@ -393,7 +414,7 @@ def return_messages(self, msg):
msg (list): List of messages to decode.

Returns:
list: Decoded responses.
list: DecodedResponse

"""
return self._decode_array(msg)
Expand Down Expand Up @@ -421,40 +442,71 @@ def _decode_array(input_array):
- W02: write setpoint for pid 2 (rotation direction setpoint)
- W0C: write action type for pid 1 (how to interpret sign of (setpoint-value))
- X01: read value for pid 1 (current rotation frequency)
"?" character indicates the error messages.
The helper function goes through the raw response strings and replaces them
with their decoded values.

Args:
input_array (list): List of str messages to decode

Returns:
list: Decoded responses
list: DecodedResponse

"""
output_array = list(input_array)

for index, string in enumerate(list(input_array)):
if not isinstance(string, str):
output_array[index] = DecodedResponse(msg_type='error', msg='Unrecognized response')
continue
header = string[0]

if header == 'R':
if '?' in string:
output_array[index] = PID._decode_error(string)
elif header == 'R':
output_array[index] = PID._decode_read(string)
elif header == 'W':
output_array[index] = PID._decode_write(string)
elif header == 'E':
output_array[index] = 'PID Enabled'
output_array[index] = DecodedResponse(msg_type='enable', msg='PID Enabled')
elif header == 'D':
output_array[index] = 'PID Disabled'
output_array[index] = DecodedResponse(msg_type='disable', msg='PID Disabled')
elif header == 'P':
pass
elif header == 'G':
pass
elif header == 'X':
output_array[index] = PID._decode_measure(string)
elif header == 'Z':
output_array[index] = DecodedResponse(msg_type='reset', msg='PID Reset')
else:
pass
output_array[index] = DecodedResponse(msg_type='error', msg='Unrecognized response')

return output_array

@staticmethod
def _decode_error(string):
"""Helper function to decode error messages

Args:
string (str): Error message type string to decode

Returns:
DecodedResponse

"""
if '?+9999.' in string:
return DecodedResponse(msg_type='error', msg='Exceed Maximum Error')
elif '?43' in string:
return DecodedResponse(msg_type='error', msg='Command Error')
elif '?46' in string:
return DecodedResponse(msg_type='error', msg='Format Error')
elif '?50' in string:
return DecodedResponse(msg_type='error', msg='Parity Error')
elif '?56' in string:
return DecodedResponse(msg_type='error', msg='Serial Device Address Error')
else:
return DecodedResponse(msg_type='error', msg='Unrecognized Error')

@staticmethod
def _decode_read(string):
"""Helper function to decode "read (hex)" type response strings
Expand All @@ -476,26 +528,23 @@ def _decode_read(string):
string (str): Read (hex) type string to decode

Returns:
Decoded value
DecodedResponse

"""
if isinstance(string, str):
end_string = string.split('\r')[-1]
read_type = end_string[1:3]
else:
read_type = '00'
end_string = string.split('\r')[-1]
read_type = end_string[1:3]
# Decode target
if read_type == '01':
target = float(int(end_string[4:], 16) / 1000.)
return target
return DecodedResponse(msg_type='read', msg='Setpoint = ' + str(target), measure=target)
# Decode direction
if read_type == '02':
elif read_type == '02':
if int(end_string[4:], 16) / 1000. > 2.5:
return 1
return DecodedResponse(msg_type='read', msg='Direction = Reverse', measure=1)
else:
return 0
return DecodedResponse(msg_type='read', msg='Direction = Forward', measure=0)
else:
return 'Unrecognized Read'
return DecodedResponse(msg_type='error', msg='Unrecognized Read')

@staticmethod
def _decode_write(string):
Expand All @@ -505,18 +554,24 @@ def _decode_write(string):
string (str): Write (hex) type string to decode

Returns:
str: Decoded string
DecodedResponse

"""
write_type = string[1:]
if write_type == '01':
return 'Changed Setpoint'
if write_type == '02':
return 'Changed Direction'
if write_type == '0C':
return 'Changed Action Type'
return DecodedResponse(msg_type='write', msg='Changed Setpoint')
elif write_type == '02':
return DecodedResponse(msg_type='write', msg='Changed Direction')
elif write_type == '0C':
return DecodedResponse(msg_type='write', msg='Changed Action Type')
elif write_type == '17':
return DecodedResponse(msg_type='write', msg='Changed PID 1 P Param')
elif write_type == '18':
return DecodedResponse(msg_type='write', msg='Changed PID 1 I Param')
elif write_type == '19':
return DecodedResponse(msg_type='write', msg='Changed PID 1 D Param')
else:
return 'Unrecognized Write'
return DecodedResponse(msg_type='error', msg='Unrecognized Write')

@staticmethod
def _decode_measure(string):
Expand All @@ -529,14 +584,12 @@ def _decode_measure(string):
string (str): Read (decimal) type string to decode

Return:
float: Decoded value
DecodedReponse
"""
if isinstance(string, str):
end_string = string.split('\r')[-1]
measure_type = end_string[1:3]
else:
measure_type = '00'
end_string = string.split('\r')[-1]
measure_type = end_string[1:3]
if measure_type == '01':
return float(end_string[3:])
freq = float(end_string[3:])
return DecodedResponse(msg_type='measure', msg='Current frequency = ' + str(freq), measure=freq)
else:
return 9.999
return DecodedResponse(msg_type='error', msg='Unrecognized Measure')
12 changes: 6 additions & 6 deletions socs/testing/hwp_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ def process_pid_msg(self, data):
with self.state.lock:
# self.logger.debug(cmd)
if cmd == "*W02400000":
self.state.pid.direction = "reverse"
logger.info("Setting direction: reverse")
return "asdfl"
elif cmd == "*W02401388":
self.state.pid.direction = "forward"
logger.info("Setting direction: forward")
return "asdfl"
elif cmd == "*W02401388":
self.state.pid.direction = "reverse"
logger.info("Setting direction: reverse")
return "asdfl"
elif cmd.startswith("*W014"):
setpt = hex_str_to_dec(cmd[5:], 3)
logger.info("SETPOINT %s Hz", setpt)
Expand All @@ -192,9 +192,9 @@ def process_pid_msg(self, data):
return f"R01{PID._convert_to_hex(self.state.pid.freq_setpoint, 3)}"
elif cmd == "*R02": # Get Direction
if self.state.pid.direction == "forward":
return "1"
return "R02400000"
else:
return "0"
return "R02401388"
else:
self.logger.info("Unknown cmd: %s", cmd)
return "unknown"
Expand Down
11 changes: 8 additions & 3 deletions tests/agents/hwp_pid/test_pid_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,14 @@ def test_decode_write():
print(PID._decode_write('W02'))


def test_decode_array():
print(PID._decode_array(['R02400000']))
def test_decode_measure():
print(PID._decode_measure('X012.000'))


def test_decode_measure_unknown():
assert PID._decode_measure(['R02400000']) == 9.999
decoded_resp = PID._decode_measure('X022.000')
assert decoded_resp.msg_type == 'error'


def test_decode_array():
print(PID._decode_array(['R02400000']))
4 changes: 2 additions & 2 deletions tests/integration/test_hwp_pid_agent_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ def test_hwp_rotation_set_direction(wait_for_crossbar, hwp_emu, run_agent, clien
assert resp.status == ocs.OK
assert resp.session['op_code'] == OpCode.SUCCEEDED.value
data = client.get_state().session['data']
assert data['direction'] == '0'
assert data['direction'] == 0

resp = client.set_direction(direction='1')
assert resp.status == ocs.OK
assert resp.session['op_code'] == OpCode.SUCCEEDED.value
data = client.get_state().session['data']
assert data['direction'] == '1'
assert data['direction'] == 1


@pytest.mark.integtest
Expand Down