Skip to content

Commit cb8985e

Browse files
authored
fix: notification issues (#309)
* fix: notification issues * refactor: add gatt connection api to simplify gatt processing Closes #308
1 parent 06dc8f1 commit cb8985e

File tree

10 files changed

+521
-212
lines changed

10 files changed

+521
-212
lines changed

examples/apps/src/ble_bas_peripheral.rs

Lines changed: 30 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ where
5252

5353
let _ = join(ble_task(runner), async {
5454
loop {
55-
match advertise("Trouble Example", &mut peripheral).await {
55+
match advertise("Trouble Example", &mut peripheral, &server).await {
5656
Ok(conn) => {
5757
// set up tasks when the connection is established to a central, so they don't run when no one is connected.
5858
let a = gatt_events_task(&server, &conn);
@@ -101,67 +101,53 @@ async fn ble_task<C: Controller>(mut runner: Runner<'_, C>) {
101101
///
102102
/// This function will handle the GATT events and process them.
103103
/// This is how we interact with read and write requests.
104-
async fn gatt_events_task(server: &Server<'_>, conn: &Connection<'_>) -> Result<(), Error> {
104+
async fn gatt_events_task(server: &Server<'_>, conn: &GattConnection<'_, '_>) -> Result<(), Error> {
105105
let level = server.battery_service.level;
106106
loop {
107107
match conn.next().await {
108-
ConnectionEvent::Disconnected { reason } => {
108+
GattConnectionEvent::Disconnected { reason } => {
109109
info!("[gatt] disconnected: {:?}", reason);
110110
break;
111111
}
112-
ConnectionEvent::Gatt { data } => {
113-
// We can choose to handle event directly without an attribute table
114-
// let req = data.request();
115-
// ..
116-
// data.reply(conn, Ok(AttRsp::Error { .. }))
117-
118-
// But to simplify things, process it in the GATT server that handles
119-
// the protocol details
120-
match data.process(server).await {
121-
// Server processing emits
122-
Ok(Some(event)) => {
123-
match &event {
124-
GattEvent::Read(event) => {
125-
if event.handle() == level.handle {
126-
let value = server.get(&level);
127-
info!("[gatt] Read Event to Level Characteristic: {:?}", value);
128-
}
129-
}
130-
GattEvent::Write(event) => {
131-
if event.handle() == level.handle {
132-
info!("[gatt] Write Event to Level Characteristic: {:?}", event.data());
133-
}
112+
GattConnectionEvent::Gatt { event } => match event {
113+
Ok(event) => {
114+
match &event {
115+
GattEvent::Read(event) => {
116+
if event.handle() == level.handle {
117+
let value = server.get(&level);
118+
info!("[gatt] Read Event to Level Characteristic: {:?}", value);
134119
}
135120
}
136-
137-
// This step is also performed at drop(), but writing it explicitly is necessary
138-
// in order to ensure reply is sent.
139-
match event.accept() {
140-
Ok(reply) => {
141-
reply.send().await;
142-
}
143-
Err(e) => {
144-
warn!("[gatt] error sending response: {:?}", e);
121+
GattEvent::Write(event) => {
122+
if event.handle() == level.handle {
123+
info!("[gatt] Write Event to Level Characteristic: {:?}", event.data());
145124
}
146125
}
147126
}
148-
Ok(_) => {}
149-
Err(e) => {
150-
warn!("[gatt] error processing event: {:?}", e);
127+
128+
// This step is also performed at drop(), but writing it explicitly is necessary
129+
// in order to ensure reply is sent.
130+
match event.accept() {
131+
Ok(reply) => {
132+
reply.send().await;
133+
}
134+
Err(e) => warn!("[gatt] error sending response: {:?}", e),
151135
}
152136
}
153-
}
137+
Err(e) => warn!("[gatt] error processing event: {:?}", e),
138+
},
154139
}
155140
}
156141
info!("[gatt] task finished");
157142
Ok(())
158143
}
159144

160145
/// Create an advertiser to use to connect to a BLE Central, and wait for it to connect.
161-
async fn advertise<'a, C: Controller>(
146+
async fn advertise<'a, 'b, C: Controller>(
162147
name: &'a str,
163148
peripheral: &mut Peripheral<'a, C>,
164-
) -> Result<Connection<'a>, BleHostError<C::Error>> {
149+
server: &'b Server<'_>,
150+
) -> Result<GattConnection<'a, 'b>, BleHostError<C::Error>> {
165151
let mut advertiser_data = [0; 31];
166152
AdStructure::encode_slice(
167153
&[
@@ -181,7 +167,7 @@ async fn advertise<'a, C: Controller>(
181167
)
182168
.await?;
183169
info!("[adv] advertising");
184-
let conn = advertiser.accept().await?;
170+
let conn = advertiser.accept().await?.with_attribute_server(server)?;
185171
info!("[adv] connection established");
186172
Ok(conn)
187173
}
@@ -190,18 +176,18 @@ async fn advertise<'a, C: Controller>(
190176
/// This task will notify the connected central of a counter value every 2 seconds.
191177
/// It will also read the RSSI value every 2 seconds.
192178
/// and will stop when the connection is closed by the central or an error occurs.
193-
async fn custom_task<C: Controller>(server: &Server<'_>, conn: &Connection<'_>, stack: &Stack<'_, C>) {
179+
async fn custom_task<C: Controller>(server: &Server<'_>, conn: &GattConnection<'_, '_>, stack: &Stack<'_, C>) {
194180
let mut tick: u8 = 0;
195181
let level = server.battery_service.level;
196182
loop {
197183
tick = tick.wrapping_add(1);
198184
info!("[custom_task] notifying connection of tick {}", tick);
199-
if level.notify(server, conn, &tick).await.is_err() {
185+
if level.notify(conn, &tick).await.is_err() {
200186
info!("[custom_task] error notifying connection");
201187
break;
202188
};
203189
// read RSSI (Received Signal Strength Indicator) of the connection.
204-
if let Ok(rssi) = conn.rssi(stack).await {
190+
if let Ok(rssi) = conn.raw().rssi(stack).await {
205191
info!("[custom_task] RSSI: {:?}", rssi);
206192
} else {
207193
info!("[custom_task] error getting RSSI");

host-macros/src/server.rs

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ use proc_macro2::TokenStream as TokenStream2;
88
use quote::{quote, quote_spanned};
99
use syn::meta::ParseNestedMeta;
1010
use syn::spanned::Spanned;
11-
use syn::{parse_quote, Expr, Result};
11+
use syn::{Expr, Result, parse_quote};
1212

1313
#[derive(Default)]
1414
pub(crate) struct ServerArgs {
1515
mutex_type: Option<syn::Type>,
1616
attribute_table_size: Option<Expr>,
17+
cccd_table_size: Option<Expr>,
18+
connections_max: Option<Expr>,
1719
}
1820

1921
impl ServerArgs {
@@ -42,11 +44,21 @@ impl ServerArgs {
4244
})?;
4345
self.attribute_table_size = Some(buffer.parse()?);
4446
}
45-
other => {
46-
return Err(meta.error(format!(
47-
"Unsupported server property: '{other}'.\nSupported properties are: mutex_type, attribute_table_size"
48-
)))
47+
"cccd_table_size" => {
48+
let buffer = meta.value().map_err(|_| {
49+
Error::custom("cccd_table_size must be followed by `= [size]`. e.g. cccd_table_size = 4".to_string())
50+
})?;
51+
self.cccd_table_size = Some(buffer.parse()?);
52+
}
53+
"connections_max" => {
54+
let buffer = meta.value().map_err(|_| {
55+
Error::custom("connections_max must be followed by `= [size]`. e.g. connections_max = 1".to_string())
56+
})?;
57+
self.connections_max = Some(buffer.parse()?);
4958
}
59+
other => return Err(meta.error(format!(
60+
"Unsupported server property: '{other}'.\nSupported properties are: mutex_type, attribute_table_size, cccd_table_size, connections_max"
61+
))),
5062
}
5163
Ok(())
5264
}
@@ -75,6 +87,7 @@ impl ServerBuilder {
7587
let mut code_service_init = TokenStream2::new();
7688
let mut code_server_populate = TokenStream2::new();
7789
let mut code_attribute_summation = TokenStream2::new();
90+
let mut code_cccd_summation = TokenStream2::new();
7891
for service in &self.properties.fields {
7992
let vis = &service.vis;
8093
let service_span = service.span();
@@ -95,6 +108,10 @@ impl ServerBuilder {
95108

96109
code_attribute_summation.extend(quote_spanned! {service_span=>
97110
+ #service_type::ATTRIBUTE_COUNT
111+
});
112+
113+
code_cccd_summation.extend(quote_spanned! {service_span=>
114+
+ #service_type::CCCD_COUNT
98115
})
99116
}
100117

@@ -104,16 +121,30 @@ impl ServerBuilder {
104121
parse_quote!(trouble_host::gap::GAP_SERVICE_ATTRIBUTE_COUNT #code_attribute_summation)
105122
};
106123

124+
let cccd_table_size = if let Some(value) = self.arguments.cccd_table_size {
125+
value
126+
} else {
127+
parse_quote!(0 #code_cccd_summation)
128+
};
129+
130+
let connections_max = if let Some(value) = self.arguments.connections_max {
131+
value
132+
} else {
133+
parse_quote!(1)
134+
};
135+
107136
quote! {
108137
const _ATTRIBUTE_TABLE_SIZE: usize = #attribute_table_size;
109138
// This pattern causes the assertion to happen at compile time
110139
const _: () = {
111140
core::assert!(_ATTRIBUTE_TABLE_SIZE >= trouble_host::gap::GAP_SERVICE_ATTRIBUTE_COUNT #code_attribute_summation, "Specified attribute table size is insufficient. Please increase attribute_table_size or remove the argument entirely to allow automatic sizing of the attribute table.");
112141
};
142+
const _CCCD_TABLE_SIZE: usize = #cccd_table_size;
143+
const _CONNECTIONS_MAX: usize = #connections_max;
113144

114145
#visibility struct #name<'values>
115146
{
116-
server: trouble_host::prelude::AttributeServer<'values, #mutex_type, _ATTRIBUTE_TABLE_SIZE>,
147+
server: trouble_host::prelude::AttributeServer<'values, #mutex_type, _ATTRIBUTE_TABLE_SIZE, _CCCD_TABLE_SIZE, _CONNECTIONS_MAX>,
117148
#code_service_definition
118149
}
119150

@@ -179,7 +210,7 @@ impl ServerBuilder {
179210

180211
impl<'values> core::ops::Deref for #name<'values>
181212
{
182-
type Target = trouble_host::prelude::AttributeServer<'values, #mutex_type, _ATTRIBUTE_TABLE_SIZE>;
213+
type Target = trouble_host::prelude::AttributeServer<'values, #mutex_type, _ATTRIBUTE_TABLE_SIZE, _CCCD_TABLE_SIZE, _CONNECTIONS_MAX>;
183214

184215
fn deref(&self) -> &Self::Target {
185216
&self.server

host-macros/src/service.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ pub(crate) struct ServiceBuilder {
109109
properties: syn::ItemStruct,
110110
args: ServiceArgs,
111111
attribute_count: usize,
112+
cccd_count: usize,
112113
code_impl: TokenStream2,
113114
code_build_chars: TokenStream2,
114115
code_struct_init: TokenStream2,
@@ -121,6 +122,7 @@ impl ServiceBuilder {
121122
properties,
122123
args,
123124
attribute_count: 1, // Service counts as an attribute
125+
cccd_count: 0,
124126
code_struct_init: TokenStream2::new(),
125127
code_impl: TokenStream2::new(),
126128
code_fields: TokenStream2::new(),
@@ -136,7 +138,12 @@ impl ServiceBuilder {
136138
/// If the characteristic has either the notify or indicate property,
137139
/// a Client Characteristic Configuration Descriptor (CCCD) declaration will also be added.
138140
fn increment_attributes(&mut self, access: &AccessArgs) -> usize {
139-
self.attribute_count += if access.notify || access.indicate { 3 } else { 2 };
141+
if access.notify || access.indicate {
142+
self.cccd_count += 1;
143+
self.attribute_count += 3;
144+
} else {
145+
self.attribute_count += 2;
146+
}
140147
self.attribute_count
141148
}
142149
/// Construct the macro blueprint for the service struct.
@@ -150,6 +157,7 @@ impl ServiceBuilder {
150157
let code_build_chars = self.code_build_chars;
151158
let uuid = self.args.uuid;
152159
let attribute_count = self.attribute_count;
160+
let cccd_count = self.cccd_count;
153161
quote! {
154162
#visibility struct #struct_name {
155163
#fields
@@ -159,6 +167,7 @@ impl ServiceBuilder {
159167
#[allow(unused)]
160168
impl #struct_name {
161169
#visibility const ATTRIBUTE_COUNT: usize = #attribute_count;
170+
#visibility const CCCD_COUNT: usize = #cccd_count;
162171

163172
#visibility fn new<M, const MAX_ATTRIBUTES: usize>(table: &mut trouble_host::attribute::AttributeTable<'_, M, MAX_ATTRIBUTES>) -> Self
164173
where

host/src/attribute.rs

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::Error;
1313
use crate::att::AttErrorCode;
1414
use crate::attribute_server::AttributeServer;
1515
use crate::cursor::{ReadCursor, WriteCursor};
16-
use crate::prelude::{AsGatt, Connection, FixedGattValue, FromGatt};
16+
use crate::prelude::{AsGatt, Connection, FixedGattValue, FromGatt, GattConnection};
1717
use crate::types::gatt_traits::FromGattError;
1818
pub use crate::types::uuid::Uuid;
1919

@@ -629,17 +629,13 @@ impl<T: FromGatt> Characteristic<T> {
629629
/// If the provided connection has not subscribed for this characteristic, it will not be notified.
630630
///
631631
/// If the characteristic does not support notifications, an error is returned.
632-
pub async fn notify<M: RawMutex, const MAX: usize>(
633-
&self,
634-
server: &AttributeServer<'_, M, MAX>,
635-
connection: &Connection<'_>,
636-
value: &T,
637-
) -> Result<(), Error> {
632+
pub async fn notify(&self, connection: &GattConnection<'_, '_>, value: &T) -> Result<(), Error> {
638633
let value = value.as_gatt();
639-
server.table().set_raw(self.handle, value)?;
634+
let server = connection.server;
635+
server.set(self.handle, value)?;
640636

641637
let cccd_handle = self.cccd_handle.ok_or(Error::NotFound)?;
642-
638+
let connection = connection.raw();
643639
if !server.should_notify(connection, cccd_handle) {
644640
// No reason to fail?
645641
return Ok(());
@@ -662,9 +658,9 @@ impl<T: FromGatt> Characteristic<T> {
662658
}
663659

664660
/// Set the value of the characteristic in the provided attribute server.
665-
pub fn set<M: RawMutex, const MAX: usize>(
661+
pub fn set<M: RawMutex, const AT: usize, const CT: usize, const CN: usize>(
666662
&self,
667-
server: &AttributeServer<'_, M, MAX>,
663+
server: &AttributeServer<'_, M, AT, CT, CN>,
668664
value: &T,
669665
) -> Result<(), Error> {
670666
let value = value.as_gatt();
@@ -676,7 +672,10 @@ impl<T: FromGatt> Characteristic<T> {
676672
///
677673
/// If the characteristic for the handle cannot be found, an error is returned.
678674
///
679-
pub fn get<M: RawMutex, const MAX: usize>(&self, server: &AttributeServer<'_, M, MAX>) -> Result<T, Error> {
675+
pub fn get<M: RawMutex, const AT: usize, const CT: usize, const CN: usize>(
676+
&self,
677+
server: &AttributeServer<'_, M, AT, CT, CN>,
678+
) -> Result<T, Error> {
680679
server.table().get(self)
681680
}
682681

@@ -874,6 +873,8 @@ pub enum CCCDFlag {
874873
}
875874

876875
/// CCCD flag.
876+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
877+
#[derive(Clone, Copy, Default)]
877878
pub struct CCCD(pub(crate) u16);
878879

879880
impl<const T: usize> From<[CCCDFlag; T]> for CCCD {
@@ -887,6 +888,16 @@ impl<const T: usize> From<[CCCDFlag; T]> for CCCD {
887888
}
888889

889890
impl CCCD {
891+
/// Get raw value
892+
pub fn raw(&self) -> u16 {
893+
self.0
894+
}
895+
896+
/// Clear all properties
897+
pub fn disable(&mut self) {
898+
self.0 = 0;
899+
}
900+
890901
/// Check if any of the properties are set.
891902
pub fn any(&self, props: &[CCCDFlag]) -> bool {
892903
for p in props {
@@ -896,4 +907,15 @@ impl CCCD {
896907
}
897908
false
898909
}
910+
911+
/// Enable or disable notifications
912+
pub fn set_notify(&mut self, is_enabled: bool) {
913+
let mask: u16 = CCCDFlag::Notify as u16;
914+
self.0 = if is_enabled { self.0 | mask } else { self.0 & !mask };
915+
}
916+
917+
/// Check if notifications are enabled
918+
pub fn should_notify(&self) -> bool {
919+
(self.0 & (CCCDFlag::Notify as u16)) != 0
920+
}
899921
}

0 commit comments

Comments
 (0)